1use crate::{GeneratorError, Result, analysis::SchemaAnalysis, streaming::StreamingConfig};
2use proc_macro2::TokenStream;
3use quote::{format_ident, quote};
4use std::collections::BTreeMap;
5use std::path::PathBuf;
6
7#[derive(Clone)]
9struct DiscriminatedVariantInfo {
10 discriminator_field: String,
12 is_parent_untagged: bool,
14}
15
16#[derive(Debug, Clone)]
17pub struct GeneratorConfig {
18 pub spec_path: PathBuf,
20 pub output_dir: PathBuf,
22 pub module_name: String,
24 pub enable_sse_client: bool,
26 pub enable_async_client: bool,
28 pub enable_specta: bool,
30 pub type_mappings: BTreeMap<String, String>,
32 pub streaming_config: Option<StreamingConfig>,
34 pub nullable_field_overrides: BTreeMap<String, bool>,
37 pub schema_extensions: Vec<PathBuf>,
40 pub http_client_config: Option<crate::http_config::HttpClientConfig>,
42 pub retry_config: Option<crate::http_config::RetryConfig>,
44 pub tracing_enabled: bool,
46 pub auth_config: Option<crate::http_config::AuthConfig>,
48}
49
50impl Default for GeneratorConfig {
51 fn default() -> Self {
52 Self {
53 spec_path: "openapi.json".into(),
54 output_dir: "src/gen".into(),
55 module_name: "api_types".to_string(),
56 enable_sse_client: true,
57 enable_async_client: true,
58 enable_specta: false,
59 type_mappings: default_type_mappings(),
60 streaming_config: None,
61 nullable_field_overrides: BTreeMap::new(),
62 schema_extensions: Vec::new(),
63 http_client_config: None,
64 retry_config: None,
65 tracing_enabled: true,
66 auth_config: None,
67 }
68 }
69}
70
71pub fn default_type_mappings() -> BTreeMap<String, String> {
72 let mut mappings = BTreeMap::new();
73 mappings.insert("integer".to_string(), "i64".to_string());
74 mappings.insert("number".to_string(), "f64".to_string());
75 mappings.insert("string".to_string(), "String".to_string());
76 mappings.insert("boolean".to_string(), "bool".to_string());
77 mappings
78}
79
80#[derive(Debug, Clone)]
82pub struct GeneratedFile {
83 pub path: PathBuf,
85 pub content: String,
87}
88
89#[derive(Debug, Clone)]
91pub struct GenerationResult {
92 pub files: Vec<GeneratedFile>,
94 pub mod_file: GeneratedFile,
96}
97
98pub struct CodeGenerator {
99 config: GeneratorConfig,
100}
101
102impl CodeGenerator {
103 pub fn new(config: GeneratorConfig) -> Self {
104 Self { config }
105 }
106
107 pub fn config(&self) -> &GeneratorConfig {
109 &self.config
110 }
111
112 pub fn generate_all(&self, analysis: &mut SchemaAnalysis) -> Result<GenerationResult> {
114 let mut files = Vec::new();
115
116 let types_content = self.generate_types(analysis)?;
118 files.push(GeneratedFile {
119 path: "types.rs".into(),
120 content: types_content,
121 });
122
123 if let Some(ref streaming_config) = self.config.streaming_config {
125 let streaming_content = self.generate_streaming_client(streaming_config, analysis)?;
126 files.push(GeneratedFile {
127 path: "streaming.rs".into(),
128 content: streaming_content,
129 });
130 }
131
132 if self.config.enable_async_client {
134 let http_content = self.generate_http_client(analysis)?;
135 files.push(GeneratedFile {
136 path: "client.rs".into(),
137 content: http_content,
138 });
139 }
140
141 let mod_content = self.generate_mod_file(&files)?;
143 let mod_file = GeneratedFile {
144 path: "mod.rs".into(),
145 content: mod_content,
146 };
147
148 Ok(GenerationResult { files, mod_file })
149 }
150
151 pub fn generate(&self, analysis: &mut SchemaAnalysis) -> Result<String> {
153 self.generate_types(analysis)
154 }
155
156 fn generate_types(&self, analysis: &mut SchemaAnalysis) -> Result<String> {
158 let mut type_definitions = TokenStream::new();
159
160 let mut discriminated_variant_info: BTreeMap<String, DiscriminatedVariantInfo> =
163 BTreeMap::new();
164
165 let mut sorted_schemas: Vec<_> = analysis.schemas.iter().collect();
167 sorted_schemas.sort_by_key(|(name, _)| name.as_str());
168
169 for (_parent_name, schema) in sorted_schemas {
170 if let crate::analysis::SchemaType::DiscriminatedUnion {
171 variants,
172 discriminator_field,
173 } = &schema.schema_type
174 {
175 let is_parent_untagged =
177 self.should_use_untagged_discriminated_union(schema, analysis);
178
179 for variant in variants {
180 if let Some(variant_schema) = analysis.schemas.get(&variant.type_name) {
183 if let crate::analysis::SchemaType::Object { properties, .. } =
184 &variant_schema.schema_type
185 {
186 if properties.contains_key(discriminator_field) {
187 discriminated_variant_info.insert(
188 variant.type_name.clone(),
189 DiscriminatedVariantInfo {
190 discriminator_field: discriminator_field.clone(),
191 is_parent_untagged,
192 },
193 );
194 }
195 }
196 }
197 }
198 }
199 }
200
201 let generation_order = analysis.dependencies.topological_sort()?;
203
204 let mut processed = std::collections::HashSet::new();
206
207 for schema_name in generation_order {
209 if let Some(schema) = analysis.schemas.get(&schema_name) {
210 let type_def =
211 self.generate_type_definition(schema, analysis, &discriminated_variant_info)?;
212 if !type_def.is_empty() {
213 type_definitions.extend(type_def);
214 }
215 processed.insert(schema_name);
216 }
217 }
218
219 let mut remaining_schemas: Vec<_> = analysis
222 .schemas
223 .iter()
224 .filter(|(name, _)| !processed.contains(*name))
225 .collect();
226 remaining_schemas.sort_by_key(|(name, _)| name.as_str());
227
228 for (_schema_name, schema) in remaining_schemas {
229 let type_def =
230 self.generate_type_definition(schema, analysis, &discriminated_variant_info)?;
231 if !type_def.is_empty() {
232 type_definitions.extend(type_def);
233 }
234 }
235
236 let generated = quote! {
238 #![allow(clippy::large_enum_variant)]
244 #![allow(clippy::format_in_format_args)]
245 #![allow(clippy::let_unit_value)]
246 #![allow(unreachable_patterns)]
247
248 use serde::{Deserialize, Serialize};
249
250 #type_definitions
251 };
252
253 let syntax_tree = syn::parse2::<syn::File>(generated).map_err(|e| {
255 GeneratorError::CodeGenError(format!("Failed to parse generated code: {e}"))
256 })?;
257
258 let formatted = prettyplease::unparse(&syntax_tree);
259
260 Ok(formatted)
261 }
262
263 fn generate_streaming_client(
265 &self,
266 streaming_config: &StreamingConfig,
267 analysis: &SchemaAnalysis,
268 ) -> Result<String> {
269 let mut client_code = TokenStream::new();
270
271 let imports = quote! {
273 #![allow(clippy::format_in_format_args)]
278 #![allow(clippy::let_unit_value)]
279 #![allow(unused_mut)]
280
281 use super::types::*;
282 use async_trait::async_trait;
283 use futures_util::{Stream, StreamExt};
284 use std::pin::Pin;
285 use std::time::Duration;
286 use reqwest::header::{HeaderMap, HeaderValue};
287 use tracing::{debug, error, info, warn, instrument};
288 };
289 client_code.extend(imports);
290
291 if streaming_config.generate_client {
293 let error_types = self.generate_streaming_error_types()?;
294 client_code.extend(error_types);
295 }
296
297 for endpoint in &streaming_config.endpoints {
299 let trait_code = self.generate_endpoint_trait(endpoint, analysis)?;
300 client_code.extend(trait_code);
301 }
302
303 if streaming_config.generate_client {
305 let client_impl = self.generate_streaming_client_impl(streaming_config, analysis)?;
306 client_code.extend(client_impl);
307 }
308
309 if streaming_config.event_parser_helpers {
311 let parser_code = self.generate_sse_parser_utilities(streaming_config)?;
312 client_code.extend(parser_code);
313 }
314
315 if let Some(reconnect_config) = &streaming_config.reconnection_config {
317 let reconnect_code = self.generate_reconnection_utilities(reconnect_config)?;
318 client_code.extend(reconnect_code);
319 }
320
321 let syntax_tree = syn::parse2::<syn::File>(client_code).map_err(|e| {
322 GeneratorError::CodeGenError(format!("Failed to parse streaming client code: {e}"))
323 })?;
324
325 Ok(prettyplease::unparse(&syntax_tree))
326 }
327
328 pub fn generate_http_client(&self, analysis: &SchemaAnalysis) -> Result<String> {
330 let error_types = self.generate_http_error_types();
331 let client_struct = self.generate_http_client_struct();
332 let operation_methods = self.generate_operation_methods(analysis);
333
334 let generated = quote! {
335 #![allow(clippy::format_in_format_args)]
340 #![allow(clippy::let_unit_value)]
341
342 use super::types::*;
343
344 #error_types
345
346 #client_struct
347
348 #operation_methods
349 };
350
351 let syntax_tree = syn::parse2::<syn::File>(generated).map_err(|e| {
352 GeneratorError::CodeGenError(format!("Failed to parse HTTP client code: {e}"))
353 })?;
354
355 Ok(prettyplease::unparse(&syntax_tree))
356 }
357
358 fn generate_http_error_types(&self) -> TokenStream {
360 quote! {
361 use thiserror::Error;
362
363 #[derive(Error, Debug)]
365 pub enum HttpError {
366 #[error("Network error: {0}")]
368 Network(#[from] reqwest::Error),
369
370 #[error("Middleware error: {0}")]
372 Middleware(#[from] reqwest_middleware::Error),
373
374 #[error("Failed to serialize request: {0}")]
376 Serialization(String),
377
378 #[error("Failed to deserialize response: {0}")]
380 Deserialization(String),
381
382 #[error("HTTP error {status}: {message}")]
384 Http {
385 status: u16,
386 message: String,
387 body: Option<String>,
388 },
389
390 #[error("Authentication error: {0}")]
392 Auth(String),
393
394 #[error("Request timeout")]
396 Timeout,
397
398 #[error("Configuration error: {0}")]
400 Config(String),
401
402 #[error("{0}")]
404 Other(String),
405 }
406
407 impl HttpError {
408 pub fn from_status(status: u16, message: impl Into<String>, body: Option<String>) -> Self {
410 Self::Http {
411 status,
412 message: message.into(),
413 body,
414 }
415 }
416
417 pub fn serialization_error(error: impl std::fmt::Display) -> Self {
419 Self::Serialization(error.to_string())
420 }
421
422 pub fn deserialization_error(error: impl std::fmt::Display) -> Self {
424 Self::Deserialization(error.to_string())
425 }
426
427 pub fn is_client_error(&self) -> bool {
429 matches!(self, Self::Http { status, .. } if *status >= 400 && *status < 500)
430 }
431
432 pub fn is_server_error(&self) -> bool {
434 matches!(self, Self::Http { status, .. } if *status >= 500 && *status < 600)
435 }
436
437 pub fn is_retryable(&self) -> bool {
439 match self {
440 Self::Network(_) => true,
441 Self::Middleware(_) => true,
442 Self::Timeout => true,
443 Self::Http { status, .. } => {
444 matches!(status, 429 | 500 | 502 | 503 | 504)
446 }
447 _ => false,
448 }
449 }
450 }
451
452 pub type HttpResult<T> = Result<T, HttpError>;
454 }
455 }
456
457 fn generate_mod_file(&self, files: &[GeneratedFile]) -> Result<String> {
459 let mut module_declarations = Vec::new();
460 let mut pub_uses = Vec::new();
461
462 for file in files {
463 if let Some(module_name) = file.path.file_stem().and_then(|s| s.to_str()) {
464 if module_name != "mod" {
465 module_declarations.push(format!("pub mod {module_name};"));
466 pub_uses.push(format!("pub use {module_name}::*;"));
467 }
468 }
469 }
470
471 let content = format!(
472 r#"//! Generated API modules
473//!
474//! This module exports all generated API types and clients.
475//! Do not edit manually - regenerate using the appropriate script.
476
477#![allow(unused_imports)]
478
479{}
480
481{}
482"#,
483 module_declarations.join("\n"),
484 pub_uses.join("\n")
485 );
486
487 Ok(content)
488 }
489
490 pub fn write_files(&self, result: &GenerationResult) -> Result<()> {
492 use std::fs;
493
494 fs::create_dir_all(&self.config.output_dir)?;
496
497 for file in &result.files {
499 let file_path = self.config.output_dir.join(&file.path);
500 fs::write(&file_path, &file.content)?;
501 }
502
503 let mod_path = self.config.output_dir.join(&result.mod_file.path);
505 fs::write(&mod_path, &result.mod_file.content)?;
506
507 Ok(())
508 }
509
510 fn generate_type_definition(
511 &self,
512 schema: &crate::analysis::AnalyzedSchema,
513 analysis: &crate::analysis::SchemaAnalysis,
514 discriminated_variant_info: &BTreeMap<String, DiscriminatedVariantInfo>,
515 ) -> Result<TokenStream> {
516 use crate::analysis::SchemaType;
517
518 match &schema.schema_type {
519 SchemaType::Primitive { rust_type } => {
520 self.generate_type_alias(schema, rust_type)
522 }
523 SchemaType::StringEnum { values } => self.generate_string_enum(schema, values),
524 SchemaType::ExtensibleEnum { known_values } => {
525 self.generate_extensible_enum(schema, known_values)
526 }
527 SchemaType::Object {
528 properties,
529 required,
530 additional_properties,
531 } => self.generate_struct(
532 schema,
533 properties,
534 required,
535 *additional_properties,
536 analysis,
537 discriminated_variant_info.get(&schema.name),
538 ),
539 SchemaType::DiscriminatedUnion {
540 discriminator_field,
541 variants,
542 } => {
543 if self.should_use_untagged_discriminated_union(schema, analysis) {
545 let schema_refs: Vec<crate::analysis::SchemaRef> = variants
547 .iter()
548 .map(|v| crate::analysis::SchemaRef {
549 target: v.type_name.clone(),
550 nullable: false,
551 })
552 .collect();
553 self.generate_union_enum(schema, &schema_refs)
554 } else {
555 self.generate_discriminated_enum(
556 schema,
557 discriminator_field,
558 variants,
559 analysis,
560 )
561 }
562 }
563 SchemaType::Union { variants } => self.generate_union_enum(schema, variants),
564 SchemaType::Reference { target } => {
565 if schema.name != *target {
568 let alias_name = format_ident!("{}", self.to_rust_type_name(&schema.name));
570 let target_type = format_ident!("{}", self.to_rust_type_name(target));
571
572 let doc_comment = if let Some(desc) = &schema.description {
573 quote! { #[doc = #desc] }
574 } else {
575 TokenStream::new()
576 };
577
578 Ok(quote! {
579 #doc_comment
580 pub type #alias_name = #target_type;
581 })
582 } else {
583 Ok(TokenStream::new())
585 }
586 }
587 SchemaType::Array { item_type } => {
588 let array_name = format_ident!("{}", self.to_rust_type_name(&schema.name));
590 let inner_type = self.generate_array_item_type(item_type, analysis);
591
592 let doc_comment = if let Some(desc) = &schema.description {
593 quote! { #[doc = #desc] }
594 } else {
595 TokenStream::new()
596 };
597
598 Ok(quote! {
599 #doc_comment
600 pub type #array_name = Vec<#inner_type>;
601 })
602 }
603 SchemaType::Composition { schemas } => {
604 self.generate_composition_struct(schema, schemas)
605 }
606 }
607 }
608
609 fn generate_type_alias(
610 &self,
611 schema: &crate::analysis::AnalyzedSchema,
612 rust_type: &str,
613 ) -> Result<TokenStream> {
614 let type_name = format_ident!("{}", self.to_rust_type_name(&schema.name));
615
616 let base_type = if rust_type.contains("::") {
618 let parts: Vec<&str> = rust_type.split("::").collect();
619 if parts.len() == 2 {
620 let module = format_ident!("{}", parts[0]);
621 let type_name_part = format_ident!("{}", parts[1]);
622 quote! { #module::#type_name_part }
623 } else {
624 let path_parts: Vec<_> = parts.iter().map(|p| format_ident!("{}", p)).collect();
626 quote! { #(#path_parts)::* }
627 }
628 } else {
629 let simple_type = format_ident!("{}", rust_type);
630 quote! { #simple_type }
631 };
632
633 let doc_comment = if let Some(desc) = &schema.description {
634 let sanitized_desc = self.sanitize_doc_comment(desc);
635 quote! { #[doc = #sanitized_desc] }
636 } else {
637 TokenStream::new()
638 };
639
640 Ok(quote! {
641 #doc_comment
642 pub type #type_name = #base_type;
643 })
644 }
645
646 fn generate_extensible_enum(
647 &self,
648 schema: &crate::analysis::AnalyzedSchema,
649 known_values: &[String],
650 ) -> Result<TokenStream> {
651 let enum_name = format_ident!("{}", self.to_rust_type_name(&schema.name));
652
653 let doc_comment = if let Some(desc) = &schema.description {
654 quote! { #[doc = #desc] }
655 } else {
656 TokenStream::new()
657 };
658
659 let known_variants = known_values.iter().map(|value| {
664 let variant_name = self.to_rust_enum_variant(value);
665 let variant_ident = format_ident!("{}", variant_name);
666 quote! {
667 #variant_ident,
668 }
669 });
670
671 let match_arms_de = known_values.iter().map(|value| {
672 let variant_name = self.to_rust_enum_variant(value);
673 let variant_ident = format_ident!("{}", variant_name);
674 quote! {
675 #value => Ok(#enum_name::#variant_ident),
676 }
677 });
678
679 let match_arms_ser = known_values.iter().map(|value| {
680 let variant_name = self.to_rust_enum_variant(value);
681 let variant_ident = format_ident!("{}", variant_name);
682 quote! {
683 #enum_name::#variant_ident => #value,
684 }
685 });
686
687 let derives = if self.config.enable_specta {
688 quote! {
689 #[derive(Debug, Clone, PartialEq, Eq)]
690 #[cfg_attr(feature = "specta", derive(specta::Type))]
691 }
692 } else {
693 quote! {
694 #[derive(Debug, Clone, PartialEq, Eq)]
695 }
696 };
697
698 Ok(quote! {
699 #doc_comment
700 #derives
701 pub enum #enum_name {
702 #(#known_variants)*
703 Custom(String),
705 }
706
707 impl<'de> serde::Deserialize<'de> for #enum_name {
708 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
709 where
710 D: serde::Deserializer<'de>,
711 {
712 let value = String::deserialize(deserializer)?;
713 match value.as_str() {
714 #(#match_arms_de)*
715 _ => Ok(#enum_name::Custom(value)),
716 }
717 }
718 }
719
720 impl serde::Serialize for #enum_name {
721 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
722 where
723 S: serde::Serializer,
724 {
725 let value = match self {
726 #(#match_arms_ser)*
727 #enum_name::Custom(s) => s.as_str(),
728 };
729 serializer.serialize_str(value)
730 }
731 }
732 })
733 }
734
735 fn generate_string_enum(
736 &self,
737 schema: &crate::analysis::AnalyzedSchema,
738 values: &[String],
739 ) -> Result<TokenStream> {
740 let enum_name = format_ident!("{}", self.to_rust_type_name(&schema.name));
741
742 let default_value = schema
744 .default
745 .as_ref()
746 .and_then(|v| v.as_str())
747 .map(|s| s.to_string());
748
749 let variants = values.iter().enumerate().map(|(i, value)| {
750 let variant_name = self.to_rust_enum_variant(value);
752 let variant_ident = format_ident!("{}", variant_name);
753
754 let is_default = if let Some(ref default) = default_value {
756 value == default
757 } else {
758 i == 0 };
760
761 if is_default {
762 quote! {
763 #[default]
764 #[serde(rename = #value)]
765 #variant_ident,
766 }
767 } else {
768 quote! {
769 #[serde(rename = #value)]
770 #variant_ident,
771 }
772 }
773 });
774
775 let doc_comment = if let Some(desc) = &schema.description {
776 quote! { #[doc = #desc] }
777 } else {
778 TokenStream::new()
779 };
780
781 let derives = if self.config.enable_specta {
783 quote! {
784 #[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize, Default)]
785 #[cfg_attr(feature = "specta", derive(specta::Type))]
786 }
787 } else {
788 quote! {
789 #[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize, Default)]
790 }
791 };
792
793 Ok(quote! {
794 #doc_comment
795 #derives
796 pub enum #enum_name {
797 #(#variants)*
798 }
799 })
800 }
801
802 fn generate_struct(
803 &self,
804 schema: &crate::analysis::AnalyzedSchema,
805 properties: &BTreeMap<String, crate::analysis::PropertyInfo>,
806 required: &std::collections::HashSet<String>,
807 additional_properties: bool,
808 analysis: &crate::analysis::SchemaAnalysis,
809 discriminator_info: Option<&DiscriminatedVariantInfo>,
810 ) -> Result<TokenStream> {
811 let struct_name = format_ident!("{}", self.to_rust_type_name(&schema.name));
812
813 let mut sorted_properties: Vec<_> = properties.iter().collect();
815 sorted_properties.sort_by_key(|(name, _)| name.as_str());
816
817 let mut fields: Vec<TokenStream> = sorted_properties
818 .into_iter()
819 .filter(|(field_name, _)| {
820 if let Some(info) = discriminator_info {
824 if !info.is_parent_untagged
825 && field_name.as_str() == info.discriminator_field.as_str()
826 {
827 false } else {
829 true }
831 } else {
832 true }
834 })
835 .map(|(field_name, prop)| {
836 let field_ident = format_ident!("{}", self.to_rust_field_name(field_name));
837 let is_required = required.contains(field_name);
838 let field_type =
839 self.generate_field_type(&schema.name, field_name, prop, is_required, analysis);
840
841 let serde_attrs = self.generate_serde_field_attrs(field_name, prop, is_required);
842 let specta_attrs = self.generate_specta_field_attrs(field_name);
843
844 let doc_comment = if let Some(desc) = &prop.description {
845 let sanitized_desc = self.sanitize_doc_comment(desc);
846 quote! { #[doc = #sanitized_desc] }
847 } else {
848 TokenStream::new()
849 };
850
851 quote! {
852 #doc_comment
853 #serde_attrs
854 #specta_attrs
855 pub #field_ident: #field_type,
856 }
857 })
858 .collect();
859
860 if additional_properties {
862 fields.push(quote! {
863 #[serde(flatten)]
865 pub additional_properties: std::collections::BTreeMap<String, serde_json::Value>,
866 });
867 }
868
869 let doc_comment = if let Some(desc) = &schema.description {
870 quote! { #[doc = #desc] }
871 } else {
872 TokenStream::new()
873 };
874
875 let derives = if self.config.enable_specta {
879 quote! {
880 #[derive(Debug, Clone, Deserialize, Serialize)]
881 #[cfg_attr(feature = "specta", derive(specta::Type))]
882 }
883 } else {
884 quote! {
885 #[derive(Debug, Clone, Deserialize, Serialize)]
886 }
887 };
888
889 Ok(quote! {
890 #doc_comment
891 #derives
892 pub struct #struct_name {
893 #(#fields)*
894 }
895 })
896 }
897
898 fn generate_discriminated_enum(
899 &self,
900 schema: &crate::analysis::AnalyzedSchema,
901 discriminator_field: &str,
902 variants: &[crate::analysis::UnionVariant],
903 analysis: &crate::analysis::SchemaAnalysis,
904 ) -> Result<TokenStream> {
905 let enum_name = format_ident!("{}", self.to_rust_type_name(&schema.name));
906
907 let has_nested_discriminated_union = variants.iter().any(|variant| {
909 if let Some(variant_schema) = analysis.schemas.get(&variant.type_name) {
910 matches!(
911 variant_schema.schema_type,
912 crate::analysis::SchemaType::DiscriminatedUnion { .. }
913 )
914 } else {
915 false
916 }
917 });
918
919 if has_nested_discriminated_union {
921 let schema_refs: Vec<crate::analysis::SchemaRef> = variants
923 .iter()
924 .map(|v| crate::analysis::SchemaRef {
925 target: v.type_name.clone(),
926 nullable: false,
927 })
928 .collect();
929 return self.generate_union_enum(schema, &schema_refs);
930 }
931
932 let enum_variants = variants.iter().map(|variant| {
933 let variant_name = format_ident!("{}", variant.rust_name);
934 let variant_value = &variant.discriminator_value;
935
936 let variant_type = format_ident!("{}", self.to_rust_type_name(&variant.type_name));
939 quote! {
940 #[serde(rename = #variant_value)]
941 #variant_name(#variant_type),
942 }
943 });
944
945 let doc_comment = if let Some(desc) = &schema.description {
946 quote! { #[doc = #desc] }
947 } else {
948 TokenStream::new()
949 };
950
951 let derives = if self.config.enable_specta {
953 quote! {
954 #[derive(Debug, Clone, Deserialize, Serialize)]
955 #[cfg_attr(feature = "specta", derive(specta::Type))]
956 #[serde(tag = #discriminator_field)]
957 }
958 } else {
959 quote! {
960 #[derive(Debug, Clone, Deserialize, Serialize)]
961 #[serde(tag = #discriminator_field)]
962 }
963 };
964
965 Ok(quote! {
966 #doc_comment
967 #derives
968 pub enum #enum_name {
969 #(#enum_variants)*
970 }
971 })
972 }
973
974 fn should_use_untagged_discriminated_union(
976 &self,
977 schema: &crate::analysis::AnalyzedSchema,
978 analysis: &crate::analysis::SchemaAnalysis,
979 ) -> bool {
980 for other_schema in analysis.schemas.values() {
985 if let crate::analysis::SchemaType::DiscriminatedUnion {
986 variants,
987 discriminator_field: _,
988 } = &other_schema.schema_type
989 {
990 for variant in variants {
991 if variant.type_name == schema.name {
992 if let crate::analysis::SchemaType::DiscriminatedUnion {
997 discriminator_field: current_discriminator,
998 variants: current_variants,
999 ..
1000 } = &schema.schema_type
1001 {
1002 for current_variant in current_variants {
1004 if let Some(variant_schema) =
1005 analysis.schemas.get(¤t_variant.type_name)
1006 {
1007 if let crate::analysis::SchemaType::Object {
1008 properties, ..
1009 } = &variant_schema.schema_type
1010 {
1011 if properties.contains_key(current_discriminator) {
1012 return false;
1015 }
1016 }
1017 }
1018 }
1019 }
1020
1021 return true;
1023 }
1024 }
1025 }
1026 }
1027 false
1028 }
1029
1030 fn generate_union_enum(
1031 &self,
1032 schema: &crate::analysis::AnalyzedSchema,
1033 variants: &[crate::analysis::SchemaRef],
1034 ) -> Result<TokenStream> {
1035 let enum_name = format_ident!("{}", self.to_rust_type_name(&schema.name));
1036
1037 let mut used_variant_names = std::collections::HashSet::new();
1039 let enum_variants = variants.iter().enumerate().map(|(i, variant)| {
1040 let base_variant_name = self.type_name_to_variant_name(&variant.target);
1042 let variant_name = self.ensure_unique_variant_name_generator(
1043 base_variant_name,
1044 &mut used_variant_names,
1045 i,
1046 );
1047 let variant_name_ident = format_ident!("{}", variant_name);
1048
1049 let variant_type_tokens = if matches!(
1051 variant.target.as_str(),
1052 "bool"
1053 | "i8"
1054 | "i16"
1055 | "i32"
1056 | "i64"
1057 | "i128"
1058 | "u8"
1059 | "u16"
1060 | "u32"
1061 | "u64"
1062 | "u128"
1063 | "f32"
1064 | "f64"
1065 | "String"
1066 ) {
1067 let type_ident = format_ident!("{}", variant.target);
1068 quote! { #type_ident }
1069 } else if variant.target.starts_with("Vec<") && variant.target.ends_with(">") {
1070 let inner = &variant.target[4..variant.target.len() - 1];
1072
1073 if inner.starts_with("Vec<") && inner.ends_with(">") {
1075 let inner_inner = &inner[4..inner.len() - 1];
1076 let inner_inner_type = if matches!(
1077 inner_inner,
1078 "bool"
1079 | "i8"
1080 | "i16"
1081 | "i32"
1082 | "i64"
1083 | "i128"
1084 | "u8"
1085 | "u16"
1086 | "u32"
1087 | "u64"
1088 | "u128"
1089 | "f32"
1090 | "f64"
1091 | "String"
1092 ) {
1093 format_ident!("{}", inner_inner)
1094 } else {
1095 format_ident!("{}", self.to_rust_type_name(inner_inner))
1096 };
1097 quote! { Vec<Vec<#inner_inner_type>> }
1098 } else {
1099 let inner_type = if matches!(
1100 inner,
1101 "bool"
1102 | "i8"
1103 | "i16"
1104 | "i32"
1105 | "i64"
1106 | "i128"
1107 | "u8"
1108 | "u16"
1109 | "u32"
1110 | "u64"
1111 | "u128"
1112 | "f32"
1113 | "f64"
1114 | "String"
1115 ) {
1116 format_ident!("{}", inner)
1117 } else {
1118 format_ident!("{}", self.to_rust_type_name(inner))
1119 };
1120 quote! { Vec<#inner_type> }
1121 }
1122 } else {
1123 let type_ident = format_ident!("{}", self.to_rust_type_name(&variant.target));
1124 quote! { #type_ident }
1125 };
1126
1127 quote! {
1128 #variant_name_ident(#variant_type_tokens),
1129 }
1130 });
1131
1132 let doc_comment = if let Some(desc) = &schema.description {
1133 quote! { #[doc = #desc] }
1134 } else {
1135 TokenStream::new()
1136 };
1137
1138 let derives = if self.config.enable_specta {
1140 quote! {
1141 #[derive(Debug, Clone, Deserialize, Serialize)]
1142 #[cfg_attr(feature = "specta", derive(specta::Type))]
1143 #[serde(untagged)]
1144 }
1145 } else {
1146 quote! {
1147 #[derive(Debug, Clone, Deserialize, Serialize)]
1148 #[serde(untagged)]
1149 }
1150 };
1151
1152 Ok(quote! {
1153 #doc_comment
1154 #derives
1155 pub enum #enum_name {
1156 #(#enum_variants)*
1157 }
1158 })
1159 }
1160
1161 fn generate_field_type(
1162 &self,
1163 schema_name: &str,
1164 field_name: &str,
1165 prop: &crate::analysis::PropertyInfo,
1166 is_required: bool,
1167 analysis: &crate::analysis::SchemaAnalysis,
1168 ) -> TokenStream {
1169 use crate::analysis::SchemaType;
1170
1171 let base_type = match &prop.schema_type {
1172 SchemaType::Primitive { rust_type } => {
1173 if rust_type.contains("::") {
1175 let parts: Vec<&str> = rust_type.split("::").collect();
1176 if parts.len() == 2 {
1177 let module = format_ident!("{}", parts[0]);
1178 let type_name = format_ident!("{}", parts[1]);
1179 quote! { #module::#type_name }
1180 } else {
1181 let path_parts: Vec<_> =
1183 parts.iter().map(|p| format_ident!("{}", p)).collect();
1184 quote! { #(#path_parts)::* }
1185 }
1186 } else {
1187 let type_ident = format_ident!("{}", rust_type);
1188 quote! { #type_ident }
1189 }
1190 }
1191 SchemaType::Reference { target } => {
1192 let target_type = format_ident!("{}", self.to_rust_type_name(target));
1193 if analysis.dependencies.recursive_schemas.contains(target) {
1195 quote! { Box<#target_type> }
1196 } else {
1197 quote! { #target_type }
1198 }
1199 }
1200 SchemaType::Array { item_type } => {
1201 let inner_type = self.generate_array_item_type(item_type, analysis);
1202 quote! { Vec<#inner_type> }
1203 }
1204 _ => {
1205 quote! { serde_json::Value }
1207 }
1208 };
1209
1210 let override_key = format!("{schema_name}.{field_name}");
1212 let is_nullable_override = self
1213 .config
1214 .nullable_field_overrides
1215 .get(&override_key)
1216 .copied()
1217 .unwrap_or(false);
1218
1219 if is_required && !prop.nullable && !is_nullable_override {
1220 base_type
1221 } else {
1222 quote! { Option<#base_type> }
1223 }
1224 }
1225
1226 fn generate_serde_field_attrs(
1227 &self,
1228 field_name: &str,
1229 prop: &crate::analysis::PropertyInfo,
1230 is_required: bool,
1231 ) -> TokenStream {
1232 let mut attrs = Vec::new();
1233
1234 let rust_field_name = self.to_rust_field_name(field_name);
1236 if rust_field_name != field_name {
1237 attrs.push(quote! { rename = #field_name });
1238 }
1239
1240 if !is_required || prop.nullable {
1242 attrs.push(quote! { skip_serializing_if = "Option::is_none" });
1243 }
1244
1245 if prop.default.is_some() && (is_required && !prop.nullable) {
1248 attrs.push(quote! { default });
1249 }
1250
1251 if attrs.is_empty() {
1252 TokenStream::new()
1253 } else {
1254 quote! { #[serde(#(#attrs),*)] }
1255 }
1256 }
1257
1258 fn generate_specta_field_attrs(&self, field_name: &str) -> TokenStream {
1259 if !self.config.enable_specta {
1260 return TokenStream::new();
1261 }
1262
1263 let camel_case_name = self.to_camel_case(field_name);
1265
1266 if camel_case_name != field_name {
1268 quote! { #[cfg_attr(feature = "specta", specta(rename = #camel_case_name))] }
1269 } else {
1270 TokenStream::new()
1271 }
1272 }
1273
1274 fn to_rust_enum_variant(&self, s: &str) -> String {
1275 let mut result = String::new();
1277 let mut next_upper = true;
1278 let mut prev_was_upper = false;
1279
1280 for (i, c) in s.chars().enumerate() {
1281 match c {
1282 'a'..='z' => {
1283 if next_upper {
1284 result.push(c.to_ascii_uppercase());
1285 next_upper = false;
1286 } else {
1287 result.push(c);
1288 }
1289 prev_was_upper = false;
1290 }
1291 'A'..='Z' => {
1292 if next_upper || (!prev_was_upper && i > 0) {
1293 result.push(c);
1295 next_upper = false;
1296 } else {
1297 result.push(c.to_ascii_lowercase());
1299 }
1300 prev_was_upper = true;
1301 }
1302 '0'..='9' => {
1303 result.push(c);
1304 next_upper = false;
1305 prev_was_upper = false;
1306 }
1307 '.' | '-' | '_' | ' ' | '@' | '#' | '$' | '/' | '\\' => {
1308 next_upper = true;
1310 prev_was_upper = false;
1311 }
1312 _ => {
1313 next_upper = true;
1315 prev_was_upper = false;
1316 }
1317 }
1318 }
1319
1320 if result.is_empty() {
1322 result = "Value".to_string();
1323 }
1324
1325 if result.chars().next().is_some_and(|c| c.is_ascii_digit()) {
1327 result = format!("Variant{result}");
1328 }
1329
1330 match result.as_str() {
1332 "Null" => "NullValue".to_string(),
1333 "True" => "TrueValue".to_string(),
1334 "False" => "FalseValue".to_string(),
1335 "Type" => "Type_".to_string(),
1336 "Match" => "Match_".to_string(),
1337 "Fn" => "Fn_".to_string(),
1338 "Impl" => "Impl_".to_string(),
1339 "Trait" => "Trait_".to_string(),
1340 "Struct" => "Struct_".to_string(),
1341 "Enum" => "Enum_".to_string(),
1342 "Mod" => "Mod_".to_string(),
1343 "Use" => "Use_".to_string(),
1344 "Pub" => "Pub_".to_string(),
1345 "Const" => "Const_".to_string(),
1346 "Static" => "Static_".to_string(),
1347 "Let" => "Let_".to_string(),
1348 "Mut" => "Mut_".to_string(),
1349 "Ref" => "Ref_".to_string(),
1350 "Move" => "Move_".to_string(),
1351 "Return" => "Return_".to_string(),
1352 "If" => "If_".to_string(),
1353 "Else" => "Else_".to_string(),
1354 "While" => "While_".to_string(),
1355 "For" => "For_".to_string(),
1356 "Loop" => "Loop_".to_string(),
1357 "Break" => "Break_".to_string(),
1358 "Continue" => "Continue_".to_string(),
1359 "Self" => "Self_".to_string(),
1360 "Super" => "Super_".to_string(),
1361 "Crate" => "Crate_".to_string(),
1362 "Async" => "Async_".to_string(),
1363 "Await" => "Await_".to_string(),
1364 _ => result,
1365 }
1366 }
1367
1368 #[allow(dead_code)]
1369 fn to_rust_identifier(&self, s: &str) -> String {
1370 let mut result = s
1372 .chars()
1373 .map(|c| match c {
1374 'a'..='z' | 'A'..='Z' | '0'..='9' => c,
1375 '.' | '-' | '_' | ' ' | '@' | '#' | '$' | '/' | '\\' => '_',
1376 _ => '_',
1377 })
1378 .collect::<String>();
1379
1380 result = result.trim_matches('_').to_string();
1382
1383 if result.is_empty() {
1385 result = "value".to_string();
1386 }
1387
1388 if result.chars().next().is_some_and(|c| c.is_ascii_digit()) {
1390 result = format!("variant_{result}");
1391 }
1392
1393 match result.as_str() {
1395 "null" => "null_value".to_string(),
1396 "true" => "true_value".to_string(),
1397 "false" => "false_value".to_string(),
1398 "type" => "type_".to_string(),
1399 "match" => "match_".to_string(),
1400 "fn" => "fn_".to_string(),
1401 "impl" => "impl_".to_string(),
1402 "trait" => "trait_".to_string(),
1403 "struct" => "struct_".to_string(),
1404 "enum" => "enum_".to_string(),
1405 "mod" => "mod_".to_string(),
1406 "use" => "use_".to_string(),
1407 "pub" => "pub_".to_string(),
1408 "const" => "const_".to_string(),
1409 "static" => "static_".to_string(),
1410 "let" => "let_".to_string(),
1411 "mut" => "mut_".to_string(),
1412 "ref" => "ref_".to_string(),
1413 "move" => "move_".to_string(),
1414 "return" => "return_".to_string(),
1415 "if" => "if_".to_string(),
1416 "else" => "else_".to_string(),
1417 "while" => "while_".to_string(),
1418 "for" => "for_".to_string(),
1419 "loop" => "loop_".to_string(),
1420 "break" => "break_".to_string(),
1421 "continue" => "continue_".to_string(),
1422 "self" => "self_".to_string(),
1423 "super" => "super_".to_string(),
1424 "crate" => "crate_".to_string(),
1425 "async" => "async_".to_string(),
1426 "await" => "await_".to_string(),
1427 "override" => "override_".to_string(),
1429 "box" => "box_".to_string(),
1430 "dyn" => "dyn_".to_string(),
1431 "where" => "where_".to_string(),
1432 "in" => "in_".to_string(),
1433 "abstract" => "abstract_".to_string(),
1435 "become" => "become_".to_string(),
1436 "do" => "do_".to_string(),
1437 "final" => "final_".to_string(),
1438 "macro" => "macro_".to_string(),
1439 "priv" => "priv_".to_string(),
1440 "try" => "try_".to_string(),
1441 "typeof" => "typeof_".to_string(),
1442 "unsized" => "unsized_".to_string(),
1443 "virtual" => "virtual_".to_string(),
1444 "yield" => "yield_".to_string(),
1445 _ => result,
1446 }
1447 }
1448
1449 fn sanitize_doc_comment(&self, desc: &str) -> String {
1450 let mut result = desc.to_string();
1452
1453 if result.contains('\n')
1461 && (result.contains('{')
1462 || result.contains("```")
1463 || result.contains("Human:")
1464 || result.contains("Assistant:")
1465 || result
1466 .lines()
1467 .any(|line| line.trim().starts_with('"') && line.trim().ends_with('"')))
1468 {
1469 if result.contains("```") {
1471 result = result.replace("```", "```ignore");
1472 } else {
1473 if result.lines().any(|line| {
1475 let trimmed = line.trim();
1476 trimmed.starts_with('"') && trimmed.ends_with('"') && trimmed.len() > 2
1477 }) {
1478 result = format!("```ignore\n{result}\n```");
1479 }
1480 }
1481 }
1482
1483 result
1484 }
1485
1486 pub(crate) fn to_rust_type_name(&self, s: &str) -> String {
1487 let mut result = String::new();
1489 let mut next_upper = true;
1490 let mut prev_was_lower = false;
1491
1492 for c in s.chars() {
1493 match c {
1494 'a'..='z' => {
1495 if next_upper {
1496 result.push(c.to_ascii_uppercase());
1497 next_upper = false;
1498 } else {
1499 result.push(c);
1500 }
1501 prev_was_lower = true;
1502 }
1503 'A'..='Z' => {
1504 result.push(c);
1505 next_upper = false;
1506 prev_was_lower = false;
1507 }
1508 '0'..='9' => {
1509 if prev_was_lower && !result.chars().last().unwrap_or(' ').is_ascii_digit() {
1512 }
1514 result.push(c);
1515 next_upper = false;
1516 prev_was_lower = false;
1517 }
1518 '_' | '-' | '.' | ' ' => {
1519 next_upper = true;
1521 prev_was_lower = false;
1522 }
1523 _ => {
1524 next_upper = true;
1526 prev_was_lower = false;
1527 }
1528 }
1529 }
1530
1531 if result.is_empty() {
1533 result = "Type".to_string();
1534 }
1535
1536 if result.chars().next().is_some_and(|c| c.is_ascii_digit()) {
1538 result = format!("Type{result}");
1539 }
1540
1541 result
1542 }
1543
1544 fn to_rust_field_name(&self, s: &str) -> String {
1545 let mut result = String::new();
1547 let mut prev_was_upper = false;
1548 let mut prev_was_underscore = false;
1549
1550 for (i, c) in s.chars().enumerate() {
1551 match c {
1552 'A'..='Z' => {
1553 if i > 0 && !prev_was_upper && !prev_was_underscore {
1555 result.push('_');
1556 }
1557 result.push(c.to_ascii_lowercase());
1558 prev_was_upper = true;
1559 prev_was_underscore = false;
1560 }
1561 'a'..='z' | '0'..='9' => {
1562 result.push(c);
1563 prev_was_upper = false;
1564 prev_was_underscore = false;
1565 }
1566 '-' | '.' | '_' | '@' | '#' | '$' | ' ' => {
1567 if !prev_was_underscore && !result.is_empty() {
1568 result.push('_');
1569 prev_was_underscore = true;
1570 }
1571 prev_was_upper = false;
1572 }
1573 _ => {
1574 if !prev_was_underscore && !result.is_empty() {
1576 result.push('_');
1577 }
1578 prev_was_upper = false;
1579 prev_was_underscore = true;
1580 }
1581 }
1582 }
1583
1584 let mut result = result.trim_matches('_').to_string();
1586 if result.is_empty() {
1587 return "field".to_string();
1588 }
1589
1590 if result.chars().next().is_some_and(|c| c.is_ascii_digit()) {
1592 result = format!("field_{result}");
1593 }
1594
1595 match result.as_str() {
1597 "type" => "type_".to_string(),
1598 "match" => "match_".to_string(),
1599 "fn" => "fn_".to_string(),
1600 "struct" => "struct_".to_string(),
1601 "enum" => "enum_".to_string(),
1602 "impl" => "impl_".to_string(),
1603 "trait" => "trait_".to_string(),
1604 "mod" => "mod_".to_string(),
1605 "use" => "use_".to_string(),
1606 "pub" => "pub_".to_string(),
1607 "const" => "const_".to_string(),
1608 "static" => "static_".to_string(),
1609 "let" => "let_".to_string(),
1610 "mut" => "mut_".to_string(),
1611 "ref" => "ref_".to_string(),
1612 "move" => "move_".to_string(),
1613 "return" => "return_".to_string(),
1614 "if" => "if_".to_string(),
1615 "else" => "else_".to_string(),
1616 "while" => "while_".to_string(),
1617 "for" => "for_".to_string(),
1618 "loop" => "loop_".to_string(),
1619 "break" => "break_".to_string(),
1620 "continue" => "continue_".to_string(),
1621 "self" => "self_".to_string(),
1623 "super" => "super_".to_string(),
1624 "crate" => "crate_".to_string(),
1625 "async" => "async_".to_string(),
1626 "await" => "await_".to_string(),
1627 "override" => "override_".to_string(),
1629 "box" => "box_".to_string(),
1630 "dyn" => "dyn_".to_string(),
1631 "where" => "where_".to_string(),
1632 "in" => "in_".to_string(),
1633 "abstract" => "abstract_".to_string(),
1635 "become" => "become_".to_string(),
1636 "do" => "do_".to_string(),
1637 "final" => "final_".to_string(),
1638 "macro" => "macro_".to_string(),
1639 "priv" => "priv_".to_string(),
1640 "try" => "try_".to_string(),
1641 "typeof" => "typeof_".to_string(),
1642 "unsized" => "unsized_".to_string(),
1643 "virtual" => "virtual_".to_string(),
1644 "yield" => "yield_".to_string(),
1645 _ => result,
1646 }
1647 }
1648
1649 fn to_camel_case(&self, s: &str) -> String {
1650 let mut result = String::new();
1652 let mut capitalize_next = false;
1653
1654 for (i, c) in s.chars().enumerate() {
1655 match c {
1656 '_' | '-' | '.' | ' ' => {
1657 capitalize_next = true;
1659 }
1660 'A'..='Z' => {
1661 if i == 0 {
1662 result.push(c.to_ascii_lowercase());
1664 } else if capitalize_next {
1665 result.push(c);
1666 capitalize_next = false;
1667 } else {
1668 result.push(c.to_ascii_lowercase());
1669 }
1670 }
1671 'a'..='z' | '0'..='9' => {
1672 if capitalize_next {
1673 result.push(c.to_ascii_uppercase());
1674 capitalize_next = false;
1675 } else {
1676 result.push(c);
1677 }
1678 }
1679 _ => {
1680 capitalize_next = true;
1682 }
1683 }
1684 }
1685
1686 if result.is_empty() {
1687 return "field".to_string();
1688 }
1689
1690 result
1691 }
1692
1693 fn generate_composition_struct(
1694 &self,
1695 schema: &crate::analysis::AnalyzedSchema,
1696 schemas: &[crate::analysis::SchemaRef],
1697 ) -> Result<TokenStream> {
1698 let struct_name = format_ident!("{}", self.to_rust_type_name(&schema.name));
1699
1700 let fields = schemas.iter().enumerate().map(|(i, schema_ref)| {
1706 let field_name = format_ident!("part_{}", i);
1707 let field_type = format_ident!("{}", self.to_rust_type_name(&schema_ref.target));
1708
1709 quote! {
1710 #[serde(flatten)]
1711 pub #field_name: #field_type,
1712 }
1713 });
1714
1715 let doc_comment = if let Some(desc) = &schema.description {
1716 quote! { #[doc = #desc] }
1717 } else {
1718 TokenStream::new()
1719 };
1720
1721 let derives = if self.config.enable_specta {
1723 quote! {
1724 #[derive(Debug, Clone, Deserialize, Serialize)]
1725 #[cfg_attr(feature = "specta", derive(specta::Type))]
1726 }
1727 } else {
1728 quote! {
1729 #[derive(Debug, Clone, Deserialize, Serialize)]
1730 }
1731 };
1732
1733 Ok(quote! {
1734 #doc_comment
1735 #derives
1736 pub struct #struct_name {
1737 #(#fields)*
1738 }
1739 })
1740 }
1741
1742 #[allow(dead_code)]
1743 fn find_missing_types(&self, analysis: &SchemaAnalysis) -> std::collections::HashSet<String> {
1744 let mut missing = std::collections::HashSet::new();
1745 let defined_types: std::collections::HashSet<String> =
1746 analysis.schemas.keys().cloned().collect();
1747
1748 for schema in analysis.schemas.values() {
1750 match &schema.schema_type {
1751 crate::analysis::SchemaType::Union { variants } => {
1752 for variant in variants {
1753 if !defined_types.contains(&variant.target) {
1754 missing.insert(variant.target.clone());
1755 }
1756 }
1757 }
1758 crate::analysis::SchemaType::DiscriminatedUnion { variants, .. } => {
1759 for variant in variants {
1760 if !defined_types.contains(&variant.type_name) {
1761 missing.insert(variant.type_name.clone());
1762 }
1763 }
1764 }
1765 crate::analysis::SchemaType::Object { properties, .. } => {
1766 let mut sorted_props: Vec<_> = properties.iter().collect();
1768 sorted_props.sort_by_key(|(name, _)| name.as_str());
1769 for (_, prop) in sorted_props {
1770 if let crate::analysis::SchemaType::Reference { target } = &prop.schema_type
1771 {
1772 if !defined_types.contains(target) {
1773 missing.insert(target.clone());
1774 }
1775 }
1776 }
1777 }
1778 crate::analysis::SchemaType::Reference { target } => {
1779 if !defined_types.contains(target) {
1780 missing.insert(target.clone());
1781 }
1782 }
1783 _ => {}
1784 }
1785 }
1786
1787 missing
1788 }
1789
1790 #[allow(clippy::only_used_in_recursion)]
1791 fn generate_array_item_type(
1792 &self,
1793 item_type: &crate::analysis::SchemaType,
1794 analysis: &crate::analysis::SchemaAnalysis,
1795 ) -> TokenStream {
1796 use crate::analysis::SchemaType;
1797
1798 match item_type {
1799 SchemaType::Primitive { rust_type } => {
1800 if rust_type.contains("::") {
1802 let parts: Vec<&str> = rust_type.split("::").collect();
1803 if parts.len() == 2 {
1804 let module = format_ident!("{}", parts[0]);
1805 let type_name = format_ident!("{}", parts[1]);
1806 quote! { #module::#type_name }
1807 } else {
1808 let path_parts: Vec<_> =
1810 parts.iter().map(|p| format_ident!("{}", p)).collect();
1811 quote! { #(#path_parts)::* }
1812 }
1813 } else {
1814 let type_ident = format_ident!("{}", rust_type);
1815 quote! { #type_ident }
1816 }
1817 }
1818 SchemaType::Reference { target } => {
1819 let target_type = format_ident!("{}", self.to_rust_type_name(target));
1820 if analysis.dependencies.recursive_schemas.contains(target) {
1822 quote! { Box<#target_type> }
1823 } else {
1824 quote! { #target_type }
1825 }
1826 }
1827 SchemaType::Array { item_type } => {
1828 let inner_type = self.generate_array_item_type(item_type, analysis);
1830 quote! { Vec<#inner_type> }
1831 }
1832 _ => {
1833 quote! { serde_json::Value }
1835 }
1836 }
1837 }
1838
1839 fn type_name_to_variant_name(&self, type_name: &str) -> String {
1841 match type_name {
1843 "bool" => return "Boolean".to_string(),
1844 "i8" | "i16" | "i32" | "i64" | "i128" => return "Integer".to_string(),
1845 "u8" | "u16" | "u32" | "u64" | "u128" => return "UnsignedInteger".to_string(),
1846 "f32" | "f64" => return "Number".to_string(),
1847 "String" => return "String".to_string(),
1848 _ => {}
1849 }
1850
1851 if type_name.starts_with("Vec<") && type_name.ends_with(">") {
1853 let inner = &type_name[4..type_name.len() - 1];
1854 if inner.starts_with("Vec<") && inner.ends_with(">") {
1856 let inner_inner = &inner[4..inner.len() - 1];
1857 return format!("{}ArrayArray", self.type_name_to_variant_name(inner_inner));
1858 }
1859 return format!("{}Array", self.type_name_to_variant_name(inner));
1860 }
1861
1862 let clean_name = type_name
1868 .trim_end_matches("Type")
1869 .trim_end_matches("Schema")
1870 .trim_end_matches("Item");
1871
1872 self.to_rust_type_name(clean_name)
1874 }
1875
1876 fn ensure_unique_variant_name_generator(
1878 &self,
1879 base_name: String,
1880 used_names: &mut std::collections::HashSet<String>,
1881 fallback_index: usize,
1882 ) -> String {
1883 if used_names.insert(base_name.clone()) {
1884 return base_name;
1885 }
1886
1887 for i in 2..100 {
1889 let numbered_name = format!("{base_name}{i}");
1890 if used_names.insert(numbered_name.clone()) {
1891 return numbered_name;
1892 }
1893 }
1894
1895 let fallback = format!("Variant{fallback_index}");
1897 used_names.insert(fallback.clone());
1898 fallback
1899 }
1900
1901 fn find_request_type_for_operation(
1903 &self,
1904 operation_id: &str,
1905 analysis: &SchemaAnalysis,
1906 ) -> Option<String> {
1907 analysis.operations.get(operation_id).and_then(|op| {
1909 op.request_body
1910 .as_ref()
1911 .and_then(|rb| rb.schema_name().map(|s| s.to_string()))
1912 })
1913 }
1914
1915 fn resolve_streaming_event_type(
1917 &self,
1918 endpoint: &crate::streaming::StreamingEndpoint,
1919 analysis: &SchemaAnalysis,
1920 ) -> Result<String> {
1921 match &endpoint.event_flow {
1922 crate::streaming::EventFlow::Simple => {
1923 if analysis.schemas.contains_key(&endpoint.event_union_type) {
1926 Ok(endpoint.event_union_type.to_string())
1927 } else {
1928 Err(crate::error::GeneratorError::ValidationError(format!(
1929 "Streaming response type '{}' not found in schema for simple streaming endpoint '{}'",
1930 endpoint.event_union_type, endpoint.operation_id
1931 )))
1932 }
1933 }
1934 crate::streaming::EventFlow::StartDeltaStop { .. } => {
1935 if analysis.schemas.contains_key(&endpoint.event_union_type) {
1938 Ok(endpoint.event_union_type.to_string())
1939 } else {
1940 Err(crate::error::GeneratorError::ValidationError(format!(
1941 "Event union type '{}' not found in schema for complex streaming endpoint '{}'",
1942 endpoint.event_union_type, endpoint.operation_id
1943 )))
1944 }
1945 }
1946 }
1947 }
1948
1949 fn generate_streaming_error_types(&self) -> Result<TokenStream> {
1951 Ok(quote! {
1952 #[derive(Debug, thiserror::Error)]
1954 pub enum StreamingError {
1955 #[error("Connection error: {0}")]
1956 Connection(String),
1957 #[error("HTTP error: {status}")]
1958 Http { status: u16 },
1959 #[error("SSE parsing error: {0}")]
1960 Parsing(String),
1961 #[error("Authentication error: {0}")]
1962 Authentication(String),
1963 #[error("Rate limit error: {0}")]
1964 RateLimit(String),
1965 #[error("API error: {0}")]
1966 Api(String),
1967 #[error("Timeout error: {0}")]
1968 Timeout(String),
1969 #[error("JSON serialization/deserialization error: {0}")]
1970 Json(#[from] serde_json::Error),
1971 #[error("Request error: {0}")]
1972 Request(reqwest::Error),
1973 }
1974
1975 impl From<reqwest::header::InvalidHeaderValue> for StreamingError {
1976 fn from(err: reqwest::header::InvalidHeaderValue) -> Self {
1977 StreamingError::Api(format!("Invalid header value: {}", err))
1978 }
1979 }
1980
1981 impl From<reqwest::Error> for StreamingError {
1982 fn from(err: reqwest::Error) -> Self {
1983 if err.is_timeout() {
1984 StreamingError::Timeout(err.to_string())
1985 } else if err.is_status() {
1986 if let Some(status) = err.status() {
1987 StreamingError::Http { status: status.as_u16() }
1988 } else {
1989 StreamingError::Connection(err.to_string())
1990 }
1991 } else {
1992 StreamingError::Request(err)
1993 }
1994 }
1995 }
1996 })
1997 }
1998
1999 fn generate_endpoint_trait(
2001 &self,
2002 endpoint: &crate::streaming::StreamingEndpoint,
2003 analysis: &SchemaAnalysis,
2004 ) -> Result<TokenStream> {
2005 use crate::streaming::HttpMethod;
2006
2007 let trait_name = format_ident!(
2008 "{}StreamingClient",
2009 self.to_rust_type_name(&endpoint.operation_id)
2010 );
2011 let method_name =
2012 format_ident!("stream_{}", self.to_rust_field_name(&endpoint.operation_id));
2013 let event_type =
2014 format_ident!("{}", self.resolve_streaming_event_type(endpoint, analysis)?);
2015
2016 let method_signature = match endpoint.http_method {
2018 HttpMethod::Get => {
2019 let mut param_defs = Vec::new();
2021 for qp in &endpoint.query_parameters {
2022 let param_name = format_ident!("{}", self.to_rust_field_name(&qp.name));
2023 if qp.required {
2024 param_defs.push(quote! { #param_name: &str });
2025 } else {
2026 param_defs.push(quote! { #param_name: Option<&str> });
2027 }
2028 }
2029 quote! {
2030 async fn #method_name(
2031 &self,
2032 #(#param_defs),*
2033 ) -> Result<Pin<Box<dyn Stream<Item = Result<#event_type, Self::Error>> + Send>>, Self::Error>;
2034 }
2035 }
2036 HttpMethod::Post => {
2037 let request_type = self
2039 .find_request_type_for_operation(&endpoint.operation_id, analysis)
2040 .unwrap_or_else(|| "serde_json::Value".to_string());
2041 let request_type_ident = if request_type.contains("::") {
2042 let parts: Vec<&str> = request_type.split("::").collect();
2043 let path_parts: Vec<_> = parts.iter().map(|p| format_ident!("{}", p)).collect();
2044 quote! { #(#path_parts)::* }
2045 } else {
2046 let ident = format_ident!("{}", request_type);
2047 quote! { #ident }
2048 };
2049 quote! {
2050 async fn #method_name(
2051 &self,
2052 request: #request_type_ident,
2053 ) -> Result<Pin<Box<dyn Stream<Item = Result<#event_type, Self::Error>> + Send>>, Self::Error>;
2054 }
2055 }
2056 };
2057
2058 Ok(quote! {
2059 #[async_trait]
2061 pub trait #trait_name {
2062 type Error: std::error::Error + Send + Sync + 'static;
2063
2064 #method_signature
2066 }
2067 })
2068 }
2069
2070 fn generate_streaming_client_impl(
2072 &self,
2073 streaming_config: &crate::streaming::StreamingConfig,
2074 analysis: &SchemaAnalysis,
2075 ) -> Result<TokenStream> {
2076 let client_name = format_ident!(
2077 "{}Client",
2078 self.to_rust_type_name(&streaming_config.client_module_name)
2079 );
2080
2081 let mut struct_fields = vec![
2084 quote! { base_url: String },
2085 quote! { api_key: Option<String> },
2086 quote! { http_client: reqwest::Client },
2087 quote! { custom_headers: std::collections::BTreeMap<String, String> },
2088 ];
2089
2090 let has_optional_headers = !streaming_config
2091 .endpoints
2092 .iter()
2093 .all(|e| e.optional_headers.is_empty());
2094
2095 if has_optional_headers {
2096 struct_fields
2097 .push(quote! { optional_headers: std::collections::BTreeMap<String, String> });
2098 }
2099
2100 let default_base_url = if let Some(ref streaming_config) = self.config.streaming_config {
2103 streaming_config
2104 .endpoints
2105 .first()
2106 .and_then(|e| e.base_url.as_deref())
2107 .unwrap_or("https://api.example.com")
2108 } else {
2109 "https://api.example.com"
2110 };
2111
2112 let constructor_fields = if has_optional_headers {
2114 quote! {
2115 base_url: #default_base_url.to_string(),
2116 api_key: None,
2117 http_client: reqwest::Client::new(),
2118 custom_headers: std::collections::BTreeMap::new(),
2119 optional_headers: std::collections::BTreeMap::new(),
2120 }
2121 } else {
2122 quote! {
2123 base_url: #default_base_url.to_string(),
2124 api_key: None,
2125 http_client: reqwest::Client::new(),
2126 custom_headers: std::collections::BTreeMap::new(),
2127 }
2128 };
2129
2130 let optional_headers_method = if has_optional_headers {
2132 quote! {
2133 pub fn set_optional_headers(&mut self, headers: std::collections::BTreeMap<String, String>) {
2135 self.optional_headers = headers;
2136 }
2137 }
2138 } else {
2139 TokenStream::new()
2140 };
2141
2142 let constructor = quote! {
2143 impl #client_name {
2144 pub fn new() -> Self {
2146 Self {
2147 #constructor_fields
2148 }
2149 }
2150
2151 pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
2153 self.base_url = base_url.into();
2154 self
2155 }
2156
2157 pub fn with_api_key(mut self, api_key: impl Into<String>) -> Self {
2159 self.api_key = Some(api_key.into());
2160 self
2161 }
2162
2163 pub fn with_header(
2165 mut self,
2166 name: impl Into<String>,
2167 value: impl Into<String>,
2168 ) -> Self {
2169 self.custom_headers.insert(name.into(), value.into());
2170 self
2171 }
2172
2173 pub fn with_http_client(mut self, client: reqwest::Client) -> Self {
2175 self.http_client = client;
2176 self
2177 }
2178
2179 #optional_headers_method
2180 }
2181 };
2182
2183 let mut trait_impls = Vec::new();
2185 for endpoint in &streaming_config.endpoints {
2186 let trait_impl = self.generate_endpoint_trait_impl(endpoint, &client_name, analysis)?;
2187 trait_impls.push(trait_impl);
2188 }
2189
2190 let default_impl = quote! {
2192 impl Default for #client_name {
2193 fn default() -> Self {
2194 Self::new()
2195 }
2196 }
2197 };
2198
2199 Ok(quote! {
2200 #[derive(Debug, Clone)]
2202 pub struct #client_name {
2203 #(#struct_fields,)*
2204 }
2205
2206 #constructor
2207
2208 #default_impl
2209
2210 #(#trait_impls)*
2211 })
2212 }
2213
2214 fn generate_endpoint_trait_impl(
2216 &self,
2217 endpoint: &crate::streaming::StreamingEndpoint,
2218 client_name: &proc_macro2::Ident,
2219 analysis: &SchemaAnalysis,
2220 ) -> Result<TokenStream> {
2221 use crate::streaming::HttpMethod;
2222
2223 let trait_name = format_ident!(
2224 "{}StreamingClient",
2225 self.to_rust_type_name(&endpoint.operation_id)
2226 );
2227 let method_name =
2228 format_ident!("stream_{}", self.to_rust_field_name(&endpoint.operation_id));
2229 let event_type =
2230 format_ident!("{}", self.resolve_streaming_event_type(endpoint, analysis)?);
2231
2232 let mut header_setup = Vec::new();
2234 for (name, value) in &endpoint.required_headers {
2235 header_setup.push(quote! {
2236 headers.insert(#name, HeaderValue::from_static(#value));
2237 });
2238 }
2239
2240 if let Some(auth_header) = &endpoint.auth_header {
2243 match auth_header {
2244 crate::streaming::AuthHeader::Bearer(header_name) => {
2245 header_setup.push(quote! {
2246 if let Some(ref api_key) = self.api_key {
2247 headers.insert(#header_name, HeaderValue::from_str(&format!("Bearer {}", api_key))?);
2248 }
2249 });
2250 }
2251 crate::streaming::AuthHeader::ApiKey(header_name) => {
2252 header_setup.push(quote! {
2253 if let Some(ref api_key) = self.api_key {
2254 headers.insert(#header_name, HeaderValue::from_str(api_key)?);
2255 }
2256 });
2257 }
2258 }
2259 } else {
2260 header_setup.push(quote! {
2262 if let Some(ref api_key) = self.api_key {
2263 headers.insert("Authorization", HeaderValue::from_str(&format!("Bearer {}", api_key))?);
2264 }
2265 });
2266 }
2267
2268 header_setup.push(quote! {
2270 for (name, value) in &self.custom_headers {
2271 if let (Ok(header_name), Ok(header_value)) = (reqwest::header::HeaderName::from_bytes(name.as_bytes()), HeaderValue::from_str(value)) {
2272 headers.insert(header_name, header_value);
2273 }
2274 }
2275 });
2276
2277 if !endpoint.optional_headers.is_empty() {
2279 header_setup.push(quote! {
2280 for (key, value) in &self.optional_headers {
2281 if let (Ok(header_name), Ok(header_value)) = (reqwest::header::HeaderName::from_bytes(key.as_bytes()), HeaderValue::from_str(value)) {
2282 headers.insert(header_name, header_value);
2283 }
2284 }
2285 });
2286 }
2287
2288 match endpoint.http_method {
2290 HttpMethod::Get => self.generate_get_streaming_impl(
2291 endpoint,
2292 client_name,
2293 &trait_name,
2294 &method_name,
2295 &event_type,
2296 &header_setup,
2297 ),
2298 HttpMethod::Post => self.generate_post_streaming_impl(
2299 endpoint,
2300 client_name,
2301 &trait_name,
2302 &method_name,
2303 &event_type,
2304 &header_setup,
2305 analysis,
2306 ),
2307 }
2308 }
2309
2310 fn generate_get_streaming_impl(
2312 &self,
2313 endpoint: &crate::streaming::StreamingEndpoint,
2314 client_name: &proc_macro2::Ident,
2315 trait_name: &proc_macro2::Ident,
2316 method_name: &proc_macro2::Ident,
2317 event_type: &proc_macro2::Ident,
2318 header_setup: &[TokenStream],
2319 ) -> Result<TokenStream> {
2320 let path = &endpoint.path;
2321
2322 let mut param_defs = Vec::new();
2324 let mut query_params = Vec::new();
2325
2326 for qp in &endpoint.query_parameters {
2327 let param_name = format_ident!("{}", self.to_rust_field_name(&qp.name));
2328 let param_name_str = &qp.name;
2329
2330 if qp.required {
2331 param_defs.push(quote! { #param_name: &str });
2332 query_params.push(quote! {
2333 url.query_pairs_mut().append_pair(#param_name_str, #param_name);
2334 });
2335 } else {
2336 param_defs.push(quote! { #param_name: Option<&str> });
2337 query_params.push(quote! {
2338 if let Some(v) = #param_name {
2339 url.query_pairs_mut().append_pair(#param_name_str, v);
2340 }
2341 });
2342 }
2343 }
2344
2345 let url_construction = quote! {
2347 let base_url = url::Url::parse(&self.base_url)
2348 .map_err(|e| StreamingError::Connection(format!("Invalid base URL: {}", e)))?;
2349 let path_to_join = #path.trim_start_matches('/');
2350 let mut url = base_url.join(path_to_join)
2351 .map_err(|e| StreamingError::Connection(format!("URL join error: {}", e)))?;
2352 #(#query_params)*
2353 };
2354
2355 let instrument_skip = quote! { #[instrument(skip(self), name = "streaming_get_request")] };
2356
2357 Ok(quote! {
2358 #[async_trait]
2359 impl #trait_name for #client_name {
2360 type Error = StreamingError;
2361
2362 #instrument_skip
2363 async fn #method_name(
2364 &self,
2365 #(#param_defs),*
2366 ) -> Result<Pin<Box<dyn Stream<Item = Result<#event_type, Self::Error>> + Send>>, Self::Error> {
2367 debug!("Starting streaming GET request");
2368
2369 let mut headers = HeaderMap::new();
2370 #(#header_setup)*
2371
2372 #url_construction
2373 let url_str = url.to_string();
2374 debug!("Making streaming GET request to: {}", url_str);
2375
2376 let request_builder = self.http_client
2377 .get(url_str)
2378 .headers(headers);
2379
2380 debug!("Creating SSE stream from request");
2381 let stream = parse_sse_stream::<#event_type>(request_builder).await?;
2382 info!("SSE stream created successfully");
2383 Ok(Box::pin(stream))
2384 }
2385 }
2386 })
2387 }
2388
2389 #[allow(clippy::too_many_arguments)]
2391 fn generate_post_streaming_impl(
2392 &self,
2393 endpoint: &crate::streaming::StreamingEndpoint,
2394 client_name: &proc_macro2::Ident,
2395 trait_name: &proc_macro2::Ident,
2396 method_name: &proc_macro2::Ident,
2397 event_type: &proc_macro2::Ident,
2398 header_setup: &[TokenStream],
2399 analysis: &SchemaAnalysis,
2400 ) -> Result<TokenStream> {
2401 let path = &endpoint.path;
2402
2403 let request_type = self
2405 .find_request_type_for_operation(&endpoint.operation_id, analysis)
2406 .unwrap_or_else(|| "serde_json::Value".to_string());
2407 let request_type_ident = if request_type.contains("::") {
2408 let parts: Vec<&str> = request_type.split("::").collect();
2409 let path_parts: Vec<_> = parts.iter().map(|p| format_ident!("{}", p)).collect();
2410 quote! { #(#path_parts)::* }
2411 } else {
2412 let ident = format_ident!("{}", request_type);
2413 quote! { #ident }
2414 };
2415
2416 let url_construction = quote! {
2418 let base_url = url::Url::parse(&self.base_url)
2419 .map_err(|e| StreamingError::Connection(format!("Invalid base URL: {}", e)))?;
2420 let path_to_join = #path.trim_start_matches('/');
2421 let url = base_url.join(path_to_join)
2422 .map_err(|e| StreamingError::Connection(format!("URL join error: {}", e)))?
2423 .to_string();
2424 };
2425
2426 let stream_param = &endpoint.stream_parameter;
2428 let stream_setup = if stream_param.is_empty() {
2429 quote! {
2430 let streaming_request = request;
2431 }
2432 } else {
2433 quote! {
2434 let mut streaming_request = request;
2436 if let Ok(mut request_value) = serde_json::to_value(&streaming_request) {
2437 if let Some(obj) = request_value.as_object_mut() {
2438 obj.insert(#stream_param.to_string(), serde_json::Value::Bool(true));
2439 }
2440 streaming_request = serde_json::from_value(request_value)?;
2441 }
2442 }
2443 };
2444
2445 Ok(quote! {
2446 #[async_trait]
2447 impl #trait_name for #client_name {
2448 type Error = StreamingError;
2449
2450 #[instrument(skip(self, request), name = "streaming_post_request")]
2451 async fn #method_name(
2452 &self,
2453 request: #request_type_ident,
2454 ) -> Result<Pin<Box<dyn Stream<Item = Result<#event_type, Self::Error>> + Send>>, Self::Error> {
2455 debug!("Starting streaming POST request");
2456
2457 #stream_setup
2458
2459 let mut headers = HeaderMap::new();
2460 #(#header_setup)*
2461
2462 #url_construction
2463 debug!("Making streaming POST request to: {}", url);
2464
2465 let request_builder = self.http_client
2466 .post(&url)
2467 .headers(headers)
2468 .json(&streaming_request);
2469
2470 debug!("Creating SSE stream from request");
2471 let stream = parse_sse_stream::<#event_type>(request_builder).await?;
2472 info!("SSE stream created successfully");
2473 Ok(Box::pin(stream))
2474 }
2475 }
2476 })
2477 }
2478
2479 fn generate_sse_parser_utilities(
2481 &self,
2482 _streaming_config: &crate::streaming::StreamingConfig,
2483 ) -> Result<TokenStream> {
2484 Ok(quote! {
2485 pub async fn parse_sse_stream<T>(
2487 request_builder: reqwest::RequestBuilder
2488 ) -> Result<impl Stream<Item = Result<T, StreamingError>>, StreamingError>
2489 where
2490 T: serde::de::DeserializeOwned + Send + 'static,
2491 {
2492 let mut event_source = reqwest_eventsource::EventSource::new(request_builder).map_err(|e| {
2493 StreamingError::Connection(format!("Failed to create event source: {}", e))
2494 })?;
2495
2496 let stream = event_source.filter_map(|event_result| async move {
2497 match event_result {
2498 Ok(reqwest_eventsource::Event::Open) => {
2499 debug!("SSE connection opened");
2500 None
2501 }
2502 Ok(reqwest_eventsource::Event::Message(message)) => {
2503 if message.event == "ping" {
2505 debug!("Received SSE ping event, skipping");
2506 return None;
2507 }
2508
2509 if message.data.trim().is_empty() {
2511 debug!("Empty SSE data, skipping");
2512 return None;
2513 }
2514
2515 if let Ok(json_value) = serde_json::from_str::<serde_json::Value>(&message.data) {
2517 if let Some(event_type) = json_value.get("event").and_then(|v| v.as_str()) {
2518 if event_type == "ping" {
2519 debug!("Received ping event in JSON data, skipping");
2520 return None;
2521 }
2522 }
2523
2524 match serde_json::from_value::<T>(json_value) {
2526 Ok(parsed_event) => {
2527 Some(Ok(parsed_event))
2528 }
2529 Err(e) => {
2530 if message.data.contains("ping") || message.event.contains("ping") {
2531 debug!("Ignoring ping-related event: {}", message.data);
2532 None
2533 } else {
2534 Some(Err(StreamingError::Parsing(
2535 format!("Failed to parse SSE event: {} (raw: {})", e, message.data)
2536 )))
2537 }
2538 }
2539 }
2540 } else {
2541 Some(Err(StreamingError::Parsing(
2543 format!("SSE event is not valid JSON: {}", message.data)
2544 )))
2545 }
2546 }
2547 Err(e) => {
2548 match e {
2550 reqwest_eventsource::Error::StreamEnded => {
2551 debug!("SSE stream completed normally");
2552 None }
2554 reqwest_eventsource::Error::InvalidStatusCode(status, response) => {
2555 let status_code = status.as_u16();
2557
2558 let error_body = match response.text().await {
2560 Ok(body) => body,
2561 Err(_) => "Failed to read error response body".to_string()
2562 };
2563
2564 error!("SSE connection error - HTTP {}: {}", status_code, error_body);
2565
2566 let detailed_error = format!(
2567 "HTTP {} error: {}",
2568 status_code,
2569 error_body
2570 );
2571
2572 Some(Err(StreamingError::Connection(detailed_error)))
2573 }
2574 _ => {
2575 let error_str = e.to_string();
2576 if error_str.contains("stream closed") {
2577 debug!("SSE stream closed");
2578 None
2579 } else {
2580 error!("SSE connection error: {}", e);
2581 Some(Err(StreamingError::Connection(error_str)))
2582 }
2583 }
2584 }
2585 }
2586 }
2587 });
2588
2589 Ok(stream)
2590 }
2591 })
2592 }
2593
2594 fn generate_reconnection_utilities(
2596 &self,
2597 reconnect_config: &crate::streaming::ReconnectionConfig,
2598 ) -> Result<TokenStream> {
2599 let max_retries = reconnect_config.max_retries;
2600 let initial_delay = reconnect_config.initial_delay_ms;
2601 let max_delay = reconnect_config.max_delay_ms;
2602 let backoff_multiplier = reconnect_config.backoff_multiplier;
2603
2604 Ok(quote! {
2605 #[derive(Debug, Clone)]
2607 pub struct ReconnectionManager {
2608 max_retries: u32,
2609 initial_delay_ms: u64,
2610 max_delay_ms: u64,
2611 backoff_multiplier: f64,
2612 current_attempt: u32,
2613 }
2614
2615 impl ReconnectionManager {
2616 pub fn new() -> Self {
2618 Self {
2619 max_retries: #max_retries,
2620 initial_delay_ms: #initial_delay,
2621 max_delay_ms: #max_delay,
2622 backoff_multiplier: #backoff_multiplier,
2623 current_attempt: 0,
2624 }
2625 }
2626
2627 pub fn should_retry(&self) -> bool {
2629 self.current_attempt < self.max_retries
2630 }
2631
2632 pub fn next_retry_delay(&mut self) -> Duration {
2634 if !self.should_retry() {
2635 return Duration::from_secs(0);
2636 }
2637
2638 let delay_ms = (self.initial_delay_ms as f64
2639 * self.backoff_multiplier.powi(self.current_attempt as i32)) as u64;
2640 let delay_ms = delay_ms.min(self.max_delay_ms);
2641
2642 self.current_attempt += 1;
2643 Duration::from_millis(delay_ms)
2644 }
2645
2646 pub fn reset(&mut self) {
2648 self.current_attempt = 0;
2649 }
2650
2651 pub fn current_attempt(&self) -> u32 {
2653 self.current_attempt
2654 }
2655 }
2656
2657 impl Default for ReconnectionManager {
2658 fn default() -> Self {
2659 Self::new()
2660 }
2661 }
2662 })
2663 }
2664}