1use crate::{GeneratorError, Result, analysis::SchemaAnalysis, streaming::StreamingConfig};
2use proc_macro2::TokenStream;
3use quote::{format_ident, quote};
4use std::collections::BTreeMap;
5use std::path::PathBuf;
6
7#[derive(Clone)]
9struct DiscriminatedVariantInfo {
10 discriminator_field: String,
12 discriminator_value: String,
14 is_parent_untagged: bool,
16}
17
18#[derive(Debug, Clone)]
19pub struct GeneratorConfig {
20 pub spec_path: PathBuf,
22 pub output_dir: PathBuf,
24 pub module_name: String,
26 pub enable_sse_client: bool,
28 pub enable_async_client: bool,
30 pub enable_specta: bool,
32 pub type_mappings: BTreeMap<String, String>,
34 pub streaming_config: Option<StreamingConfig>,
36 pub nullable_field_overrides: BTreeMap<String, bool>,
39 pub schema_extensions: Vec<PathBuf>,
42 pub http_client_config: Option<crate::http_config::HttpClientConfig>,
44 pub retry_config: Option<crate::http_config::RetryConfig>,
46 pub tracing_enabled: bool,
48 pub auth_config: Option<crate::http_config::AuthConfig>,
50 pub enable_registry: bool,
52 pub registry_only: bool,
54}
55
56impl Default for GeneratorConfig {
57 fn default() -> Self {
58 Self {
59 spec_path: "openapi.json".into(),
60 output_dir: "src/gen".into(),
61 module_name: "api_types".to_string(),
62 enable_sse_client: true,
63 enable_async_client: true,
64 enable_specta: false,
65 type_mappings: default_type_mappings(),
66 streaming_config: None,
67 nullable_field_overrides: BTreeMap::new(),
68 schema_extensions: Vec::new(),
69 http_client_config: None,
70 retry_config: None,
71 tracing_enabled: true,
72 auth_config: None,
73 enable_registry: false,
74 registry_only: false,
75 }
76 }
77}
78
79pub fn default_type_mappings() -> BTreeMap<String, String> {
80 let mut mappings = BTreeMap::new();
81 mappings.insert("integer".to_string(), "i64".to_string());
82 mappings.insert("number".to_string(), "f64".to_string());
83 mappings.insert("string".to_string(), "String".to_string());
84 mappings.insert("boolean".to_string(), "bool".to_string());
85 mappings
86}
87
88#[derive(Debug, Clone)]
90pub struct GeneratedFile {
91 pub path: PathBuf,
93 pub content: String,
95}
96
97#[derive(Debug, Clone)]
99pub struct GenerationResult {
100 pub files: Vec<GeneratedFile>,
102 pub mod_file: GeneratedFile,
104}
105
106pub struct CodeGenerator {
107 config: GeneratorConfig,
108}
109
110impl CodeGenerator {
111 pub fn new(config: GeneratorConfig) -> Self {
112 Self { config }
113 }
114
115 pub fn config(&self) -> &GeneratorConfig {
117 &self.config
118 }
119
120 pub fn generate_all(&self, analysis: &mut SchemaAnalysis) -> Result<GenerationResult> {
122 let mut files = Vec::new();
123
124 if !self.config.registry_only {
125 let types_content = self.generate_types(analysis)?;
127 files.push(GeneratedFile {
128 path: "types.rs".into(),
129 content: types_content,
130 });
131
132 if let Some(ref streaming_config) = self.config.streaming_config {
134 let streaming_content =
135 self.generate_streaming_client(streaming_config, analysis)?;
136 files.push(GeneratedFile {
137 path: "streaming.rs".into(),
138 content: streaming_content,
139 });
140 }
141
142 if self.config.enable_async_client {
144 let http_content = self.generate_http_client(analysis)?;
145 files.push(GeneratedFile {
146 path: "client.rs".into(),
147 content: http_content,
148 });
149 }
150 }
151
152 if self.config.enable_registry || self.config.registry_only {
154 let registry_content = self.generate_registry(analysis)?;
155 files.push(GeneratedFile {
156 path: "registry.rs".into(),
157 content: registry_content,
158 });
159 }
160
161 let mod_content = self.generate_mod_file(&files)?;
163 let mod_file = GeneratedFile {
164 path: "mod.rs".into(),
165 content: mod_content,
166 };
167
168 Ok(GenerationResult { files, mod_file })
169 }
170
171 pub fn generate(&self, analysis: &mut SchemaAnalysis) -> Result<String> {
173 self.generate_types(analysis)
174 }
175
176 fn generate_types(&self, analysis: &mut SchemaAnalysis) -> Result<String> {
178 let mut type_definitions = TokenStream::new();
179
180 let mut discriminated_variant_info: BTreeMap<String, DiscriminatedVariantInfo> =
183 BTreeMap::new();
184
185 let mut sorted_schemas: Vec<_> = analysis.schemas.iter().collect();
187 sorted_schemas.sort_by_key(|(name, _)| name.as_str());
188
189 for (_parent_name, schema) in sorted_schemas {
190 if let crate::analysis::SchemaType::DiscriminatedUnion {
191 variants,
192 discriminator_field,
193 } = &schema.schema_type
194 {
195 let is_parent_untagged =
197 self.should_use_untagged_discriminated_union(schema, analysis);
198
199 for variant in variants {
200 if let Some(variant_schema) = analysis.schemas.get(&variant.type_name) {
203 if let crate::analysis::SchemaType::Object { properties, .. } =
204 &variant_schema.schema_type
205 {
206 if properties.contains_key(discriminator_field) {
207 discriminated_variant_info.insert(
208 variant.type_name.clone(),
209 DiscriminatedVariantInfo {
210 discriminator_field: discriminator_field.clone(),
211 discriminator_value: variant.discriminator_value.clone(),
212 is_parent_untagged,
213 },
214 );
215 }
216 }
217 }
218 }
219 }
220 }
221
222 let generation_order = analysis.dependencies.topological_sort()?;
224
225 let mut processed = std::collections::HashSet::new();
227
228 for schema_name in generation_order {
230 if let Some(schema) = analysis.schemas.get(&schema_name) {
231 let type_def =
232 self.generate_type_definition(schema, analysis, &discriminated_variant_info)?;
233 if !type_def.is_empty() {
234 type_definitions.extend(type_def);
235 }
236 processed.insert(schema_name);
237 }
238 }
239
240 let mut remaining_schemas: Vec<_> = analysis
243 .schemas
244 .iter()
245 .filter(|(name, _)| !processed.contains(*name))
246 .collect();
247 remaining_schemas.sort_by_key(|(name, _)| name.as_str());
248
249 for (_schema_name, schema) in remaining_schemas {
250 let type_def =
251 self.generate_type_definition(schema, analysis, &discriminated_variant_info)?;
252 if !type_def.is_empty() {
253 type_definitions.extend(type_def);
254 }
255 }
256
257 let generated = quote! {
259 #![allow(clippy::large_enum_variant)]
265 #![allow(clippy::format_in_format_args)]
266 #![allow(clippy::let_unit_value)]
267 #![allow(unreachable_patterns)]
268
269 use serde::{Deserialize, Serialize};
270
271 #type_definitions
272 };
273
274 let syntax_tree = syn::parse2::<syn::File>(generated).map_err(|e| {
276 GeneratorError::CodeGenError(format!("Failed to parse generated code: {e}"))
277 })?;
278
279 let formatted = prettyplease::unparse(&syntax_tree);
280
281 Ok(formatted)
282 }
283
284 fn generate_streaming_client(
286 &self,
287 streaming_config: &StreamingConfig,
288 analysis: &SchemaAnalysis,
289 ) -> Result<String> {
290 let mut client_code = TokenStream::new();
291
292 let imports = quote! {
294 #![allow(clippy::format_in_format_args)]
299 #![allow(clippy::let_unit_value)]
300 #![allow(unused_mut)]
301
302 use super::types::*;
303 use async_trait::async_trait;
304 use futures_util::{Stream, StreamExt};
305 use std::pin::Pin;
306 use std::time::Duration;
307 use reqwest::header::{HeaderMap, HeaderValue};
308 use tracing::{debug, error, info, warn, instrument};
309 };
310 client_code.extend(imports);
311
312 if streaming_config.generate_client {
314 let error_types = self.generate_streaming_error_types()?;
315 client_code.extend(error_types);
316 }
317
318 for endpoint in &streaming_config.endpoints {
320 let trait_code = self.generate_endpoint_trait(endpoint, analysis)?;
321 client_code.extend(trait_code);
322 }
323
324 if streaming_config.generate_client {
326 let client_impl = self.generate_streaming_client_impl(streaming_config, analysis)?;
327 client_code.extend(client_impl);
328 }
329
330 if streaming_config.event_parser_helpers {
332 let parser_code = self.generate_sse_parser_utilities(streaming_config)?;
333 client_code.extend(parser_code);
334 }
335
336 if let Some(reconnect_config) = &streaming_config.reconnection_config {
338 let reconnect_code = self.generate_reconnection_utilities(reconnect_config)?;
339 client_code.extend(reconnect_code);
340 }
341
342 let syntax_tree = syn::parse2::<syn::File>(client_code).map_err(|e| {
343 GeneratorError::CodeGenError(format!("Failed to parse streaming client code: {e}"))
344 })?;
345
346 Ok(prettyplease::unparse(&syntax_tree))
347 }
348
349 pub fn generate_http_client(&self, analysis: &SchemaAnalysis) -> Result<String> {
351 let error_types = self.generate_http_error_types();
352 let client_struct = self.generate_http_client_struct();
353 let operation_methods = self.generate_operation_methods(analysis);
354
355 let generated = quote! {
356 #![allow(clippy::format_in_format_args)]
361 #![allow(clippy::let_unit_value)]
362
363 use super::types::*;
364
365 #error_types
366
367 #client_struct
368
369 #operation_methods
370 };
371
372 let syntax_tree = syn::parse2::<syn::File>(generated).map_err(|e| {
373 GeneratorError::CodeGenError(format!("Failed to parse HTTP client code: {e}"))
374 })?;
375
376 Ok(prettyplease::unparse(&syntax_tree))
377 }
378
379 fn generate_http_error_types(&self) -> TokenStream {
381 quote! {
382 use thiserror::Error;
383
384 #[derive(Error, Debug)]
386 pub enum HttpError {
387 #[error("Network error: {0}")]
389 Network(#[from] reqwest::Error),
390
391 #[error("Middleware error: {0}")]
393 Middleware(#[from] reqwest_middleware::Error),
394
395 #[error("Failed to serialize request: {0}")]
397 Serialization(String),
398
399 #[error("Failed to deserialize response: {0}")]
401 Deserialization(String),
402
403 #[error("HTTP error {status}: {message}")]
405 Http {
406 status: u16,
407 message: String,
408 body: Option<String>,
409 },
410
411 #[error("Authentication error: {0}")]
413 Auth(String),
414
415 #[error("Request timeout")]
417 Timeout,
418
419 #[error("Configuration error: {0}")]
421 Config(String),
422
423 #[error("{0}")]
425 Other(String),
426 }
427
428 impl HttpError {
429 pub fn from_status(status: u16, message: impl Into<String>, body: Option<String>) -> Self {
431 Self::Http {
432 status,
433 message: message.into(),
434 body,
435 }
436 }
437
438 pub fn serialization_error(error: impl std::fmt::Display) -> Self {
440 Self::Serialization(error.to_string())
441 }
442
443 pub fn deserialization_error(error: impl std::fmt::Display) -> Self {
445 Self::Deserialization(error.to_string())
446 }
447
448 pub fn is_client_error(&self) -> bool {
450 matches!(self, Self::Http { status, .. } if *status >= 400 && *status < 500)
451 }
452
453 pub fn is_server_error(&self) -> bool {
455 matches!(self, Self::Http { status, .. } if *status >= 500 && *status < 600)
456 }
457
458 pub fn is_retryable(&self) -> bool {
460 match self {
461 Self::Network(_) => true,
462 Self::Middleware(_) => true,
463 Self::Timeout => true,
464 Self::Http { status, .. } => {
465 matches!(status, 429 | 500 | 502 | 503 | 504)
467 }
468 _ => false,
469 }
470 }
471 }
472
473 pub type HttpResult<T> = Result<T, HttpError>;
475 }
476 }
477
478 fn generate_mod_file(&self, files: &[GeneratedFile]) -> Result<String> {
480 let mut module_declarations = Vec::new();
481 let mut pub_uses = Vec::new();
482
483 for file in files {
484 if let Some(module_name) = file.path.file_stem().and_then(|s| s.to_str()) {
485 if module_name != "mod" {
486 module_declarations.push(format!("pub mod {module_name};"));
487 pub_uses.push(format!("pub use {module_name}::*;"));
488 }
489 }
490 }
491
492 let content = format!(
493 r#"//! Generated API modules
494//!
495//! This module exports all generated API types and clients.
496//! Do not edit manually - regenerate using the appropriate script.
497
498#![allow(unused_imports)]
499
500{}
501
502{}
503"#,
504 module_declarations.join("\n"),
505 pub_uses.join("\n")
506 );
507
508 Ok(content)
509 }
510
511 pub fn write_files(&self, result: &GenerationResult) -> Result<()> {
513 use std::fs;
514
515 fs::create_dir_all(&self.config.output_dir)?;
517
518 for file in &result.files {
520 let file_path = self.config.output_dir.join(&file.path);
521 fs::write(&file_path, &file.content)?;
522 }
523
524 let mod_path = self.config.output_dir.join(&result.mod_file.path);
526 fs::write(&mod_path, &result.mod_file.content)?;
527
528 Ok(())
529 }
530
531 fn generate_type_definition(
532 &self,
533 schema: &crate::analysis::AnalyzedSchema,
534 analysis: &crate::analysis::SchemaAnalysis,
535 discriminated_variant_info: &BTreeMap<String, DiscriminatedVariantInfo>,
536 ) -> Result<TokenStream> {
537 use crate::analysis::SchemaType;
538
539 match &schema.schema_type {
540 SchemaType::Primitive { rust_type } => {
541 self.generate_type_alias(schema, rust_type)
543 }
544 SchemaType::StringEnum { values } => self.generate_string_enum(schema, values),
545 SchemaType::ExtensibleEnum { known_values } => {
546 self.generate_extensible_enum(schema, known_values)
547 }
548 SchemaType::Object {
549 properties,
550 required,
551 additional_properties,
552 } => self.generate_struct(
553 schema,
554 properties,
555 required,
556 *additional_properties,
557 analysis,
558 discriminated_variant_info.get(&schema.name),
559 ),
560 SchemaType::DiscriminatedUnion {
561 discriminator_field,
562 variants,
563 } => {
564 if self.should_use_untagged_discriminated_union(schema, analysis) {
566 let schema_refs: Vec<crate::analysis::SchemaRef> = variants
568 .iter()
569 .map(|v| crate::analysis::SchemaRef {
570 target: v.type_name.clone(),
571 nullable: false,
572 })
573 .collect();
574 self.generate_union_enum(schema, &schema_refs)
575 } else {
576 self.generate_discriminated_enum(
577 schema,
578 discriminator_field,
579 variants,
580 analysis,
581 )
582 }
583 }
584 SchemaType::Union { variants } => self.generate_union_enum(schema, variants),
585 SchemaType::Reference { target } => {
586 if schema.name != *target {
589 let alias_name = format_ident!("{}", self.to_rust_type_name(&schema.name));
591 let target_type = format_ident!("{}", self.to_rust_type_name(target));
592
593 let doc_comment = if let Some(desc) = &schema.description {
594 quote! { #[doc = #desc] }
595 } else {
596 TokenStream::new()
597 };
598
599 Ok(quote! {
600 #doc_comment
601 pub type #alias_name = #target_type;
602 })
603 } else {
604 Ok(TokenStream::new())
606 }
607 }
608 SchemaType::Array { item_type } => {
609 let array_name = format_ident!("{}", self.to_rust_type_name(&schema.name));
617
618 if let SchemaType::Reference { target } = item_type.as_ref() {
620 if let Some(info) = discriminated_variant_info.get(target) {
621 if !info.is_parent_untagged {
622 let wrapper_name = format_ident!(
624 "{}Item",
625 self.to_rust_type_name(&schema.name)
626 );
627 let variant_type =
628 format_ident!("{}", self.to_rust_type_name(target));
629 let disc_field = &info.discriminator_field;
630 let disc_value = &info.discriminator_value;
631
632 let doc_comment = if let Some(desc) = &schema.description {
633 quote! { #[doc = #desc] }
634 } else {
635 TokenStream::new()
636 };
637
638 return Ok(quote! {
639 #[derive(Debug, Clone, Deserialize, Serialize)]
643 #[serde(tag = #disc_field)]
644 pub enum #wrapper_name {
645 #[serde(rename = #disc_value)]
646 #variant_type(#variant_type),
647 }
648 #doc_comment
649 pub type #array_name = Vec<#wrapper_name>;
650 });
651 }
652 }
653 }
654
655 let inner_type = self.generate_array_item_type(item_type, analysis);
656
657 let doc_comment = if let Some(desc) = &schema.description {
658 quote! { #[doc = #desc] }
659 } else {
660 TokenStream::new()
661 };
662
663 Ok(quote! {
664 #doc_comment
665 pub type #array_name = Vec<#inner_type>;
666 })
667 }
668 SchemaType::Composition { schemas } => {
669 self.generate_composition_struct(schema, schemas)
670 }
671 }
672 }
673
674 fn generate_type_alias(
675 &self,
676 schema: &crate::analysis::AnalyzedSchema,
677 rust_type: &str,
678 ) -> Result<TokenStream> {
679 let type_name = format_ident!("{}", self.to_rust_type_name(&schema.name));
680
681 let base_type = if rust_type.contains("::") {
683 let parts: Vec<&str> = rust_type.split("::").collect();
684 if parts.len() == 2 {
685 let module = format_ident!("{}", parts[0]);
686 let type_name_part = format_ident!("{}", parts[1]);
687 quote! { #module::#type_name_part }
688 } else {
689 let path_parts: Vec<_> = parts.iter().map(|p| format_ident!("{}", p)).collect();
691 quote! { #(#path_parts)::* }
692 }
693 } else {
694 let simple_type = format_ident!("{}", rust_type);
695 quote! { #simple_type }
696 };
697
698 let doc_comment = if let Some(desc) = &schema.description {
699 let sanitized_desc = self.sanitize_doc_comment(desc);
700 quote! { #[doc = #sanitized_desc] }
701 } else {
702 TokenStream::new()
703 };
704
705 Ok(quote! {
706 #doc_comment
707 pub type #type_name = #base_type;
708 })
709 }
710
711 fn generate_extensible_enum(
712 &self,
713 schema: &crate::analysis::AnalyzedSchema,
714 known_values: &[String],
715 ) -> Result<TokenStream> {
716 let enum_name = format_ident!("{}", self.to_rust_type_name(&schema.name));
717
718 let doc_comment = if let Some(desc) = &schema.description {
719 quote! { #[doc = #desc] }
720 } else {
721 TokenStream::new()
722 };
723
724 let known_variants = known_values.iter().map(|value| {
729 let variant_name = self.to_rust_enum_variant(value);
730 let variant_ident = format_ident!("{}", variant_name);
731 quote! {
732 #variant_ident,
733 }
734 });
735
736 let match_arms_de = known_values.iter().map(|value| {
737 let variant_name = self.to_rust_enum_variant(value);
738 let variant_ident = format_ident!("{}", variant_name);
739 quote! {
740 #value => Ok(#enum_name::#variant_ident),
741 }
742 });
743
744 let match_arms_ser = known_values.iter().map(|value| {
745 let variant_name = self.to_rust_enum_variant(value);
746 let variant_ident = format_ident!("{}", variant_name);
747 quote! {
748 #enum_name::#variant_ident => #value,
749 }
750 });
751
752 let derives = if self.config.enable_specta {
753 quote! {
754 #[derive(Debug, Clone, PartialEq, Eq)]
755 #[cfg_attr(feature = "specta", derive(specta::Type))]
756 }
757 } else {
758 quote! {
759 #[derive(Debug, Clone, PartialEq, Eq)]
760 }
761 };
762
763 Ok(quote! {
764 #doc_comment
765 #derives
766 pub enum #enum_name {
767 #(#known_variants)*
768 Custom(String),
770 }
771
772 impl<'de> serde::Deserialize<'de> for #enum_name {
773 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
774 where
775 D: serde::Deserializer<'de>,
776 {
777 let value = String::deserialize(deserializer)?;
778 match value.as_str() {
779 #(#match_arms_de)*
780 _ => Ok(#enum_name::Custom(value)),
781 }
782 }
783 }
784
785 impl serde::Serialize for #enum_name {
786 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
787 where
788 S: serde::Serializer,
789 {
790 let value = match self {
791 #(#match_arms_ser)*
792 #enum_name::Custom(s) => s.as_str(),
793 };
794 serializer.serialize_str(value)
795 }
796 }
797 })
798 }
799
800 fn generate_string_enum(
801 &self,
802 schema: &crate::analysis::AnalyzedSchema,
803 values: &[String],
804 ) -> Result<TokenStream> {
805 let enum_name = format_ident!("{}", self.to_rust_type_name(&schema.name));
806
807 let default_value = schema
809 .default
810 .as_ref()
811 .and_then(|v| v.as_str())
812 .map(|s| s.to_string());
813
814 let variants = values.iter().enumerate().map(|(i, value)| {
815 let variant_name = self.to_rust_enum_variant(value);
817 let variant_ident = format_ident!("{}", variant_name);
818
819 let is_default = if let Some(ref default) = default_value {
821 value == default
822 } else {
823 i == 0 };
825
826 if is_default {
827 quote! {
828 #[default]
829 #[serde(rename = #value)]
830 #variant_ident,
831 }
832 } else {
833 quote! {
834 #[serde(rename = #value)]
835 #variant_ident,
836 }
837 }
838 });
839
840 let doc_comment = if let Some(desc) = &schema.description {
841 quote! { #[doc = #desc] }
842 } else {
843 TokenStream::new()
844 };
845
846 let derives = if self.config.enable_specta {
848 quote! {
849 #[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize, Default)]
850 #[cfg_attr(feature = "specta", derive(specta::Type))]
851 }
852 } else {
853 quote! {
854 #[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize, Default)]
855 }
856 };
857
858 Ok(quote! {
859 #doc_comment
860 #derives
861 pub enum #enum_name {
862 #(#variants)*
863 }
864 })
865 }
866
867 fn generate_struct(
868 &self,
869 schema: &crate::analysis::AnalyzedSchema,
870 properties: &BTreeMap<String, crate::analysis::PropertyInfo>,
871 required: &std::collections::HashSet<String>,
872 additional_properties: bool,
873 analysis: &crate::analysis::SchemaAnalysis,
874 discriminator_info: Option<&DiscriminatedVariantInfo>,
875 ) -> Result<TokenStream> {
876 let struct_name = format_ident!("{}", self.to_rust_type_name(&schema.name));
877
878 let mut sorted_properties: Vec<_> = properties.iter().collect();
880 sorted_properties.sort_by_key(|(name, _)| name.as_str());
881
882 let mut fields: Vec<TokenStream> = sorted_properties
883 .into_iter()
884 .filter(|(field_name, _)| {
885 if let Some(info) = discriminator_info {
889 if !info.is_parent_untagged
890 && field_name.as_str() == info.discriminator_field.as_str()
891 {
892 false } else {
894 true }
896 } else {
897 true }
899 })
900 .map(|(field_name, prop)| {
901 let field_ident = Self::to_field_ident(&self.to_rust_field_name(field_name));
902 let is_required = required.contains(field_name);
903 let field_type =
904 self.generate_field_type(&schema.name, field_name, prop, is_required, analysis);
905
906 let serde_attrs = self.generate_serde_field_attrs(field_name, prop, is_required, analysis);
907 let specta_attrs = self.generate_specta_field_attrs(field_name);
908
909 let doc_comment = if let Some(desc) = &prop.description {
910 let sanitized_desc = self.sanitize_doc_comment(desc);
911 quote! { #[doc = #sanitized_desc] }
912 } else {
913 TokenStream::new()
914 };
915
916 quote! {
917 #doc_comment
918 #serde_attrs
919 #specta_attrs
920 pub #field_ident: #field_type,
921 }
922 })
923 .collect();
924
925 if additional_properties {
927 fields.push(quote! {
928 #[serde(flatten)]
930 pub additional_properties: std::collections::BTreeMap<String, serde_json::Value>,
931 });
932 }
933
934 let doc_comment = if let Some(desc) = &schema.description {
935 quote! { #[doc = #desc] }
936 } else {
937 TokenStream::new()
938 };
939
940 let derives = if self.config.enable_specta {
944 quote! {
945 #[derive(Debug, Clone, Deserialize, Serialize)]
946 #[cfg_attr(feature = "specta", derive(specta::Type))]
947 }
948 } else {
949 quote! {
950 #[derive(Debug, Clone, Deserialize, Serialize)]
951 }
952 };
953
954 Ok(quote! {
955 #doc_comment
956 #derives
957 pub struct #struct_name {
958 #(#fields)*
959 }
960 })
961 }
962
963 fn generate_discriminated_enum(
964 &self,
965 schema: &crate::analysis::AnalyzedSchema,
966 discriminator_field: &str,
967 variants: &[crate::analysis::UnionVariant],
968 analysis: &crate::analysis::SchemaAnalysis,
969 ) -> Result<TokenStream> {
970 let enum_name = format_ident!("{}", self.to_rust_type_name(&schema.name));
971
972 let has_nested_discriminated_union = variants.iter().any(|variant| {
974 if let Some(variant_schema) = analysis.schemas.get(&variant.type_name) {
975 matches!(
976 variant_schema.schema_type,
977 crate::analysis::SchemaType::DiscriminatedUnion { .. }
978 )
979 } else {
980 false
981 }
982 });
983
984 if has_nested_discriminated_union {
986 let schema_refs: Vec<crate::analysis::SchemaRef> = variants
988 .iter()
989 .map(|v| crate::analysis::SchemaRef {
990 target: v.type_name.clone(),
991 nullable: false,
992 })
993 .collect();
994 return self.generate_union_enum(schema, &schema_refs);
995 }
996
997 let enum_variants = variants.iter().map(|variant| {
998 let variant_name = format_ident!("{}", variant.rust_name);
999 let variant_value = &variant.discriminator_value;
1000
1001 let variant_type = format_ident!("{}", self.to_rust_type_name(&variant.type_name));
1004 quote! {
1005 #[serde(rename = #variant_value)]
1006 #variant_name(#variant_type),
1007 }
1008 });
1009
1010 let doc_comment = if let Some(desc) = &schema.description {
1011 quote! { #[doc = #desc] }
1012 } else {
1013 TokenStream::new()
1014 };
1015
1016 let derives = if self.config.enable_specta {
1018 quote! {
1019 #[derive(Debug, Clone, Deserialize, Serialize)]
1020 #[cfg_attr(feature = "specta", derive(specta::Type))]
1021 #[serde(tag = #discriminator_field)]
1022 }
1023 } else {
1024 quote! {
1025 #[derive(Debug, Clone, Deserialize, Serialize)]
1026 #[serde(tag = #discriminator_field)]
1027 }
1028 };
1029
1030 Ok(quote! {
1031 #doc_comment
1032 #derives
1033 pub enum #enum_name {
1034 #(#enum_variants)*
1035 }
1036 })
1037 }
1038
1039 fn should_use_untagged_discriminated_union(
1041 &self,
1042 schema: &crate::analysis::AnalyzedSchema,
1043 analysis: &crate::analysis::SchemaAnalysis,
1044 ) -> bool {
1045 for other_schema in analysis.schemas.values() {
1050 if let crate::analysis::SchemaType::DiscriminatedUnion {
1051 variants,
1052 discriminator_field: _,
1053 } = &other_schema.schema_type
1054 {
1055 for variant in variants {
1056 if variant.type_name == schema.name {
1057 if let crate::analysis::SchemaType::DiscriminatedUnion {
1062 discriminator_field: current_discriminator,
1063 variants: current_variants,
1064 ..
1065 } = &schema.schema_type
1066 {
1067 for current_variant in current_variants {
1069 if let Some(variant_schema) =
1070 analysis.schemas.get(¤t_variant.type_name)
1071 {
1072 if let crate::analysis::SchemaType::Object {
1073 properties, ..
1074 } = &variant_schema.schema_type
1075 {
1076 if properties.contains_key(current_discriminator) {
1077 return false;
1080 }
1081 }
1082 }
1083 }
1084 }
1085
1086 return true;
1088 }
1089 }
1090 }
1091 }
1092 false
1093 }
1094
1095 fn generate_union_enum(
1096 &self,
1097 schema: &crate::analysis::AnalyzedSchema,
1098 variants: &[crate::analysis::SchemaRef],
1099 ) -> Result<TokenStream> {
1100 let enum_name = format_ident!("{}", self.to_rust_type_name(&schema.name));
1101
1102 let mut used_variant_names = std::collections::HashSet::new();
1104 let enum_variants = variants.iter().enumerate().map(|(i, variant)| {
1105 let base_variant_name = self.type_name_to_variant_name(&variant.target);
1107 let variant_name = self.ensure_unique_variant_name_generator(
1108 base_variant_name,
1109 &mut used_variant_names,
1110 i,
1111 );
1112 let variant_name_ident = format_ident!("{}", variant_name);
1113
1114 let variant_type_tokens = if matches!(
1116 variant.target.as_str(),
1117 "bool"
1118 | "i8"
1119 | "i16"
1120 | "i32"
1121 | "i64"
1122 | "i128"
1123 | "u8"
1124 | "u16"
1125 | "u32"
1126 | "u64"
1127 | "u128"
1128 | "f32"
1129 | "f64"
1130 | "String"
1131 ) {
1132 let type_ident = format_ident!("{}", variant.target);
1133 quote! { #type_ident }
1134 } else if variant.target.starts_with("Vec<") && variant.target.ends_with(">") {
1135 let inner = &variant.target[4..variant.target.len() - 1];
1137
1138 if inner.starts_with("Vec<") && inner.ends_with(">") {
1140 let inner_inner = &inner[4..inner.len() - 1];
1141 let inner_inner_type = if matches!(
1142 inner_inner,
1143 "bool"
1144 | "i8"
1145 | "i16"
1146 | "i32"
1147 | "i64"
1148 | "i128"
1149 | "u8"
1150 | "u16"
1151 | "u32"
1152 | "u64"
1153 | "u128"
1154 | "f32"
1155 | "f64"
1156 | "String"
1157 ) {
1158 format_ident!("{}", inner_inner)
1159 } else {
1160 format_ident!("{}", self.to_rust_type_name(inner_inner))
1161 };
1162 quote! { Vec<Vec<#inner_inner_type>> }
1163 } else {
1164 let inner_type = if matches!(
1165 inner,
1166 "bool"
1167 | "i8"
1168 | "i16"
1169 | "i32"
1170 | "i64"
1171 | "i128"
1172 | "u8"
1173 | "u16"
1174 | "u32"
1175 | "u64"
1176 | "u128"
1177 | "f32"
1178 | "f64"
1179 | "String"
1180 ) {
1181 format_ident!("{}", inner)
1182 } else {
1183 format_ident!("{}", self.to_rust_type_name(inner))
1184 };
1185 quote! { Vec<#inner_type> }
1186 }
1187 } else {
1188 let type_ident = format_ident!("{}", self.to_rust_type_name(&variant.target));
1189 quote! { #type_ident }
1190 };
1191
1192 quote! {
1193 #variant_name_ident(#variant_type_tokens),
1194 }
1195 });
1196
1197 let doc_comment = if let Some(desc) = &schema.description {
1198 quote! { #[doc = #desc] }
1199 } else {
1200 TokenStream::new()
1201 };
1202
1203 let derives = if self.config.enable_specta {
1205 quote! {
1206 #[derive(Debug, Clone, Deserialize, Serialize)]
1207 #[cfg_attr(feature = "specta", derive(specta::Type))]
1208 #[serde(untagged)]
1209 }
1210 } else {
1211 quote! {
1212 #[derive(Debug, Clone, Deserialize, Serialize)]
1213 #[serde(untagged)]
1214 }
1215 };
1216
1217 Ok(quote! {
1218 #doc_comment
1219 #derives
1220 pub enum #enum_name {
1221 #(#enum_variants)*
1222 }
1223 })
1224 }
1225
1226 fn generate_field_type(
1227 &self,
1228 schema_name: &str,
1229 field_name: &str,
1230 prop: &crate::analysis::PropertyInfo,
1231 is_required: bool,
1232 analysis: &crate::analysis::SchemaAnalysis,
1233 ) -> TokenStream {
1234 use crate::analysis::SchemaType;
1235
1236 let base_type = match &prop.schema_type {
1237 SchemaType::Primitive { rust_type } => {
1238 if rust_type.contains("::") {
1240 let parts: Vec<&str> = rust_type.split("::").collect();
1241 if parts.len() == 2 {
1242 let module = format_ident!("{}", parts[0]);
1243 let type_name = format_ident!("{}", parts[1]);
1244 quote! { #module::#type_name }
1245 } else {
1246 let path_parts: Vec<_> =
1248 parts.iter().map(|p| format_ident!("{}", p)).collect();
1249 quote! { #(#path_parts)::* }
1250 }
1251 } else {
1252 let type_ident = format_ident!("{}", rust_type);
1253 quote! { #type_ident }
1254 }
1255 }
1256 SchemaType::Reference { target } => {
1257 let target_type = format_ident!("{}", self.to_rust_type_name(target));
1258 if analysis.dependencies.recursive_schemas.contains(target) {
1260 quote! { Box<#target_type> }
1261 } else {
1262 quote! { #target_type }
1263 }
1264 }
1265 SchemaType::Array { item_type } => {
1266 let inner_type = self.generate_array_item_type(item_type, analysis);
1267 quote! { Vec<#inner_type> }
1268 }
1269 _ => {
1270 quote! { serde_json::Value }
1272 }
1273 };
1274
1275 let override_key = format!("{schema_name}.{field_name}");
1277 let is_nullable_override = self
1278 .config
1279 .nullable_field_overrides
1280 .get(&override_key)
1281 .copied()
1282 .unwrap_or(false);
1283
1284 if is_required && !prop.nullable && !is_nullable_override {
1285 if prop.default.is_some() && self.type_lacks_default(&prop.schema_type, analysis) {
1288 quote! { Option<#base_type> }
1289 } else {
1290 base_type
1291 }
1292 } else {
1293 quote! { Option<#base_type> }
1294 }
1295 }
1296
1297 fn generate_serde_field_attrs(
1298 &self,
1299 field_name: &str,
1300 prop: &crate::analysis::PropertyInfo,
1301 is_required: bool,
1302 analysis: &crate::analysis::SchemaAnalysis,
1303 ) -> TokenStream {
1304 let mut attrs = Vec::new();
1305
1306 let rust_field_name = self.to_rust_field_name(field_name);
1309 let comparison_name = rust_field_name
1310 .strip_prefix("r#")
1311 .unwrap_or(&rust_field_name);
1312 if comparison_name != field_name {
1313 attrs.push(quote! { rename = #field_name });
1314 }
1315
1316 if !is_required || prop.nullable {
1318 attrs.push(quote! { skip_serializing_if = "Option::is_none" });
1319 }
1320
1321 if prop.default.is_some() && (is_required && !prop.nullable) {
1325 if !self.type_lacks_default(&prop.schema_type, analysis) {
1326 attrs.push(quote! { default });
1327 }
1328 }
1329
1330 if attrs.is_empty() {
1331 TokenStream::new()
1332 } else {
1333 quote! { #[serde(#(#attrs),*)] }
1334 }
1335 }
1336
1337 fn type_lacks_default(
1341 &self,
1342 schema_type: &crate::analysis::SchemaType,
1343 analysis: &crate::analysis::SchemaAnalysis,
1344 ) -> bool {
1345 use crate::analysis::SchemaType;
1346 match schema_type {
1347 SchemaType::DiscriminatedUnion { .. } | SchemaType::Union { .. } => true,
1348 SchemaType::Reference { target } => {
1349 if let Some(schema) = analysis.schemas.get(target) {
1350 self.type_lacks_default(&schema.schema_type, analysis)
1351 } else {
1352 false
1353 }
1354 }
1355 _ => false,
1356 }
1357 }
1358
1359 fn generate_specta_field_attrs(&self, field_name: &str) -> TokenStream {
1360 if !self.config.enable_specta {
1361 return TokenStream::new();
1362 }
1363
1364 let camel_case_name = self.to_camel_case(field_name);
1366
1367 if camel_case_name != field_name {
1369 quote! { #[cfg_attr(feature = "specta", specta(rename = #camel_case_name))] }
1370 } else {
1371 TokenStream::new()
1372 }
1373 }
1374
1375 fn to_rust_enum_variant(&self, s: &str) -> String {
1376 let mut result = String::new();
1378 let mut next_upper = true;
1379 let mut prev_was_upper = false;
1380
1381 for (i, c) in s.chars().enumerate() {
1382 match c {
1383 'a'..='z' => {
1384 if next_upper {
1385 result.push(c.to_ascii_uppercase());
1386 next_upper = false;
1387 } else {
1388 result.push(c);
1389 }
1390 prev_was_upper = false;
1391 }
1392 'A'..='Z' => {
1393 if next_upper || (!prev_was_upper && i > 0) {
1394 result.push(c);
1396 next_upper = false;
1397 } else {
1398 result.push(c.to_ascii_lowercase());
1400 }
1401 prev_was_upper = true;
1402 }
1403 '0'..='9' => {
1404 result.push(c);
1405 next_upper = false;
1406 prev_was_upper = false;
1407 }
1408 '.' | '-' | '_' | ' ' | '@' | '#' | '$' | '/' | '\\' => {
1409 next_upper = true;
1411 prev_was_upper = false;
1412 }
1413 _ => {
1414 next_upper = true;
1416 prev_was_upper = false;
1417 }
1418 }
1419 }
1420
1421 if result.is_empty() {
1423 result = "Value".to_string();
1424 }
1425
1426 if result.chars().next().is_some_and(|c| c.is_ascii_digit()) {
1428 result = format!("Variant{result}");
1429 }
1430
1431 match result.as_str() {
1433 "Null" => "NullValue".to_string(),
1434 "True" => "TrueValue".to_string(),
1435 "False" => "FalseValue".to_string(),
1436 "Type" => "Type_".to_string(),
1437 "Match" => "Match_".to_string(),
1438 "Fn" => "Fn_".to_string(),
1439 "Impl" => "Impl_".to_string(),
1440 "Trait" => "Trait_".to_string(),
1441 "Struct" => "Struct_".to_string(),
1442 "Enum" => "Enum_".to_string(),
1443 "Mod" => "Mod_".to_string(),
1444 "Use" => "Use_".to_string(),
1445 "Pub" => "Pub_".to_string(),
1446 "Const" => "Const_".to_string(),
1447 "Static" => "Static_".to_string(),
1448 "Let" => "Let_".to_string(),
1449 "Mut" => "Mut_".to_string(),
1450 "Ref" => "Ref_".to_string(),
1451 "Move" => "Move_".to_string(),
1452 "Return" => "Return_".to_string(),
1453 "If" => "If_".to_string(),
1454 "Else" => "Else_".to_string(),
1455 "While" => "While_".to_string(),
1456 "For" => "For_".to_string(),
1457 "Loop" => "Loop_".to_string(),
1458 "Break" => "Break_".to_string(),
1459 "Continue" => "Continue_".to_string(),
1460 "Self" => "Self_".to_string(),
1461 "Super" => "Super_".to_string(),
1462 "Crate" => "Crate_".to_string(),
1463 "Async" => "Async_".to_string(),
1464 "Await" => "Await_".to_string(),
1465 _ => result,
1466 }
1467 }
1468
1469 #[allow(dead_code)]
1470 fn to_rust_identifier(&self, s: &str) -> String {
1471 let mut result = s
1473 .chars()
1474 .map(|c| match c {
1475 'a'..='z' | 'A'..='Z' | '0'..='9' => c,
1476 '.' | '-' | '_' | ' ' | '@' | '#' | '$' | '/' | '\\' => '_',
1477 _ => '_',
1478 })
1479 .collect::<String>();
1480
1481 result = result.trim_matches('_').to_string();
1483
1484 if result.is_empty() {
1486 result = "value".to_string();
1487 }
1488
1489 if result.chars().next().is_some_and(|c| c.is_ascii_digit()) {
1491 result = format!("variant_{result}");
1492 }
1493
1494 match result.as_str() {
1496 "null" => "null_value".to_string(),
1497 "true" => "true_value".to_string(),
1498 "false" => "false_value".to_string(),
1499 "type" => "type_".to_string(),
1500 "match" => "match_".to_string(),
1501 "fn" => "fn_".to_string(),
1502 "impl" => "impl_".to_string(),
1503 "trait" => "trait_".to_string(),
1504 "struct" => "struct_".to_string(),
1505 "enum" => "enum_".to_string(),
1506 "mod" => "mod_".to_string(),
1507 "use" => "use_".to_string(),
1508 "pub" => "pub_".to_string(),
1509 "const" => "const_".to_string(),
1510 "static" => "static_".to_string(),
1511 "let" => "let_".to_string(),
1512 "mut" => "mut_".to_string(),
1513 "ref" => "ref_".to_string(),
1514 "move" => "move_".to_string(),
1515 "return" => "return_".to_string(),
1516 "if" => "if_".to_string(),
1517 "else" => "else_".to_string(),
1518 "while" => "while_".to_string(),
1519 "for" => "for_".to_string(),
1520 "loop" => "loop_".to_string(),
1521 "break" => "break_".to_string(),
1522 "continue" => "continue_".to_string(),
1523 "self" => "self_".to_string(),
1524 "super" => "super_".to_string(),
1525 "crate" => "crate_".to_string(),
1526 "async" => "async_".to_string(),
1527 "await" => "await_".to_string(),
1528 "override" => "override_".to_string(),
1530 "box" => "box_".to_string(),
1531 "dyn" => "dyn_".to_string(),
1532 "where" => "where_".to_string(),
1533 "in" => "in_".to_string(),
1534 "abstract" => "abstract_".to_string(),
1536 "become" => "become_".to_string(),
1537 "do" => "do_".to_string(),
1538 "final" => "final_".to_string(),
1539 "macro" => "macro_".to_string(),
1540 "priv" => "priv_".to_string(),
1541 "try" => "try_".to_string(),
1542 "typeof" => "typeof_".to_string(),
1543 "unsized" => "unsized_".to_string(),
1544 "virtual" => "virtual_".to_string(),
1545 "yield" => "yield_".to_string(),
1546 _ => result,
1547 }
1548 }
1549
1550 fn sanitize_doc_comment(&self, desc: &str) -> String {
1551 let mut result = desc.to_string();
1553
1554 if result.contains('\n')
1562 && (result.contains('{')
1563 || result.contains("```")
1564 || result.contains("Human:")
1565 || result.contains("Assistant:")
1566 || result
1567 .lines()
1568 .any(|line| line.trim().starts_with('"') && line.trim().ends_with('"')))
1569 {
1570 if result.contains("```") {
1572 result = result.replace("```", "```ignore");
1573 } else {
1574 if result.lines().any(|line| {
1576 let trimmed = line.trim();
1577 trimmed.starts_with('"') && trimmed.ends_with('"') && trimmed.len() > 2
1578 }) {
1579 result = format!("```ignore\n{result}\n```");
1580 }
1581 }
1582 }
1583
1584 result
1585 }
1586
1587 pub(crate) fn to_rust_type_name(&self, s: &str) -> String {
1588 let mut result = String::new();
1590 let mut next_upper = true;
1591 let mut prev_was_lower = false;
1592
1593 for c in s.chars() {
1594 match c {
1595 'a'..='z' => {
1596 if next_upper {
1597 result.push(c.to_ascii_uppercase());
1598 next_upper = false;
1599 } else {
1600 result.push(c);
1601 }
1602 prev_was_lower = true;
1603 }
1604 'A'..='Z' => {
1605 result.push(c);
1606 next_upper = false;
1607 prev_was_lower = false;
1608 }
1609 '0'..='9' => {
1610 if prev_was_lower && !result.chars().last().unwrap_or(' ').is_ascii_digit() {
1613 }
1615 result.push(c);
1616 next_upper = false;
1617 prev_was_lower = false;
1618 }
1619 '_' | '-' | '.' | ' ' => {
1620 next_upper = true;
1622 prev_was_lower = false;
1623 }
1624 _ => {
1625 next_upper = true;
1627 prev_was_lower = false;
1628 }
1629 }
1630 }
1631
1632 if result.is_empty() {
1634 result = "Type".to_string();
1635 }
1636
1637 if result.chars().next().is_some_and(|c| c.is_ascii_digit()) {
1639 result = format!("Type{result}");
1640 }
1641
1642 result
1643 }
1644
1645 fn to_rust_field_name(&self, s: &str) -> String {
1646 let mut result = String::new();
1648 let mut prev_was_upper = false;
1649 let mut prev_was_underscore = false;
1650
1651 for (i, c) in s.chars().enumerate() {
1652 match c {
1653 'A'..='Z' => {
1654 if i > 0 && !prev_was_upper && !prev_was_underscore {
1656 result.push('_');
1657 }
1658 result.push(c.to_ascii_lowercase());
1659 prev_was_upper = true;
1660 prev_was_underscore = false;
1661 }
1662 'a'..='z' | '0'..='9' => {
1663 result.push(c);
1664 prev_was_upper = false;
1665 prev_was_underscore = false;
1666 }
1667 '-' | '.' | '_' | '@' | '#' | '$' | ' ' => {
1668 if !prev_was_underscore && !result.is_empty() {
1669 result.push('_');
1670 prev_was_underscore = true;
1671 }
1672 prev_was_upper = false;
1673 }
1674 _ => {
1675 if !prev_was_underscore && !result.is_empty() {
1677 result.push('_');
1678 }
1679 prev_was_upper = false;
1680 prev_was_underscore = true;
1681 }
1682 }
1683 }
1684
1685 let mut result = result.trim_matches('_').to_string();
1687 if result.is_empty() {
1688 return "field".to_string();
1689 }
1690
1691 if result.chars().next().is_some_and(|c| c.is_ascii_digit()) {
1693 result = format!("field_{result}");
1694 }
1695
1696 if Self::is_rust_keyword(&result) {
1698 format!("r#{result}")
1699 } else {
1700 result
1701 }
1702 }
1703
1704 pub fn is_rust_keyword(s: &str) -> bool {
1706 matches!(
1707 s,
1708 "type"
1709 | "match"
1710 | "fn"
1711 | "struct"
1712 | "enum"
1713 | "impl"
1714 | "trait"
1715 | "mod"
1716 | "use"
1717 | "pub"
1718 | "const"
1719 | "static"
1720 | "let"
1721 | "mut"
1722 | "ref"
1723 | "move"
1724 | "return"
1725 | "if"
1726 | "else"
1727 | "while"
1728 | "for"
1729 | "loop"
1730 | "break"
1731 | "continue"
1732 | "self"
1733 | "super"
1734 | "crate"
1735 | "async"
1736 | "await"
1737 | "override"
1738 | "box"
1739 | "dyn"
1740 | "where"
1741 | "in"
1742 | "abstract"
1743 | "become"
1744 | "do"
1745 | "final"
1746 | "macro"
1747 | "priv"
1748 | "try"
1749 | "typeof"
1750 | "unsized"
1751 | "virtual"
1752 | "yield"
1753 )
1754 }
1755
1756 pub fn to_field_ident(name: &str) -> proc_macro2::Ident {
1758 if let Some(raw) = name.strip_prefix("r#") {
1759 proc_macro2::Ident::new_raw(raw, proc_macro2::Span::call_site())
1760 } else {
1761 proc_macro2::Ident::new(name, proc_macro2::Span::call_site())
1762 }
1763 }
1764
1765 fn to_camel_case(&self, s: &str) -> String {
1766 let mut result = String::new();
1768 let mut capitalize_next = false;
1769
1770 for (i, c) in s.chars().enumerate() {
1771 match c {
1772 '_' | '-' | '.' | ' ' => {
1773 capitalize_next = true;
1775 }
1776 'A'..='Z' => {
1777 if i == 0 {
1778 result.push(c.to_ascii_lowercase());
1780 } else if capitalize_next {
1781 result.push(c);
1782 capitalize_next = false;
1783 } else {
1784 result.push(c.to_ascii_lowercase());
1785 }
1786 }
1787 'a'..='z' | '0'..='9' => {
1788 if capitalize_next {
1789 result.push(c.to_ascii_uppercase());
1790 capitalize_next = false;
1791 } else {
1792 result.push(c);
1793 }
1794 }
1795 _ => {
1796 capitalize_next = true;
1798 }
1799 }
1800 }
1801
1802 if result.is_empty() {
1803 return "field".to_string();
1804 }
1805
1806 result
1807 }
1808
1809 fn generate_composition_struct(
1810 &self,
1811 schema: &crate::analysis::AnalyzedSchema,
1812 schemas: &[crate::analysis::SchemaRef],
1813 ) -> Result<TokenStream> {
1814 let struct_name = format_ident!("{}", self.to_rust_type_name(&schema.name));
1815
1816 let fields = schemas.iter().enumerate().map(|(i, schema_ref)| {
1822 let field_name = format_ident!("part_{}", i);
1823 let field_type = format_ident!("{}", self.to_rust_type_name(&schema_ref.target));
1824
1825 quote! {
1826 #[serde(flatten)]
1827 pub #field_name: #field_type,
1828 }
1829 });
1830
1831 let doc_comment = if let Some(desc) = &schema.description {
1832 quote! { #[doc = #desc] }
1833 } else {
1834 TokenStream::new()
1835 };
1836
1837 let derives = if self.config.enable_specta {
1839 quote! {
1840 #[derive(Debug, Clone, Deserialize, Serialize)]
1841 #[cfg_attr(feature = "specta", derive(specta::Type))]
1842 }
1843 } else {
1844 quote! {
1845 #[derive(Debug, Clone, Deserialize, Serialize)]
1846 }
1847 };
1848
1849 Ok(quote! {
1850 #doc_comment
1851 #derives
1852 pub struct #struct_name {
1853 #(#fields)*
1854 }
1855 })
1856 }
1857
1858 #[allow(dead_code)]
1859 fn find_missing_types(&self, analysis: &SchemaAnalysis) -> std::collections::HashSet<String> {
1860 let mut missing = std::collections::HashSet::new();
1861 let defined_types: std::collections::HashSet<String> =
1862 analysis.schemas.keys().cloned().collect();
1863
1864 for schema in analysis.schemas.values() {
1866 match &schema.schema_type {
1867 crate::analysis::SchemaType::Union { variants } => {
1868 for variant in variants {
1869 if !defined_types.contains(&variant.target) {
1870 missing.insert(variant.target.clone());
1871 }
1872 }
1873 }
1874 crate::analysis::SchemaType::DiscriminatedUnion { variants, .. } => {
1875 for variant in variants {
1876 if !defined_types.contains(&variant.type_name) {
1877 missing.insert(variant.type_name.clone());
1878 }
1879 }
1880 }
1881 crate::analysis::SchemaType::Object { properties, .. } => {
1882 let mut sorted_props: Vec<_> = properties.iter().collect();
1884 sorted_props.sort_by_key(|(name, _)| name.as_str());
1885 for (_, prop) in sorted_props {
1886 if let crate::analysis::SchemaType::Reference { target } = &prop.schema_type
1887 {
1888 if !defined_types.contains(target) {
1889 missing.insert(target.clone());
1890 }
1891 }
1892 }
1893 }
1894 crate::analysis::SchemaType::Reference { target } => {
1895 if !defined_types.contains(target) {
1896 missing.insert(target.clone());
1897 }
1898 }
1899 _ => {}
1900 }
1901 }
1902
1903 missing
1904 }
1905
1906 #[allow(clippy::only_used_in_recursion)]
1907 fn generate_array_item_type(
1908 &self,
1909 item_type: &crate::analysis::SchemaType,
1910 analysis: &crate::analysis::SchemaAnalysis,
1911 ) -> TokenStream {
1912 use crate::analysis::SchemaType;
1913
1914 match item_type {
1915 SchemaType::Primitive { rust_type } => {
1916 if rust_type.contains("::") {
1918 let parts: Vec<&str> = rust_type.split("::").collect();
1919 if parts.len() == 2 {
1920 let module = format_ident!("{}", parts[0]);
1921 let type_name = format_ident!("{}", parts[1]);
1922 quote! { #module::#type_name }
1923 } else {
1924 let path_parts: Vec<_> =
1926 parts.iter().map(|p| format_ident!("{}", p)).collect();
1927 quote! { #(#path_parts)::* }
1928 }
1929 } else {
1930 let type_ident = format_ident!("{}", rust_type);
1931 quote! { #type_ident }
1932 }
1933 }
1934 SchemaType::Reference { target } => {
1935 let target_type = format_ident!("{}", self.to_rust_type_name(target));
1936 if analysis.dependencies.recursive_schemas.contains(target) {
1938 quote! { Box<#target_type> }
1939 } else {
1940 quote! { #target_type }
1941 }
1942 }
1943 SchemaType::Array { item_type } => {
1944 let inner_type = self.generate_array_item_type(item_type, analysis);
1946 quote! { Vec<#inner_type> }
1947 }
1948 _ => {
1949 quote! { serde_json::Value }
1951 }
1952 }
1953 }
1954
1955 fn type_name_to_variant_name(&self, type_name: &str) -> String {
1957 match type_name {
1959 "bool" => return "Boolean".to_string(),
1960 "i8" | "i16" | "i32" | "i64" | "i128" => return "Integer".to_string(),
1961 "u8" | "u16" | "u32" | "u64" | "u128" => return "UnsignedInteger".to_string(),
1962 "f32" | "f64" => return "Number".to_string(),
1963 "String" => return "String".to_string(),
1964 _ => {}
1965 }
1966
1967 if type_name.starts_with("Vec<") && type_name.ends_with(">") {
1969 let inner = &type_name[4..type_name.len() - 1];
1970 if inner.starts_with("Vec<") && inner.ends_with(">") {
1972 let inner_inner = &inner[4..inner.len() - 1];
1973 return format!("{}ArrayArray", self.type_name_to_variant_name(inner_inner));
1974 }
1975 return format!("{}Array", self.type_name_to_variant_name(inner));
1976 }
1977
1978 let clean_name = type_name
1984 .trim_end_matches("Type")
1985 .trim_end_matches("Schema")
1986 .trim_end_matches("Item");
1987
1988 self.to_rust_type_name(clean_name)
1990 }
1991
1992 fn ensure_unique_variant_name_generator(
1994 &self,
1995 base_name: String,
1996 used_names: &mut std::collections::HashSet<String>,
1997 fallback_index: usize,
1998 ) -> String {
1999 if used_names.insert(base_name.clone()) {
2000 return base_name;
2001 }
2002
2003 for i in 2..100 {
2005 let numbered_name = format!("{base_name}{i}");
2006 if used_names.insert(numbered_name.clone()) {
2007 return numbered_name;
2008 }
2009 }
2010
2011 let fallback = format!("Variant{fallback_index}");
2013 used_names.insert(fallback.clone());
2014 fallback
2015 }
2016
2017 fn find_request_type_for_operation(
2019 &self,
2020 operation_id: &str,
2021 analysis: &SchemaAnalysis,
2022 ) -> Option<String> {
2023 analysis.operations.get(operation_id).and_then(|op| {
2025 op.request_body
2026 .as_ref()
2027 .and_then(|rb| rb.schema_name().map(|s| s.to_string()))
2028 })
2029 }
2030
2031 fn resolve_streaming_event_type(
2033 &self,
2034 endpoint: &crate::streaming::StreamingEndpoint,
2035 analysis: &SchemaAnalysis,
2036 ) -> Result<String> {
2037 match &endpoint.event_flow {
2038 crate::streaming::EventFlow::Simple => {
2039 if analysis.schemas.contains_key(&endpoint.event_union_type) {
2042 Ok(endpoint.event_union_type.to_string())
2043 } else {
2044 Err(crate::error::GeneratorError::ValidationError(format!(
2045 "Streaming response type '{}' not found in schema for simple streaming endpoint '{}'",
2046 endpoint.event_union_type, endpoint.operation_id
2047 )))
2048 }
2049 }
2050 crate::streaming::EventFlow::StartDeltaStop { .. } => {
2051 if analysis.schemas.contains_key(&endpoint.event_union_type) {
2054 Ok(endpoint.event_union_type.to_string())
2055 } else {
2056 Err(crate::error::GeneratorError::ValidationError(format!(
2057 "Event union type '{}' not found in schema for complex streaming endpoint '{}'",
2058 endpoint.event_union_type, endpoint.operation_id
2059 )))
2060 }
2061 }
2062 }
2063 }
2064
2065 fn generate_streaming_error_types(&self) -> Result<TokenStream> {
2067 Ok(quote! {
2068 #[derive(Debug, thiserror::Error)]
2070 pub enum StreamingError {
2071 #[error("Connection error: {0}")]
2072 Connection(String),
2073 #[error("HTTP error: {status}")]
2074 Http { status: u16 },
2075 #[error("SSE parsing error: {0}")]
2076 Parsing(String),
2077 #[error("Authentication error: {0}")]
2078 Authentication(String),
2079 #[error("Rate limit error: {0}")]
2080 RateLimit(String),
2081 #[error("API error: {0}")]
2082 Api(String),
2083 #[error("Timeout error: {0}")]
2084 Timeout(String),
2085 #[error("JSON serialization/deserialization error: {0}")]
2086 Json(#[from] serde_json::Error),
2087 #[error("Request error: {0}")]
2088 Request(reqwest::Error),
2089 }
2090
2091 impl From<reqwest::header::InvalidHeaderValue> for StreamingError {
2092 fn from(err: reqwest::header::InvalidHeaderValue) -> Self {
2093 StreamingError::Api(format!("Invalid header value: {}", err))
2094 }
2095 }
2096
2097 impl From<reqwest::Error> for StreamingError {
2098 fn from(err: reqwest::Error) -> Self {
2099 if err.is_timeout() {
2100 StreamingError::Timeout(err.to_string())
2101 } else if err.is_status() {
2102 if let Some(status) = err.status() {
2103 StreamingError::Http { status: status.as_u16() }
2104 } else {
2105 StreamingError::Connection(err.to_string())
2106 }
2107 } else {
2108 StreamingError::Request(err)
2109 }
2110 }
2111 }
2112 })
2113 }
2114
2115 fn generate_endpoint_trait(
2117 &self,
2118 endpoint: &crate::streaming::StreamingEndpoint,
2119 analysis: &SchemaAnalysis,
2120 ) -> Result<TokenStream> {
2121 use crate::streaming::HttpMethod;
2122
2123 let trait_name = format_ident!(
2124 "{}StreamingClient",
2125 self.to_rust_type_name(&endpoint.operation_id)
2126 );
2127 let method_name =
2128 format_ident!("stream_{}", self.to_rust_field_name(&endpoint.operation_id));
2129 let event_type =
2130 format_ident!("{}", self.resolve_streaming_event_type(endpoint, analysis)?);
2131
2132 let method_signature = match endpoint.http_method {
2134 HttpMethod::Get => {
2135 let mut param_defs = Vec::new();
2137 for qp in &endpoint.query_parameters {
2138 let param_name = format_ident!("{}", self.to_rust_field_name(&qp.name));
2139 if qp.required {
2140 param_defs.push(quote! { #param_name: &str });
2141 } else {
2142 param_defs.push(quote! { #param_name: Option<&str> });
2143 }
2144 }
2145 quote! {
2146 async fn #method_name(
2147 &self,
2148 #(#param_defs),*
2149 ) -> Result<Pin<Box<dyn Stream<Item = Result<#event_type, Self::Error>> + Send>>, Self::Error>;
2150 }
2151 }
2152 HttpMethod::Post => {
2153 let request_type = self
2155 .find_request_type_for_operation(&endpoint.operation_id, analysis)
2156 .unwrap_or_else(|| "serde_json::Value".to_string());
2157 let request_type_ident = if request_type.contains("::") {
2158 let parts: Vec<&str> = request_type.split("::").collect();
2159 let path_parts: Vec<_> = parts.iter().map(|p| format_ident!("{}", p)).collect();
2160 quote! { #(#path_parts)::* }
2161 } else {
2162 let ident = format_ident!("{}", request_type);
2163 quote! { #ident }
2164 };
2165 quote! {
2166 async fn #method_name(
2167 &self,
2168 request: #request_type_ident,
2169 ) -> Result<Pin<Box<dyn Stream<Item = Result<#event_type, Self::Error>> + Send>>, Self::Error>;
2170 }
2171 }
2172 };
2173
2174 Ok(quote! {
2175 #[async_trait]
2177 pub trait #trait_name {
2178 type Error: std::error::Error + Send + Sync + 'static;
2179
2180 #method_signature
2182 }
2183 })
2184 }
2185
2186 fn generate_streaming_client_impl(
2188 &self,
2189 streaming_config: &crate::streaming::StreamingConfig,
2190 analysis: &SchemaAnalysis,
2191 ) -> Result<TokenStream> {
2192 let client_name = format_ident!(
2193 "{}Client",
2194 self.to_rust_type_name(&streaming_config.client_module_name)
2195 );
2196
2197 let mut struct_fields = vec![
2200 quote! { base_url: String },
2201 quote! { api_key: Option<String> },
2202 quote! { http_client: reqwest::Client },
2203 quote! { custom_headers: std::collections::BTreeMap<String, String> },
2204 ];
2205
2206 let has_optional_headers = !streaming_config
2207 .endpoints
2208 .iter()
2209 .all(|e| e.optional_headers.is_empty());
2210
2211 if has_optional_headers {
2212 struct_fields
2213 .push(quote! { optional_headers: std::collections::BTreeMap<String, String> });
2214 }
2215
2216 let default_base_url = if let Some(ref streaming_config) = self.config.streaming_config {
2219 streaming_config
2220 .endpoints
2221 .first()
2222 .and_then(|e| e.base_url.as_deref())
2223 .unwrap_or("https://api.example.com")
2224 } else {
2225 "https://api.example.com"
2226 };
2227
2228 let constructor_fields = if has_optional_headers {
2230 quote! {
2231 base_url: #default_base_url.to_string(),
2232 api_key: None,
2233 http_client: reqwest::Client::new(),
2234 custom_headers: std::collections::BTreeMap::new(),
2235 optional_headers: std::collections::BTreeMap::new(),
2236 }
2237 } else {
2238 quote! {
2239 base_url: #default_base_url.to_string(),
2240 api_key: None,
2241 http_client: reqwest::Client::new(),
2242 custom_headers: std::collections::BTreeMap::new(),
2243 }
2244 };
2245
2246 let optional_headers_method = if has_optional_headers {
2248 quote! {
2249 pub fn set_optional_headers(&mut self, headers: std::collections::BTreeMap<String, String>) {
2251 self.optional_headers = headers;
2252 }
2253 }
2254 } else {
2255 TokenStream::new()
2256 };
2257
2258 let constructor = quote! {
2259 impl #client_name {
2260 pub fn new() -> Self {
2262 Self {
2263 #constructor_fields
2264 }
2265 }
2266
2267 pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
2269 self.base_url = base_url.into();
2270 self
2271 }
2272
2273 pub fn with_api_key(mut self, api_key: impl Into<String>) -> Self {
2275 self.api_key = Some(api_key.into());
2276 self
2277 }
2278
2279 pub fn with_header(
2281 mut self,
2282 name: impl Into<String>,
2283 value: impl Into<String>,
2284 ) -> Self {
2285 self.custom_headers.insert(name.into(), value.into());
2286 self
2287 }
2288
2289 pub fn with_http_client(mut self, client: reqwest::Client) -> Self {
2291 self.http_client = client;
2292 self
2293 }
2294
2295 #optional_headers_method
2296 }
2297 };
2298
2299 let mut trait_impls = Vec::new();
2301 for endpoint in &streaming_config.endpoints {
2302 let trait_impl = self.generate_endpoint_trait_impl(endpoint, &client_name, analysis)?;
2303 trait_impls.push(trait_impl);
2304 }
2305
2306 let default_impl = quote! {
2308 impl Default for #client_name {
2309 fn default() -> Self {
2310 Self::new()
2311 }
2312 }
2313 };
2314
2315 Ok(quote! {
2316 #[derive(Debug, Clone)]
2318 pub struct #client_name {
2319 #(#struct_fields,)*
2320 }
2321
2322 #constructor
2323
2324 #default_impl
2325
2326 #(#trait_impls)*
2327 })
2328 }
2329
2330 fn generate_endpoint_trait_impl(
2332 &self,
2333 endpoint: &crate::streaming::StreamingEndpoint,
2334 client_name: &proc_macro2::Ident,
2335 analysis: &SchemaAnalysis,
2336 ) -> Result<TokenStream> {
2337 use crate::streaming::HttpMethod;
2338
2339 let trait_name = format_ident!(
2340 "{}StreamingClient",
2341 self.to_rust_type_name(&endpoint.operation_id)
2342 );
2343 let method_name =
2344 format_ident!("stream_{}", self.to_rust_field_name(&endpoint.operation_id));
2345 let event_type =
2346 format_ident!("{}", self.resolve_streaming_event_type(endpoint, analysis)?);
2347
2348 let mut header_setup = Vec::new();
2350 for (name, value) in &endpoint.required_headers {
2351 header_setup.push(quote! {
2352 headers.insert(#name, HeaderValue::from_static(#value));
2353 });
2354 }
2355
2356 if let Some(auth_header) = &endpoint.auth_header {
2359 match auth_header {
2360 crate::streaming::AuthHeader::Bearer(header_name) => {
2361 header_setup.push(quote! {
2362 if let Some(ref api_key) = self.api_key {
2363 headers.insert(#header_name, HeaderValue::from_str(&format!("Bearer {}", api_key))?);
2364 }
2365 });
2366 }
2367 crate::streaming::AuthHeader::ApiKey(header_name) => {
2368 header_setup.push(quote! {
2369 if let Some(ref api_key) = self.api_key {
2370 headers.insert(#header_name, HeaderValue::from_str(api_key)?);
2371 }
2372 });
2373 }
2374 }
2375 } else {
2376 header_setup.push(quote! {
2378 if let Some(ref api_key) = self.api_key {
2379 headers.insert("Authorization", HeaderValue::from_str(&format!("Bearer {}", api_key))?);
2380 }
2381 });
2382 }
2383
2384 header_setup.push(quote! {
2386 for (name, value) in &self.custom_headers {
2387 if let (Ok(header_name), Ok(header_value)) = (reqwest::header::HeaderName::from_bytes(name.as_bytes()), HeaderValue::from_str(value)) {
2388 headers.insert(header_name, header_value);
2389 }
2390 }
2391 });
2392
2393 if !endpoint.optional_headers.is_empty() {
2395 header_setup.push(quote! {
2396 for (key, value) in &self.optional_headers {
2397 if let (Ok(header_name), Ok(header_value)) = (reqwest::header::HeaderName::from_bytes(key.as_bytes()), HeaderValue::from_str(value)) {
2398 headers.insert(header_name, header_value);
2399 }
2400 }
2401 });
2402 }
2403
2404 match endpoint.http_method {
2406 HttpMethod::Get => self.generate_get_streaming_impl(
2407 endpoint,
2408 client_name,
2409 &trait_name,
2410 &method_name,
2411 &event_type,
2412 &header_setup,
2413 ),
2414 HttpMethod::Post => self.generate_post_streaming_impl(
2415 endpoint,
2416 client_name,
2417 &trait_name,
2418 &method_name,
2419 &event_type,
2420 &header_setup,
2421 analysis,
2422 ),
2423 }
2424 }
2425
2426 fn generate_get_streaming_impl(
2428 &self,
2429 endpoint: &crate::streaming::StreamingEndpoint,
2430 client_name: &proc_macro2::Ident,
2431 trait_name: &proc_macro2::Ident,
2432 method_name: &proc_macro2::Ident,
2433 event_type: &proc_macro2::Ident,
2434 header_setup: &[TokenStream],
2435 ) -> Result<TokenStream> {
2436 let path = &endpoint.path;
2437
2438 let mut param_defs = Vec::new();
2440 let mut query_params = Vec::new();
2441
2442 for qp in &endpoint.query_parameters {
2443 let param_name = format_ident!("{}", self.to_rust_field_name(&qp.name));
2444 let param_name_str = &qp.name;
2445
2446 if qp.required {
2447 param_defs.push(quote! { #param_name: &str });
2448 query_params.push(quote! {
2449 url.query_pairs_mut().append_pair(#param_name_str, #param_name);
2450 });
2451 } else {
2452 param_defs.push(quote! { #param_name: Option<&str> });
2453 query_params.push(quote! {
2454 if let Some(v) = #param_name {
2455 url.query_pairs_mut().append_pair(#param_name_str, v);
2456 }
2457 });
2458 }
2459 }
2460
2461 let url_construction = quote! {
2463 let base_url = url::Url::parse(&self.base_url)
2464 .map_err(|e| StreamingError::Connection(format!("Invalid base URL: {}", e)))?;
2465 let path_to_join = #path.trim_start_matches('/');
2466 let mut url = base_url.join(path_to_join)
2467 .map_err(|e| StreamingError::Connection(format!("URL join error: {}", e)))?;
2468 #(#query_params)*
2469 };
2470
2471 let instrument_skip = quote! { #[instrument(skip(self), name = "streaming_get_request")] };
2472
2473 Ok(quote! {
2474 #[async_trait]
2475 impl #trait_name for #client_name {
2476 type Error = StreamingError;
2477
2478 #instrument_skip
2479 async fn #method_name(
2480 &self,
2481 #(#param_defs),*
2482 ) -> Result<Pin<Box<dyn Stream<Item = Result<#event_type, Self::Error>> + Send>>, Self::Error> {
2483 debug!("Starting streaming GET request");
2484
2485 let mut headers = HeaderMap::new();
2486 #(#header_setup)*
2487
2488 #url_construction
2489 let url_str = url.to_string();
2490 debug!("Making streaming GET request to: {}", url_str);
2491
2492 let request_builder = self.http_client
2493 .get(url_str)
2494 .headers(headers);
2495
2496 debug!("Creating SSE stream from request");
2497 let stream = parse_sse_stream::<#event_type>(request_builder).await?;
2498 info!("SSE stream created successfully");
2499 Ok(Box::pin(stream))
2500 }
2501 }
2502 })
2503 }
2504
2505 #[allow(clippy::too_many_arguments)]
2507 fn generate_post_streaming_impl(
2508 &self,
2509 endpoint: &crate::streaming::StreamingEndpoint,
2510 client_name: &proc_macro2::Ident,
2511 trait_name: &proc_macro2::Ident,
2512 method_name: &proc_macro2::Ident,
2513 event_type: &proc_macro2::Ident,
2514 header_setup: &[TokenStream],
2515 analysis: &SchemaAnalysis,
2516 ) -> Result<TokenStream> {
2517 let path = &endpoint.path;
2518
2519 let request_type = self
2521 .find_request_type_for_operation(&endpoint.operation_id, analysis)
2522 .unwrap_or_else(|| "serde_json::Value".to_string());
2523 let request_type_ident = if request_type.contains("::") {
2524 let parts: Vec<&str> = request_type.split("::").collect();
2525 let path_parts: Vec<_> = parts.iter().map(|p| format_ident!("{}", p)).collect();
2526 quote! { #(#path_parts)::* }
2527 } else {
2528 let ident = format_ident!("{}", request_type);
2529 quote! { #ident }
2530 };
2531
2532 let url_construction = quote! {
2534 let base_url = url::Url::parse(&self.base_url)
2535 .map_err(|e| StreamingError::Connection(format!("Invalid base URL: {}", e)))?;
2536 let path_to_join = #path.trim_start_matches('/');
2537 let url = base_url.join(path_to_join)
2538 .map_err(|e| StreamingError::Connection(format!("URL join error: {}", e)))?
2539 .to_string();
2540 };
2541
2542 let stream_param = &endpoint.stream_parameter;
2544 let stream_setup = if stream_param.is_empty() {
2545 quote! {
2546 let streaming_request = request;
2547 }
2548 } else {
2549 quote! {
2550 let mut streaming_request = request;
2552 if let Ok(mut request_value) = serde_json::to_value(&streaming_request) {
2553 if let Some(obj) = request_value.as_object_mut() {
2554 obj.insert(#stream_param.to_string(), serde_json::Value::Bool(true));
2555 }
2556 streaming_request = serde_json::from_value(request_value)?;
2557 }
2558 }
2559 };
2560
2561 Ok(quote! {
2562 #[async_trait]
2563 impl #trait_name for #client_name {
2564 type Error = StreamingError;
2565
2566 #[instrument(skip(self, request), name = "streaming_post_request")]
2567 async fn #method_name(
2568 &self,
2569 request: #request_type_ident,
2570 ) -> Result<Pin<Box<dyn Stream<Item = Result<#event_type, Self::Error>> + Send>>, Self::Error> {
2571 debug!("Starting streaming POST request");
2572
2573 #stream_setup
2574
2575 let mut headers = HeaderMap::new();
2576 #(#header_setup)*
2577
2578 #url_construction
2579 debug!("Making streaming POST request to: {}", url);
2580
2581 let request_builder = self.http_client
2582 .post(&url)
2583 .headers(headers)
2584 .json(&streaming_request);
2585
2586 debug!("Creating SSE stream from request");
2587 let stream = parse_sse_stream::<#event_type>(request_builder).await?;
2588 info!("SSE stream created successfully");
2589 Ok(Box::pin(stream))
2590 }
2591 }
2592 })
2593 }
2594
2595 fn generate_sse_parser_utilities(
2597 &self,
2598 _streaming_config: &crate::streaming::StreamingConfig,
2599 ) -> Result<TokenStream> {
2600 Ok(quote! {
2601 pub async fn parse_sse_stream<T>(
2603 request_builder: reqwest::RequestBuilder
2604 ) -> Result<impl Stream<Item = Result<T, StreamingError>>, StreamingError>
2605 where
2606 T: serde::de::DeserializeOwned + Send + 'static,
2607 {
2608 let mut event_source = reqwest_eventsource::EventSource::new(request_builder).map_err(|e| {
2609 StreamingError::Connection(format!("Failed to create event source: {}", e))
2610 })?;
2611
2612 let stream = event_source.filter_map(|event_result| async move {
2613 match event_result {
2614 Ok(reqwest_eventsource::Event::Open) => {
2615 debug!("SSE connection opened");
2616 None
2617 }
2618 Ok(reqwest_eventsource::Event::Message(message)) => {
2619 if message.event == "ping" {
2621 debug!("Received SSE ping event, skipping");
2622 return None;
2623 }
2624
2625 if message.data.trim().is_empty() {
2627 debug!("Empty SSE data, skipping");
2628 return None;
2629 }
2630
2631 if let Ok(json_value) = serde_json::from_str::<serde_json::Value>(&message.data) {
2633 if let Some(event_type) = json_value.get("event").and_then(|v| v.as_str()) {
2634 if event_type == "ping" {
2635 debug!("Received ping event in JSON data, skipping");
2636 return None;
2637 }
2638 }
2639
2640 match serde_json::from_value::<T>(json_value) {
2642 Ok(parsed_event) => {
2643 Some(Ok(parsed_event))
2644 }
2645 Err(e) => {
2646 if message.data.contains("ping") || message.event.contains("ping") {
2647 debug!("Ignoring ping-related event: {}", message.data);
2648 None
2649 } else {
2650 Some(Err(StreamingError::Parsing(
2651 format!("Failed to parse SSE event: {} (raw: {})", e, message.data)
2652 )))
2653 }
2654 }
2655 }
2656 } else {
2657 Some(Err(StreamingError::Parsing(
2659 format!("SSE event is not valid JSON: {}", message.data)
2660 )))
2661 }
2662 }
2663 Err(e) => {
2664 match e {
2666 reqwest_eventsource::Error::StreamEnded => {
2667 debug!("SSE stream completed normally");
2668 None }
2670 reqwest_eventsource::Error::InvalidStatusCode(status, response) => {
2671 let status_code = status.as_u16();
2673
2674 let error_body = match response.text().await {
2676 Ok(body) => body,
2677 Err(_) => "Failed to read error response body".to_string()
2678 };
2679
2680 error!("SSE connection error - HTTP {}: {}", status_code, error_body);
2681
2682 let detailed_error = format!(
2683 "HTTP {} error: {}",
2684 status_code,
2685 error_body
2686 );
2687
2688 Some(Err(StreamingError::Connection(detailed_error)))
2689 }
2690 _ => {
2691 let error_str = e.to_string();
2692 if error_str.contains("stream closed") {
2693 debug!("SSE stream closed");
2694 None
2695 } else {
2696 error!("SSE connection error: {}", e);
2697 Some(Err(StreamingError::Connection(error_str)))
2698 }
2699 }
2700 }
2701 }
2702 }
2703 });
2704
2705 Ok(stream)
2706 }
2707 })
2708 }
2709
2710 fn generate_reconnection_utilities(
2712 &self,
2713 reconnect_config: &crate::streaming::ReconnectionConfig,
2714 ) -> Result<TokenStream> {
2715 let max_retries = reconnect_config.max_retries;
2716 let initial_delay = reconnect_config.initial_delay_ms;
2717 let max_delay = reconnect_config.max_delay_ms;
2718 let backoff_multiplier = reconnect_config.backoff_multiplier;
2719
2720 Ok(quote! {
2721 #[derive(Debug, Clone)]
2723 pub struct ReconnectionManager {
2724 max_retries: u32,
2725 initial_delay_ms: u64,
2726 max_delay_ms: u64,
2727 backoff_multiplier: f64,
2728 current_attempt: u32,
2729 }
2730
2731 impl ReconnectionManager {
2732 pub fn new() -> Self {
2734 Self {
2735 max_retries: #max_retries,
2736 initial_delay_ms: #initial_delay,
2737 max_delay_ms: #max_delay,
2738 backoff_multiplier: #backoff_multiplier,
2739 current_attempt: 0,
2740 }
2741 }
2742
2743 pub fn should_retry(&self) -> bool {
2745 self.current_attempt < self.max_retries
2746 }
2747
2748 pub fn next_retry_delay(&mut self) -> Duration {
2750 if !self.should_retry() {
2751 return Duration::from_secs(0);
2752 }
2753
2754 let delay_ms = (self.initial_delay_ms as f64
2755 * self.backoff_multiplier.powi(self.current_attempt as i32)) as u64;
2756 let delay_ms = delay_ms.min(self.max_delay_ms);
2757
2758 self.current_attempt += 1;
2759 Duration::from_millis(delay_ms)
2760 }
2761
2762 pub fn reset(&mut self) {
2764 self.current_attempt = 0;
2765 }
2766
2767 pub fn current_attempt(&self) -> u32 {
2769 self.current_attempt
2770 }
2771 }
2772
2773 impl Default for ReconnectionManager {
2774 fn default() -> Self {
2775 Self::new()
2776 }
2777 }
2778 })
2779 }
2780}