1use crate::{GeneratorError, Result, analysis::SchemaAnalysis, streaming::StreamingConfig};
2use proc_macro2::TokenStream;
3use quote::{format_ident, quote};
4use std::collections::BTreeMap;
5use std::path::PathBuf;
6
7#[derive(Clone)]
9struct DiscriminatedVariantInfo {
10 discriminator_field: String,
12 discriminator_value: String,
14 is_parent_untagged: bool,
16}
17
18#[derive(Debug, Clone)]
19pub struct GeneratorConfig {
20 pub spec_path: PathBuf,
22 pub output_dir: PathBuf,
24 pub module_name: String,
26 pub enable_sse_client: bool,
28 pub enable_async_client: bool,
30 pub enable_specta: bool,
32 pub type_mappings: BTreeMap<String, String>,
34 pub streaming_config: Option<StreamingConfig>,
36 pub nullable_field_overrides: BTreeMap<String, bool>,
39 pub schema_extensions: Vec<PathBuf>,
42 pub http_client_config: Option<crate::http_config::HttpClientConfig>,
44 pub retry_config: Option<crate::http_config::RetryConfig>,
46 pub tracing_enabled: bool,
48 pub auth_config: Option<crate::http_config::AuthConfig>,
50 pub enable_registry: bool,
52 pub registry_only: bool,
54}
55
56impl Default for GeneratorConfig {
57 fn default() -> Self {
58 Self {
59 spec_path: "openapi.json".into(),
60 output_dir: "src/gen".into(),
61 module_name: "api_types".to_string(),
62 enable_sse_client: true,
63 enable_async_client: true,
64 enable_specta: false,
65 type_mappings: default_type_mappings(),
66 streaming_config: None,
67 nullable_field_overrides: BTreeMap::new(),
68 schema_extensions: Vec::new(),
69 http_client_config: None,
70 retry_config: None,
71 tracing_enabled: true,
72 auth_config: None,
73 enable_registry: false,
74 registry_only: false,
75 }
76 }
77}
78
79pub fn default_type_mappings() -> BTreeMap<String, String> {
80 let mut mappings = BTreeMap::new();
81 mappings.insert("integer".to_string(), "i64".to_string());
82 mappings.insert("number".to_string(), "f64".to_string());
83 mappings.insert("string".to_string(), "String".to_string());
84 mappings.insert("boolean".to_string(), "bool".to_string());
85 mappings
86}
87
88#[derive(Debug, Clone)]
90pub struct GeneratedFile {
91 pub path: PathBuf,
93 pub content: String,
95}
96
97#[derive(Debug, Clone)]
99pub struct GenerationResult {
100 pub files: Vec<GeneratedFile>,
102 pub mod_file: GeneratedFile,
104}
105
106pub struct CodeGenerator {
107 config: GeneratorConfig,
108}
109
110impl CodeGenerator {
111 pub fn new(config: GeneratorConfig) -> Self {
112 Self { config }
113 }
114
115 pub fn config(&self) -> &GeneratorConfig {
117 &self.config
118 }
119
120 pub fn generate_all(&self, analysis: &mut SchemaAnalysis) -> Result<GenerationResult> {
122 let mut files = Vec::new();
123
124 if !self.config.registry_only {
125 let types_content = self.generate_types(analysis)?;
127 files.push(GeneratedFile {
128 path: "types.rs".into(),
129 content: types_content,
130 });
131
132 if let Some(ref streaming_config) = self.config.streaming_config {
134 let streaming_content =
135 self.generate_streaming_client(streaming_config, analysis)?;
136 files.push(GeneratedFile {
137 path: "streaming.rs".into(),
138 content: streaming_content,
139 });
140 }
141
142 if self.config.enable_async_client {
144 let http_content = self.generate_http_client(analysis)?;
145 files.push(GeneratedFile {
146 path: "client.rs".into(),
147 content: http_content,
148 });
149 }
150 }
151
152 if self.config.enable_registry || self.config.registry_only {
154 let registry_content = self.generate_registry(analysis)?;
155 files.push(GeneratedFile {
156 path: "registry.rs".into(),
157 content: registry_content,
158 });
159 }
160
161 let mod_content = self.generate_mod_file(&files)?;
163 let mod_file = GeneratedFile {
164 path: "mod.rs".into(),
165 content: mod_content,
166 };
167
168 Ok(GenerationResult { files, mod_file })
169 }
170
171 pub fn generate(&self, analysis: &mut SchemaAnalysis) -> Result<String> {
173 self.generate_types(analysis)
174 }
175
176 fn generate_types(&self, analysis: &mut SchemaAnalysis) -> Result<String> {
178 let mut type_definitions = TokenStream::new();
179
180 let mut discriminated_variant_info: BTreeMap<String, DiscriminatedVariantInfo> =
183 BTreeMap::new();
184
185 let mut sorted_schemas: Vec<_> = analysis.schemas.iter().collect();
187 sorted_schemas.sort_by_key(|(name, _)| name.as_str());
188
189 for (_parent_name, schema) in sorted_schemas {
190 if let crate::analysis::SchemaType::DiscriminatedUnion {
191 variants,
192 discriminator_field,
193 } = &schema.schema_type
194 {
195 let is_parent_untagged =
197 self.should_use_untagged_discriminated_union(schema, analysis);
198
199 for variant in variants {
200 if let Some(variant_schema) = analysis.schemas.get(&variant.type_name) {
203 if let crate::analysis::SchemaType::Object { properties, .. } =
204 &variant_schema.schema_type
205 {
206 if properties.contains_key(discriminator_field) {
207 discriminated_variant_info.insert(
208 variant.type_name.clone(),
209 DiscriminatedVariantInfo {
210 discriminator_field: discriminator_field.clone(),
211 discriminator_value: variant.discriminator_value.clone(),
212 is_parent_untagged,
213 },
214 );
215 }
216 }
217 }
218 }
219 }
220 }
221
222 let generation_order = analysis.dependencies.topological_sort()?;
224
225 let mut processed = std::collections::HashSet::new();
227
228 for schema_name in generation_order {
230 if let Some(schema) = analysis.schemas.get(&schema_name) {
231 let type_def =
232 self.generate_type_definition(schema, analysis, &discriminated_variant_info)?;
233 if !type_def.is_empty() {
234 type_definitions.extend(type_def);
235 }
236 processed.insert(schema_name);
237 }
238 }
239
240 let mut remaining_schemas: Vec<_> = analysis
243 .schemas
244 .iter()
245 .filter(|(name, _)| !processed.contains(*name))
246 .collect();
247 remaining_schemas.sort_by_key(|(name, _)| name.as_str());
248
249 for (_schema_name, schema) in remaining_schemas {
250 let type_def =
251 self.generate_type_definition(schema, analysis, &discriminated_variant_info)?;
252 if !type_def.is_empty() {
253 type_definitions.extend(type_def);
254 }
255 }
256
257 let generated = quote! {
259 #![allow(clippy::large_enum_variant)]
265 #![allow(clippy::format_in_format_args)]
266 #![allow(clippy::let_unit_value)]
267 #![allow(unreachable_patterns)]
268
269 use serde::{Deserialize, Serialize};
270
271 #type_definitions
272 };
273
274 let syntax_tree = syn::parse2::<syn::File>(generated).map_err(|e| {
276 GeneratorError::CodeGenError(format!("Failed to parse generated code: {e}"))
277 })?;
278
279 let formatted = prettyplease::unparse(&syntax_tree);
280
281 Ok(formatted)
282 }
283
284 fn generate_streaming_client(
286 &self,
287 streaming_config: &StreamingConfig,
288 analysis: &SchemaAnalysis,
289 ) -> Result<String> {
290 let mut client_code = TokenStream::new();
291
292 let imports = quote! {
294 #![allow(clippy::format_in_format_args)]
299 #![allow(clippy::let_unit_value)]
300 #![allow(unused_mut)]
301
302 use super::types::*;
303 use async_trait::async_trait;
304 use futures_util::{Stream, StreamExt};
305 use std::pin::Pin;
306 use std::time::Duration;
307 use reqwest::header::{HeaderMap, HeaderValue};
308 use tracing::{debug, error, info, warn, instrument};
309 };
310 client_code.extend(imports);
311
312 if streaming_config.generate_client {
314 let error_types = self.generate_streaming_error_types()?;
315 client_code.extend(error_types);
316 }
317
318 for endpoint in &streaming_config.endpoints {
320 let trait_code = self.generate_endpoint_trait(endpoint, analysis)?;
321 client_code.extend(trait_code);
322 }
323
324 if streaming_config.generate_client {
326 let client_impl = self.generate_streaming_client_impl(streaming_config, analysis)?;
327 client_code.extend(client_impl);
328 }
329
330 if streaming_config.event_parser_helpers {
332 let parser_code = self.generate_sse_parser_utilities(streaming_config)?;
333 client_code.extend(parser_code);
334 }
335
336 if let Some(reconnect_config) = &streaming_config.reconnection_config {
338 let reconnect_code = self.generate_reconnection_utilities(reconnect_config)?;
339 client_code.extend(reconnect_code);
340 }
341
342 let syntax_tree = syn::parse2::<syn::File>(client_code).map_err(|e| {
343 GeneratorError::CodeGenError(format!("Failed to parse streaming client code: {e}"))
344 })?;
345
346 Ok(prettyplease::unparse(&syntax_tree))
347 }
348
349 pub fn generate_http_client(&self, analysis: &SchemaAnalysis) -> Result<String> {
351 let error_types = self.generate_http_error_types();
352 let client_struct = self.generate_http_client_struct();
353 let operation_methods = self.generate_operation_methods(analysis);
354
355 let generated = quote! {
356 #![allow(clippy::format_in_format_args)]
361 #![allow(clippy::let_unit_value)]
362
363 use super::types::*;
364
365 #error_types
366
367 #client_struct
368
369 #operation_methods
370 };
371
372 let syntax_tree = syn::parse2::<syn::File>(generated).map_err(|e| {
373 GeneratorError::CodeGenError(format!("Failed to parse HTTP client code: {e}"))
374 })?;
375
376 Ok(prettyplease::unparse(&syntax_tree))
377 }
378
379 fn generate_http_error_types(&self) -> TokenStream {
381 quote! {
382 use thiserror::Error;
383
384 #[derive(Error, Debug)]
386 pub enum HttpError {
387 #[error("Network error: {0}")]
389 Network(#[from] reqwest::Error),
390
391 #[error("Middleware error: {0}")]
393 Middleware(#[from] reqwest_middleware::Error),
394
395 #[error("Failed to serialize request: {0}")]
397 Serialization(String),
398
399 #[error("Failed to deserialize response: {0}")]
401 Deserialization(String),
402
403 #[error("HTTP error {status}: {message}")]
405 Http {
406 status: u16,
407 message: String,
408 body: Option<String>,
409 },
410
411 #[error("Authentication error: {0}")]
413 Auth(String),
414
415 #[error("Request timeout")]
417 Timeout,
418
419 #[error("Configuration error: {0}")]
421 Config(String),
422
423 #[error("{0}")]
425 Other(String),
426 }
427
428 impl HttpError {
429 pub fn from_status(status: u16, message: impl Into<String>, body: Option<String>) -> Self {
431 Self::Http {
432 status,
433 message: message.into(),
434 body,
435 }
436 }
437
438 pub fn serialization_error(error: impl std::fmt::Display) -> Self {
440 Self::Serialization(error.to_string())
441 }
442
443 pub fn deserialization_error(error: impl std::fmt::Display) -> Self {
445 Self::Deserialization(error.to_string())
446 }
447
448 pub fn is_client_error(&self) -> bool {
450 matches!(self, Self::Http { status, .. } if *status >= 400 && *status < 500)
451 }
452
453 pub fn is_server_error(&self) -> bool {
455 matches!(self, Self::Http { status, .. } if *status >= 500 && *status < 600)
456 }
457
458 pub fn is_retryable(&self) -> bool {
460 match self {
461 Self::Network(_) => true,
462 Self::Middleware(_) => true,
463 Self::Timeout => true,
464 Self::Http { status, .. } => {
465 matches!(status, 429 | 500 | 502 | 503 | 504)
467 }
468 _ => false,
469 }
470 }
471 }
472
473 pub type HttpResult<T> = Result<T, HttpError>;
475 }
476 }
477
478 fn generate_mod_file(&self, files: &[GeneratedFile]) -> Result<String> {
480 let mut module_declarations = Vec::new();
481 let mut pub_uses = Vec::new();
482
483 for file in files {
484 if let Some(module_name) = file.path.file_stem().and_then(|s| s.to_str()) {
485 if module_name != "mod" {
486 module_declarations.push(format!("pub mod {module_name};"));
487 pub_uses.push(format!("pub use {module_name}::*;"));
488 }
489 }
490 }
491
492 let content = format!(
493 r#"//! Generated API modules
494//!
495//! This module exports all generated API types and clients.
496//! Do not edit manually - regenerate using the appropriate script.
497
498#![allow(unused_imports)]
499
500{}
501
502{}
503"#,
504 module_declarations.join("\n"),
505 pub_uses.join("\n")
506 );
507
508 Ok(content)
509 }
510
511 pub fn write_files(&self, result: &GenerationResult) -> Result<()> {
513 use std::fs;
514
515 fs::create_dir_all(&self.config.output_dir)?;
517
518 for file in &result.files {
520 let file_path = self.config.output_dir.join(&file.path);
521 fs::write(&file_path, &file.content)?;
522 }
523
524 let mod_path = self.config.output_dir.join(&result.mod_file.path);
526 fs::write(&mod_path, &result.mod_file.content)?;
527
528 Ok(())
529 }
530
531 fn generate_type_definition(
532 &self,
533 schema: &crate::analysis::AnalyzedSchema,
534 analysis: &crate::analysis::SchemaAnalysis,
535 discriminated_variant_info: &BTreeMap<String, DiscriminatedVariantInfo>,
536 ) -> Result<TokenStream> {
537 use crate::analysis::SchemaType;
538
539 match &schema.schema_type {
540 SchemaType::Primitive { rust_type } => {
541 self.generate_type_alias(schema, rust_type)
543 }
544 SchemaType::StringEnum { values } => self.generate_string_enum(schema, values),
545 SchemaType::ExtensibleEnum { known_values } => {
546 self.generate_extensible_enum(schema, known_values)
547 }
548 SchemaType::Object {
549 properties,
550 required,
551 additional_properties,
552 } => self.generate_struct(
553 schema,
554 properties,
555 required,
556 *additional_properties,
557 analysis,
558 discriminated_variant_info.get(&schema.name),
559 ),
560 SchemaType::DiscriminatedUnion {
561 discriminator_field,
562 variants,
563 } => {
564 if self.should_use_untagged_discriminated_union(schema, analysis) {
566 let schema_refs: Vec<crate::analysis::SchemaRef> = variants
568 .iter()
569 .map(|v| crate::analysis::SchemaRef {
570 target: v.type_name.clone(),
571 nullable: false,
572 })
573 .collect();
574 self.generate_union_enum(schema, &schema_refs)
575 } else {
576 self.generate_discriminated_enum(
577 schema,
578 discriminator_field,
579 variants,
580 analysis,
581 )
582 }
583 }
584 SchemaType::Union { variants } => self.generate_union_enum(schema, variants),
585 SchemaType::Reference { target } => {
586 if schema.name != *target {
589 let alias_name = format_ident!("{}", self.to_rust_type_name(&schema.name));
591 let target_type = format_ident!("{}", self.to_rust_type_name(target));
592
593 let doc_comment = if let Some(desc) = &schema.description {
594 quote! { #[doc = #desc] }
595 } else {
596 TokenStream::new()
597 };
598
599 Ok(quote! {
600 #doc_comment
601 pub type #alias_name = #target_type;
602 })
603 } else {
604 Ok(TokenStream::new())
606 }
607 }
608 SchemaType::Array { item_type } => {
609 let array_name = format_ident!("{}", self.to_rust_type_name(&schema.name));
617
618 if let SchemaType::Reference { target } = item_type.as_ref() {
620 if let Some(info) = discriminated_variant_info.get(target) {
621 if !info.is_parent_untagged {
622 let wrapper_name =
624 format_ident!("{}Item", self.to_rust_type_name(&schema.name));
625 let variant_type = format_ident!("{}", self.to_rust_type_name(target));
626 let disc_field = &info.discriminator_field;
627 let disc_value = &info.discriminator_value;
628
629 let doc_comment = if let Some(desc) = &schema.description {
630 quote! { #[doc = #desc] }
631 } else {
632 TokenStream::new()
633 };
634
635 return Ok(quote! {
636 #[derive(Debug, Clone, Deserialize, Serialize)]
640 #[serde(tag = #disc_field)]
641 pub enum #wrapper_name {
642 #[serde(rename = #disc_value)]
643 #variant_type(#variant_type),
644 }
645 #doc_comment
646 pub type #array_name = Vec<#wrapper_name>;
647 });
648 }
649 }
650 }
651
652 let inner_type = self.generate_array_item_type(item_type, analysis);
653
654 let doc_comment = if let Some(desc) = &schema.description {
655 quote! { #[doc = #desc] }
656 } else {
657 TokenStream::new()
658 };
659
660 Ok(quote! {
661 #doc_comment
662 pub type #array_name = Vec<#inner_type>;
663 })
664 }
665 SchemaType::Composition { schemas } => {
666 self.generate_composition_struct(schema, schemas)
667 }
668 }
669 }
670
671 fn generate_type_alias(
672 &self,
673 schema: &crate::analysis::AnalyzedSchema,
674 rust_type: &str,
675 ) -> Result<TokenStream> {
676 let type_name = format_ident!("{}", self.to_rust_type_name(&schema.name));
677
678 let base_type = if rust_type.contains("::") {
680 let parts: Vec<&str> = rust_type.split("::").collect();
681 if parts.len() == 2 {
682 let module = format_ident!("{}", parts[0]);
683 let type_name_part = format_ident!("{}", parts[1]);
684 quote! { #module::#type_name_part }
685 } else {
686 let path_parts: Vec<_> = parts.iter().map(|p| format_ident!("{}", p)).collect();
688 quote! { #(#path_parts)::* }
689 }
690 } else {
691 let simple_type = format_ident!("{}", rust_type);
692 quote! { #simple_type }
693 };
694
695 let doc_comment = if let Some(desc) = &schema.description {
696 let sanitized_desc = self.sanitize_doc_comment(desc);
697 quote! { #[doc = #sanitized_desc] }
698 } else {
699 TokenStream::new()
700 };
701
702 Ok(quote! {
703 #doc_comment
704 pub type #type_name = #base_type;
705 })
706 }
707
708 fn generate_extensible_enum(
709 &self,
710 schema: &crate::analysis::AnalyzedSchema,
711 known_values: &[String],
712 ) -> Result<TokenStream> {
713 let enum_name = format_ident!("{}", self.to_rust_type_name(&schema.name));
714
715 let doc_comment = if let Some(desc) = &schema.description {
716 quote! { #[doc = #desc] }
717 } else {
718 TokenStream::new()
719 };
720
721 let known_variants = known_values.iter().map(|value| {
726 let variant_name = self.to_rust_enum_variant(value);
727 let variant_ident = format_ident!("{}", variant_name);
728 quote! {
729 #variant_ident,
730 }
731 });
732
733 let match_arms_de = known_values.iter().map(|value| {
734 let variant_name = self.to_rust_enum_variant(value);
735 let variant_ident = format_ident!("{}", variant_name);
736 quote! {
737 #value => Ok(#enum_name::#variant_ident),
738 }
739 });
740
741 let match_arms_ser = known_values.iter().map(|value| {
742 let variant_name = self.to_rust_enum_variant(value);
743 let variant_ident = format_ident!("{}", variant_name);
744 quote! {
745 #enum_name::#variant_ident => #value,
746 }
747 });
748
749 let derives = if self.config.enable_specta {
750 quote! {
751 #[derive(Debug, Clone, PartialEq, Eq)]
752 #[cfg_attr(feature = "specta", derive(specta::Type))]
753 }
754 } else {
755 quote! {
756 #[derive(Debug, Clone, PartialEq, Eq)]
757 }
758 };
759
760 Ok(quote! {
761 #doc_comment
762 #derives
763 pub enum #enum_name {
764 #(#known_variants)*
765 Custom(String),
767 }
768
769 impl<'de> serde::Deserialize<'de> for #enum_name {
770 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
771 where
772 D: serde::Deserializer<'de>,
773 {
774 let value = String::deserialize(deserializer)?;
775 match value.as_str() {
776 #(#match_arms_de)*
777 _ => Ok(#enum_name::Custom(value)),
778 }
779 }
780 }
781
782 impl serde::Serialize for #enum_name {
783 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
784 where
785 S: serde::Serializer,
786 {
787 let value = match self {
788 #(#match_arms_ser)*
789 #enum_name::Custom(s) => s.as_str(),
790 };
791 serializer.serialize_str(value)
792 }
793 }
794 })
795 }
796
797 fn generate_string_enum(
798 &self,
799 schema: &crate::analysis::AnalyzedSchema,
800 values: &[String],
801 ) -> Result<TokenStream> {
802 let enum_name = format_ident!("{}", self.to_rust_type_name(&schema.name));
803
804 let default_value = schema
806 .default
807 .as_ref()
808 .and_then(|v| v.as_str())
809 .map(|s| s.to_string());
810
811 let variants = values.iter().enumerate().map(|(i, value)| {
812 let variant_name = self.to_rust_enum_variant(value);
814 let variant_ident = format_ident!("{}", variant_name);
815
816 let is_default = if let Some(ref default) = default_value {
818 value == default
819 } else {
820 i == 0 };
822
823 if is_default {
824 quote! {
825 #[default]
826 #[serde(rename = #value)]
827 #variant_ident,
828 }
829 } else {
830 quote! {
831 #[serde(rename = #value)]
832 #variant_ident,
833 }
834 }
835 });
836
837 let doc_comment = if let Some(desc) = &schema.description {
838 quote! { #[doc = #desc] }
839 } else {
840 TokenStream::new()
841 };
842
843 let derives = if self.config.enable_specta {
845 quote! {
846 #[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize, Default)]
847 #[cfg_attr(feature = "specta", derive(specta::Type))]
848 }
849 } else {
850 quote! {
851 #[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize, Default)]
852 }
853 };
854
855 Ok(quote! {
856 #doc_comment
857 #derives
858 pub enum #enum_name {
859 #(#variants)*
860 }
861 })
862 }
863
864 fn generate_struct(
865 &self,
866 schema: &crate::analysis::AnalyzedSchema,
867 properties: &BTreeMap<String, crate::analysis::PropertyInfo>,
868 required: &std::collections::HashSet<String>,
869 additional_properties: bool,
870 analysis: &crate::analysis::SchemaAnalysis,
871 discriminator_info: Option<&DiscriminatedVariantInfo>,
872 ) -> Result<TokenStream> {
873 let struct_name = format_ident!("{}", self.to_rust_type_name(&schema.name));
874
875 let mut sorted_properties: Vec<_> = properties.iter().collect();
877 sorted_properties.sort_by_key(|(name, _)| name.as_str());
878
879 let mut fields: Vec<TokenStream> = sorted_properties
880 .into_iter()
881 .filter(|(field_name, _)| {
882 if let Some(info) = discriminator_info {
886 if !info.is_parent_untagged
887 && field_name.as_str() == info.discriminator_field.as_str()
888 {
889 false } else {
891 true }
893 } else {
894 true }
896 })
897 .map(|(field_name, prop)| {
898 let field_ident = Self::to_field_ident(&self.to_rust_field_name(field_name));
899 let is_required = required.contains(field_name);
900 let field_type =
901 self.generate_field_type(&schema.name, field_name, prop, is_required, analysis);
902
903 let serde_attrs =
904 self.generate_serde_field_attrs(field_name, prop, is_required, analysis);
905 let specta_attrs = self.generate_specta_field_attrs(field_name);
906
907 let doc_comment = if let Some(desc) = &prop.description {
908 let sanitized_desc = self.sanitize_doc_comment(desc);
909 quote! { #[doc = #sanitized_desc] }
910 } else {
911 TokenStream::new()
912 };
913
914 quote! {
915 #doc_comment
916 #serde_attrs
917 #specta_attrs
918 pub #field_ident: #field_type,
919 }
920 })
921 .collect();
922
923 if additional_properties {
925 fields.push(quote! {
926 #[serde(flatten)]
928 pub additional_properties: std::collections::BTreeMap<String, serde_json::Value>,
929 });
930 }
931
932 let doc_comment = if let Some(desc) = &schema.description {
933 quote! { #[doc = #desc] }
934 } else {
935 TokenStream::new()
936 };
937
938 let derives = if self.config.enable_specta {
942 quote! {
943 #[derive(Debug, Clone, Deserialize, Serialize)]
944 #[cfg_attr(feature = "specta", derive(specta::Type))]
945 }
946 } else {
947 quote! {
948 #[derive(Debug, Clone, Deserialize, Serialize)]
949 }
950 };
951
952 Ok(quote! {
953 #doc_comment
954 #derives
955 pub struct #struct_name {
956 #(#fields)*
957 }
958 })
959 }
960
961 fn generate_discriminated_enum(
962 &self,
963 schema: &crate::analysis::AnalyzedSchema,
964 discriminator_field: &str,
965 variants: &[crate::analysis::UnionVariant],
966 analysis: &crate::analysis::SchemaAnalysis,
967 ) -> Result<TokenStream> {
968 let enum_name = format_ident!("{}", self.to_rust_type_name(&schema.name));
969
970 let has_nested_discriminated_union = variants.iter().any(|variant| {
972 if let Some(variant_schema) = analysis.schemas.get(&variant.type_name) {
973 matches!(
974 variant_schema.schema_type,
975 crate::analysis::SchemaType::DiscriminatedUnion { .. }
976 )
977 } else {
978 false
979 }
980 });
981
982 if has_nested_discriminated_union {
984 let schema_refs: Vec<crate::analysis::SchemaRef> = variants
986 .iter()
987 .map(|v| crate::analysis::SchemaRef {
988 target: v.type_name.clone(),
989 nullable: false,
990 })
991 .collect();
992 return self.generate_union_enum(schema, &schema_refs);
993 }
994
995 let enum_variants = variants.iter().map(|variant| {
996 let variant_name = format_ident!("{}", variant.rust_name);
997 let variant_value = &variant.discriminator_value;
998
999 let variant_type = format_ident!("{}", self.to_rust_type_name(&variant.type_name));
1002 quote! {
1003 #[serde(rename = #variant_value)]
1004 #variant_name(#variant_type),
1005 }
1006 });
1007
1008 let doc_comment = if let Some(desc) = &schema.description {
1009 quote! { #[doc = #desc] }
1010 } else {
1011 TokenStream::new()
1012 };
1013
1014 let derives = if self.config.enable_specta {
1016 quote! {
1017 #[derive(Debug, Clone, Deserialize, Serialize)]
1018 #[cfg_attr(feature = "specta", derive(specta::Type))]
1019 #[serde(tag = #discriminator_field)]
1020 }
1021 } else {
1022 quote! {
1023 #[derive(Debug, Clone, Deserialize, Serialize)]
1024 #[serde(tag = #discriminator_field)]
1025 }
1026 };
1027
1028 Ok(quote! {
1029 #doc_comment
1030 #derives
1031 pub enum #enum_name {
1032 #(#enum_variants)*
1033 }
1034 })
1035 }
1036
1037 fn should_use_untagged_discriminated_union(
1039 &self,
1040 schema: &crate::analysis::AnalyzedSchema,
1041 analysis: &crate::analysis::SchemaAnalysis,
1042 ) -> bool {
1043 for other_schema in analysis.schemas.values() {
1048 if let crate::analysis::SchemaType::DiscriminatedUnion {
1049 variants,
1050 discriminator_field: _,
1051 } = &other_schema.schema_type
1052 {
1053 for variant in variants {
1054 if variant.type_name == schema.name {
1055 if let crate::analysis::SchemaType::DiscriminatedUnion {
1060 discriminator_field: current_discriminator,
1061 variants: current_variants,
1062 ..
1063 } = &schema.schema_type
1064 {
1065 for current_variant in current_variants {
1067 if let Some(variant_schema) =
1068 analysis.schemas.get(¤t_variant.type_name)
1069 {
1070 if let crate::analysis::SchemaType::Object {
1071 properties, ..
1072 } = &variant_schema.schema_type
1073 {
1074 if properties.contains_key(current_discriminator) {
1075 return false;
1078 }
1079 }
1080 }
1081 }
1082 }
1083
1084 return true;
1086 }
1087 }
1088 }
1089 }
1090 false
1091 }
1092
1093 fn generate_union_enum(
1094 &self,
1095 schema: &crate::analysis::AnalyzedSchema,
1096 variants: &[crate::analysis::SchemaRef],
1097 ) -> Result<TokenStream> {
1098 let enum_name = format_ident!("{}", self.to_rust_type_name(&schema.name));
1099
1100 let mut used_variant_names = std::collections::HashSet::new();
1102 let enum_variants = variants.iter().enumerate().map(|(i, variant)| {
1103 let base_variant_name = self.type_name_to_variant_name(&variant.target);
1105 let variant_name = self.ensure_unique_variant_name_generator(
1106 base_variant_name,
1107 &mut used_variant_names,
1108 i,
1109 );
1110 let variant_name_ident = format_ident!("{}", variant_name);
1111
1112 let variant_type_tokens = if matches!(
1114 variant.target.as_str(),
1115 "bool"
1116 | "i8"
1117 | "i16"
1118 | "i32"
1119 | "i64"
1120 | "i128"
1121 | "u8"
1122 | "u16"
1123 | "u32"
1124 | "u64"
1125 | "u128"
1126 | "f32"
1127 | "f64"
1128 | "String"
1129 ) {
1130 let type_ident = format_ident!("{}", variant.target);
1131 quote! { #type_ident }
1132 } else if variant.target == "serde_json::Value" {
1133 quote! { serde_json::Value }
1136 } else if variant.target.starts_with("Vec<") && variant.target.ends_with(">") {
1137 let inner = &variant.target[4..variant.target.len() - 1];
1139
1140 if inner.starts_with("Vec<") && inner.ends_with(">") {
1142 let inner_inner = &inner[4..inner.len() - 1];
1143 if inner_inner == "serde_json::Value" {
1144 quote! { Vec<Vec<serde_json::Value>> }
1145 } else {
1146 let inner_inner_type = if matches!(
1147 inner_inner,
1148 "bool"
1149 | "i8"
1150 | "i16"
1151 | "i32"
1152 | "i64"
1153 | "i128"
1154 | "u8"
1155 | "u16"
1156 | "u32"
1157 | "u64"
1158 | "u128"
1159 | "f32"
1160 | "f64"
1161 | "String"
1162 ) {
1163 format_ident!("{}", inner_inner)
1164 } else {
1165 format_ident!("{}", self.to_rust_type_name(inner_inner))
1166 };
1167 quote! { Vec<Vec<#inner_inner_type>> }
1168 }
1169 } else if inner == "serde_json::Value" {
1170 quote! { Vec<serde_json::Value> }
1171 } else {
1172 let inner_type = if matches!(
1173 inner,
1174 "bool"
1175 | "i8"
1176 | "i16"
1177 | "i32"
1178 | "i64"
1179 | "i128"
1180 | "u8"
1181 | "u16"
1182 | "u32"
1183 | "u64"
1184 | "u128"
1185 | "f32"
1186 | "f64"
1187 | "String"
1188 ) {
1189 format_ident!("{}", inner)
1190 } else {
1191 format_ident!("{}", self.to_rust_type_name(inner))
1192 };
1193 quote! { Vec<#inner_type> }
1194 }
1195 } else {
1196 let type_ident = format_ident!("{}", self.to_rust_type_name(&variant.target));
1197 quote! { #type_ident }
1198 };
1199
1200 quote! {
1201 #variant_name_ident(#variant_type_tokens),
1202 }
1203 });
1204
1205 let doc_comment = if let Some(desc) = &schema.description {
1206 quote! { #[doc = #desc] }
1207 } else {
1208 TokenStream::new()
1209 };
1210
1211 let derives = if self.config.enable_specta {
1213 quote! {
1214 #[derive(Debug, Clone, Deserialize, Serialize)]
1215 #[cfg_attr(feature = "specta", derive(specta::Type))]
1216 #[serde(untagged)]
1217 }
1218 } else {
1219 quote! {
1220 #[derive(Debug, Clone, Deserialize, Serialize)]
1221 #[serde(untagged)]
1222 }
1223 };
1224
1225 Ok(quote! {
1226 #doc_comment
1227 #derives
1228 pub enum #enum_name {
1229 #(#enum_variants)*
1230 }
1231 })
1232 }
1233
1234 fn generate_field_type(
1235 &self,
1236 schema_name: &str,
1237 field_name: &str,
1238 prop: &crate::analysis::PropertyInfo,
1239 is_required: bool,
1240 analysis: &crate::analysis::SchemaAnalysis,
1241 ) -> TokenStream {
1242 use crate::analysis::SchemaType;
1243
1244 let base_type = match &prop.schema_type {
1245 SchemaType::Primitive { rust_type } => {
1246 if rust_type.contains("::") {
1248 let parts: Vec<&str> = rust_type.split("::").collect();
1249 if parts.len() == 2 {
1250 let module = format_ident!("{}", parts[0]);
1251 let type_name = format_ident!("{}", parts[1]);
1252 quote! { #module::#type_name }
1253 } else {
1254 let path_parts: Vec<_> =
1256 parts.iter().map(|p| format_ident!("{}", p)).collect();
1257 quote! { #(#path_parts)::* }
1258 }
1259 } else {
1260 let type_ident = format_ident!("{}", rust_type);
1261 quote! { #type_ident }
1262 }
1263 }
1264 SchemaType::Reference { target } => {
1265 let target_type = format_ident!("{}", self.to_rust_type_name(target));
1266 if analysis.dependencies.recursive_schemas.contains(target) {
1268 quote! { Box<#target_type> }
1269 } else {
1270 quote! { #target_type }
1271 }
1272 }
1273 SchemaType::Array { item_type } => {
1274 let inner_type = self.generate_array_item_type(item_type, analysis);
1275 quote! { Vec<#inner_type> }
1276 }
1277 _ => {
1278 quote! { serde_json::Value }
1280 }
1281 };
1282
1283 let override_key = format!("{schema_name}.{field_name}");
1285 let is_nullable_override = self
1286 .config
1287 .nullable_field_overrides
1288 .get(&override_key)
1289 .copied()
1290 .unwrap_or(false);
1291
1292 if is_required && !prop.nullable && !is_nullable_override {
1293 if prop.default.is_some() && self.type_lacks_default(&prop.schema_type, analysis) {
1296 quote! { Option<#base_type> }
1297 } else {
1298 base_type
1299 }
1300 } else {
1301 quote! { Option<#base_type> }
1302 }
1303 }
1304
1305 fn generate_serde_field_attrs(
1306 &self,
1307 field_name: &str,
1308 prop: &crate::analysis::PropertyInfo,
1309 is_required: bool,
1310 analysis: &crate::analysis::SchemaAnalysis,
1311 ) -> TokenStream {
1312 let mut attrs = Vec::new();
1313
1314 let rust_field_name = self.to_rust_field_name(field_name);
1317 let comparison_name = rust_field_name
1318 .strip_prefix("r#")
1319 .unwrap_or(&rust_field_name);
1320 if comparison_name != field_name {
1321 attrs.push(quote! { rename = #field_name });
1322 }
1323
1324 if !is_required || prop.nullable {
1326 attrs.push(quote! { skip_serializing_if = "Option::is_none" });
1327 }
1328
1329 if prop.default.is_some()
1333 && (is_required && !prop.nullable)
1334 && !self.type_lacks_default(&prop.schema_type, analysis)
1335 {
1336 attrs.push(quote! { default });
1337 }
1338
1339 if attrs.is_empty() {
1340 TokenStream::new()
1341 } else {
1342 quote! { #[serde(#(#attrs),*)] }
1343 }
1344 }
1345
1346 fn type_lacks_default(
1350 &self,
1351 schema_type: &crate::analysis::SchemaType,
1352 analysis: &crate::analysis::SchemaAnalysis,
1353 ) -> bool {
1354 use crate::analysis::SchemaType;
1355 match schema_type {
1356 SchemaType::DiscriminatedUnion { .. } | SchemaType::Union { .. } => true,
1357 SchemaType::Reference { target } => {
1358 if let Some(schema) = analysis.schemas.get(target) {
1359 self.type_lacks_default(&schema.schema_type, analysis)
1360 } else {
1361 false
1362 }
1363 }
1364 _ => false,
1365 }
1366 }
1367
1368 fn generate_specta_field_attrs(&self, field_name: &str) -> TokenStream {
1369 if !self.config.enable_specta {
1370 return TokenStream::new();
1371 }
1372
1373 let camel_case_name = self.to_camel_case(field_name);
1375
1376 if camel_case_name != field_name {
1378 quote! { #[cfg_attr(feature = "specta", specta(rename = #camel_case_name))] }
1379 } else {
1380 TokenStream::new()
1381 }
1382 }
1383
1384 fn to_rust_enum_variant(&self, s: &str) -> String {
1385 let mut result = String::new();
1387 let mut next_upper = true;
1388 let mut prev_was_upper = false;
1389
1390 for (i, c) in s.chars().enumerate() {
1391 match c {
1392 'a'..='z' => {
1393 if next_upper {
1394 result.push(c.to_ascii_uppercase());
1395 next_upper = false;
1396 } else {
1397 result.push(c);
1398 }
1399 prev_was_upper = false;
1400 }
1401 'A'..='Z' => {
1402 if next_upper || (!prev_was_upper && i > 0) {
1403 result.push(c);
1405 next_upper = false;
1406 } else {
1407 result.push(c.to_ascii_lowercase());
1409 }
1410 prev_was_upper = true;
1411 }
1412 '0'..='9' => {
1413 result.push(c);
1414 next_upper = false;
1415 prev_was_upper = false;
1416 }
1417 '.' | '-' | '_' | ' ' | '@' | '#' | '$' | '/' | '\\' => {
1418 next_upper = true;
1420 prev_was_upper = false;
1421 }
1422 _ => {
1423 next_upper = true;
1425 prev_was_upper = false;
1426 }
1427 }
1428 }
1429
1430 if result.is_empty() {
1432 result = "Value".to_string();
1433 }
1434
1435 if result.chars().next().is_some_and(|c| c.is_ascii_digit()) {
1437 result = format!("Variant{result}");
1438 }
1439
1440 match result.as_str() {
1442 "Null" => "NullValue".to_string(),
1443 "True" => "TrueValue".to_string(),
1444 "False" => "FalseValue".to_string(),
1445 "Type" => "Type_".to_string(),
1446 "Match" => "Match_".to_string(),
1447 "Fn" => "Fn_".to_string(),
1448 "Impl" => "Impl_".to_string(),
1449 "Trait" => "Trait_".to_string(),
1450 "Struct" => "Struct_".to_string(),
1451 "Enum" => "Enum_".to_string(),
1452 "Mod" => "Mod_".to_string(),
1453 "Use" => "Use_".to_string(),
1454 "Pub" => "Pub_".to_string(),
1455 "Const" => "Const_".to_string(),
1456 "Static" => "Static_".to_string(),
1457 "Let" => "Let_".to_string(),
1458 "Mut" => "Mut_".to_string(),
1459 "Ref" => "Ref_".to_string(),
1460 "Move" => "Move_".to_string(),
1461 "Return" => "Return_".to_string(),
1462 "If" => "If_".to_string(),
1463 "Else" => "Else_".to_string(),
1464 "While" => "While_".to_string(),
1465 "For" => "For_".to_string(),
1466 "Loop" => "Loop_".to_string(),
1467 "Break" => "Break_".to_string(),
1468 "Continue" => "Continue_".to_string(),
1469 "Self" => "Self_".to_string(),
1470 "Super" => "Super_".to_string(),
1471 "Crate" => "Crate_".to_string(),
1472 "Async" => "Async_".to_string(),
1473 "Await" => "Await_".to_string(),
1474 _ => result,
1475 }
1476 }
1477
1478 #[allow(dead_code)]
1479 fn to_rust_identifier(&self, s: &str) -> String {
1480 let mut result = s
1482 .chars()
1483 .map(|c| match c {
1484 'a'..='z' | 'A'..='Z' | '0'..='9' => c,
1485 '.' | '-' | '_' | ' ' | '@' | '#' | '$' | '/' | '\\' => '_',
1486 _ => '_',
1487 })
1488 .collect::<String>();
1489
1490 result = result.trim_matches('_').to_string();
1492
1493 if result.is_empty() {
1495 result = "value".to_string();
1496 }
1497
1498 if result.chars().next().is_some_and(|c| c.is_ascii_digit()) {
1500 result = format!("variant_{result}");
1501 }
1502
1503 match result.as_str() {
1505 "null" => "null_value".to_string(),
1506 "true" => "true_value".to_string(),
1507 "false" => "false_value".to_string(),
1508 "type" => "type_".to_string(),
1509 "match" => "match_".to_string(),
1510 "fn" => "fn_".to_string(),
1511 "impl" => "impl_".to_string(),
1512 "trait" => "trait_".to_string(),
1513 "struct" => "struct_".to_string(),
1514 "enum" => "enum_".to_string(),
1515 "mod" => "mod_".to_string(),
1516 "use" => "use_".to_string(),
1517 "pub" => "pub_".to_string(),
1518 "const" => "const_".to_string(),
1519 "static" => "static_".to_string(),
1520 "let" => "let_".to_string(),
1521 "mut" => "mut_".to_string(),
1522 "ref" => "ref_".to_string(),
1523 "move" => "move_".to_string(),
1524 "return" => "return_".to_string(),
1525 "if" => "if_".to_string(),
1526 "else" => "else_".to_string(),
1527 "while" => "while_".to_string(),
1528 "for" => "for_".to_string(),
1529 "loop" => "loop_".to_string(),
1530 "break" => "break_".to_string(),
1531 "continue" => "continue_".to_string(),
1532 "self" => "self_".to_string(),
1533 "super" => "super_".to_string(),
1534 "crate" => "crate_".to_string(),
1535 "async" => "async_".to_string(),
1536 "await" => "await_".to_string(),
1537 "override" => "override_".to_string(),
1539 "box" => "box_".to_string(),
1540 "dyn" => "dyn_".to_string(),
1541 "where" => "where_".to_string(),
1542 "in" => "in_".to_string(),
1543 "abstract" => "abstract_".to_string(),
1545 "become" => "become_".to_string(),
1546 "do" => "do_".to_string(),
1547 "final" => "final_".to_string(),
1548 "macro" => "macro_".to_string(),
1549 "priv" => "priv_".to_string(),
1550 "try" => "try_".to_string(),
1551 "typeof" => "typeof_".to_string(),
1552 "unsized" => "unsized_".to_string(),
1553 "virtual" => "virtual_".to_string(),
1554 "yield" => "yield_".to_string(),
1555 _ => result,
1556 }
1557 }
1558
1559 fn sanitize_doc_comment(&self, desc: &str) -> String {
1560 let mut result = desc.to_string();
1562
1563 if result.contains('\n')
1571 && (result.contains('{')
1572 || result.contains("```")
1573 || result.contains("Human:")
1574 || result.contains("Assistant:")
1575 || result
1576 .lines()
1577 .any(|line| line.trim().starts_with('"') && line.trim().ends_with('"')))
1578 {
1579 if result.contains("```") {
1581 result = result.replace("```", "```ignore");
1582 } else {
1583 if result.lines().any(|line| {
1585 let trimmed = line.trim();
1586 trimmed.starts_with('"') && trimmed.ends_with('"') && trimmed.len() > 2
1587 }) {
1588 result = format!("```ignore\n{result}\n```");
1589 }
1590 }
1591 }
1592
1593 result
1594 }
1595
1596 pub(crate) fn to_rust_type_name(&self, s: &str) -> String {
1597 let mut result = String::new();
1599 let mut next_upper = true;
1600 let mut prev_was_lower = false;
1601
1602 for c in s.chars() {
1603 match c {
1604 'a'..='z' => {
1605 if next_upper {
1606 result.push(c.to_ascii_uppercase());
1607 next_upper = false;
1608 } else {
1609 result.push(c);
1610 }
1611 prev_was_lower = true;
1612 }
1613 'A'..='Z' => {
1614 result.push(c);
1615 next_upper = false;
1616 prev_was_lower = false;
1617 }
1618 '0'..='9' => {
1619 if prev_was_lower && !result.chars().last().unwrap_or(' ').is_ascii_digit() {
1622 }
1624 result.push(c);
1625 next_upper = false;
1626 prev_was_lower = false;
1627 }
1628 '_' | '-' | '.' | ' ' => {
1629 next_upper = true;
1631 prev_was_lower = false;
1632 }
1633 _ => {
1634 next_upper = true;
1636 prev_was_lower = false;
1637 }
1638 }
1639 }
1640
1641 if result.is_empty() {
1643 result = "Type".to_string();
1644 }
1645
1646 if result.chars().next().is_some_and(|c| c.is_ascii_digit()) {
1648 result = format!("Type{result}");
1649 }
1650
1651 result
1652 }
1653
1654 fn to_rust_field_name(&self, s: &str) -> String {
1655 let mut result = String::new();
1657 let mut prev_was_upper = false;
1658 let mut prev_was_underscore = false;
1659
1660 for (i, c) in s.chars().enumerate() {
1661 match c {
1662 'A'..='Z' => {
1663 if i > 0 && !prev_was_upper && !prev_was_underscore {
1665 result.push('_');
1666 }
1667 result.push(c.to_ascii_lowercase());
1668 prev_was_upper = true;
1669 prev_was_underscore = false;
1670 }
1671 'a'..='z' | '0'..='9' => {
1672 result.push(c);
1673 prev_was_upper = false;
1674 prev_was_underscore = false;
1675 }
1676 '-' | '.' | '_' | '@' | '#' | '$' | ' ' => {
1677 if !prev_was_underscore && !result.is_empty() {
1678 result.push('_');
1679 prev_was_underscore = true;
1680 }
1681 prev_was_upper = false;
1682 }
1683 _ => {
1684 if !prev_was_underscore && !result.is_empty() {
1686 result.push('_');
1687 }
1688 prev_was_upper = false;
1689 prev_was_underscore = true;
1690 }
1691 }
1692 }
1693
1694 let mut result = result.trim_matches('_').to_string();
1696 if result.is_empty() {
1697 return "field".to_string();
1698 }
1699
1700 if result.chars().next().is_some_and(|c| c.is_ascii_digit()) {
1702 result = format!("field_{result}");
1703 }
1704
1705 if Self::is_rust_keyword(&result) {
1707 format!("r#{result}")
1708 } else {
1709 result
1710 }
1711 }
1712
1713 pub fn is_rust_keyword(s: &str) -> bool {
1715 matches!(
1716 s,
1717 "type"
1718 | "match"
1719 | "fn"
1720 | "struct"
1721 | "enum"
1722 | "impl"
1723 | "trait"
1724 | "mod"
1725 | "use"
1726 | "pub"
1727 | "const"
1728 | "static"
1729 | "let"
1730 | "mut"
1731 | "ref"
1732 | "move"
1733 | "return"
1734 | "if"
1735 | "else"
1736 | "while"
1737 | "for"
1738 | "loop"
1739 | "break"
1740 | "continue"
1741 | "self"
1742 | "super"
1743 | "crate"
1744 | "async"
1745 | "await"
1746 | "override"
1747 | "box"
1748 | "dyn"
1749 | "where"
1750 | "in"
1751 | "abstract"
1752 | "become"
1753 | "do"
1754 | "final"
1755 | "macro"
1756 | "priv"
1757 | "try"
1758 | "typeof"
1759 | "unsized"
1760 | "virtual"
1761 | "yield"
1762 )
1763 }
1764
1765 pub fn to_field_ident(name: &str) -> proc_macro2::Ident {
1767 if let Some(raw) = name.strip_prefix("r#") {
1768 proc_macro2::Ident::new_raw(raw, proc_macro2::Span::call_site())
1769 } else {
1770 proc_macro2::Ident::new(name, proc_macro2::Span::call_site())
1771 }
1772 }
1773
1774 fn to_camel_case(&self, s: &str) -> String {
1775 let mut result = String::new();
1777 let mut capitalize_next = false;
1778
1779 for (i, c) in s.chars().enumerate() {
1780 match c {
1781 '_' | '-' | '.' | ' ' => {
1782 capitalize_next = true;
1784 }
1785 'A'..='Z' => {
1786 if i == 0 {
1787 result.push(c.to_ascii_lowercase());
1789 } else if capitalize_next {
1790 result.push(c);
1791 capitalize_next = false;
1792 } else {
1793 result.push(c.to_ascii_lowercase());
1794 }
1795 }
1796 'a'..='z' | '0'..='9' => {
1797 if capitalize_next {
1798 result.push(c.to_ascii_uppercase());
1799 capitalize_next = false;
1800 } else {
1801 result.push(c);
1802 }
1803 }
1804 _ => {
1805 capitalize_next = true;
1807 }
1808 }
1809 }
1810
1811 if result.is_empty() {
1812 return "field".to_string();
1813 }
1814
1815 result
1816 }
1817
1818 fn generate_composition_struct(
1819 &self,
1820 schema: &crate::analysis::AnalyzedSchema,
1821 schemas: &[crate::analysis::SchemaRef],
1822 ) -> Result<TokenStream> {
1823 let struct_name = format_ident!("{}", self.to_rust_type_name(&schema.name));
1824
1825 let fields = schemas.iter().enumerate().map(|(i, schema_ref)| {
1831 let field_name = format_ident!("part_{}", i);
1832 let field_type = format_ident!("{}", self.to_rust_type_name(&schema_ref.target));
1833
1834 quote! {
1835 #[serde(flatten)]
1836 pub #field_name: #field_type,
1837 }
1838 });
1839
1840 let doc_comment = if let Some(desc) = &schema.description {
1841 quote! { #[doc = #desc] }
1842 } else {
1843 TokenStream::new()
1844 };
1845
1846 let derives = if self.config.enable_specta {
1848 quote! {
1849 #[derive(Debug, Clone, Deserialize, Serialize)]
1850 #[cfg_attr(feature = "specta", derive(specta::Type))]
1851 }
1852 } else {
1853 quote! {
1854 #[derive(Debug, Clone, Deserialize, Serialize)]
1855 }
1856 };
1857
1858 Ok(quote! {
1859 #doc_comment
1860 #derives
1861 pub struct #struct_name {
1862 #(#fields)*
1863 }
1864 })
1865 }
1866
1867 #[allow(dead_code)]
1868 fn find_missing_types(&self, analysis: &SchemaAnalysis) -> std::collections::HashSet<String> {
1869 let mut missing = std::collections::HashSet::new();
1870 let defined_types: std::collections::HashSet<String> =
1871 analysis.schemas.keys().cloned().collect();
1872
1873 for schema in analysis.schemas.values() {
1875 match &schema.schema_type {
1876 crate::analysis::SchemaType::Union { variants } => {
1877 for variant in variants {
1878 if !defined_types.contains(&variant.target) {
1879 missing.insert(variant.target.clone());
1880 }
1881 }
1882 }
1883 crate::analysis::SchemaType::DiscriminatedUnion { variants, .. } => {
1884 for variant in variants {
1885 if !defined_types.contains(&variant.type_name) {
1886 missing.insert(variant.type_name.clone());
1887 }
1888 }
1889 }
1890 crate::analysis::SchemaType::Object { properties, .. } => {
1891 let mut sorted_props: Vec<_> = properties.iter().collect();
1893 sorted_props.sort_by_key(|(name, _)| name.as_str());
1894 for (_, prop) in sorted_props {
1895 if let crate::analysis::SchemaType::Reference { target } = &prop.schema_type
1896 {
1897 if !defined_types.contains(target) {
1898 missing.insert(target.clone());
1899 }
1900 }
1901 }
1902 }
1903 crate::analysis::SchemaType::Reference { target }
1904 if !defined_types.contains(target) =>
1905 {
1906 missing.insert(target.clone());
1907 }
1908 _ => {}
1909 }
1910 }
1911
1912 missing
1913 }
1914
1915 #[allow(clippy::only_used_in_recursion)]
1916 fn generate_array_item_type(
1917 &self,
1918 item_type: &crate::analysis::SchemaType,
1919 analysis: &crate::analysis::SchemaAnalysis,
1920 ) -> TokenStream {
1921 use crate::analysis::SchemaType;
1922
1923 match item_type {
1924 SchemaType::Primitive { rust_type } => {
1925 if rust_type.contains("::") {
1927 let parts: Vec<&str> = rust_type.split("::").collect();
1928 if parts.len() == 2 {
1929 let module = format_ident!("{}", parts[0]);
1930 let type_name = format_ident!("{}", parts[1]);
1931 quote! { #module::#type_name }
1932 } else {
1933 let path_parts: Vec<_> =
1935 parts.iter().map(|p| format_ident!("{}", p)).collect();
1936 quote! { #(#path_parts)::* }
1937 }
1938 } else {
1939 let type_ident = format_ident!("{}", rust_type);
1940 quote! { #type_ident }
1941 }
1942 }
1943 SchemaType::Reference { target } => {
1944 let target_type = format_ident!("{}", self.to_rust_type_name(target));
1945 if analysis.dependencies.recursive_schemas.contains(target) {
1947 quote! { Box<#target_type> }
1948 } else {
1949 quote! { #target_type }
1950 }
1951 }
1952 SchemaType::Array { item_type } => {
1953 let inner_type = self.generate_array_item_type(item_type, analysis);
1955 quote! { Vec<#inner_type> }
1956 }
1957 _ => {
1958 quote! { serde_json::Value }
1960 }
1961 }
1962 }
1963
1964 fn type_name_to_variant_name(&self, type_name: &str) -> String {
1966 match type_name {
1968 "bool" => return "Boolean".to_string(),
1969 "i8" | "i16" | "i32" | "i64" | "i128" => return "Integer".to_string(),
1970 "u8" | "u16" | "u32" | "u64" | "u128" => return "UnsignedInteger".to_string(),
1971 "f32" | "f64" => return "Number".to_string(),
1972 "String" => return "String".to_string(),
1973 "serde_json::Value" => return "Value".to_string(),
1974 _ => {}
1975 }
1976
1977 if type_name.starts_with("Vec<") && type_name.ends_with(">") {
1979 let inner = &type_name[4..type_name.len() - 1];
1980 if inner.starts_with("Vec<") && inner.ends_with(">") {
1982 let inner_inner = &inner[4..inner.len() - 1];
1983 return format!("{}ArrayArray", self.type_name_to_variant_name(inner_inner));
1984 }
1985 return format!("{}Array", self.type_name_to_variant_name(inner));
1986 }
1987
1988 let clean_name = type_name
1994 .trim_end_matches("Type")
1995 .trim_end_matches("Schema")
1996 .trim_end_matches("Item");
1997
1998 self.to_rust_type_name(clean_name)
2000 }
2001
2002 fn ensure_unique_variant_name_generator(
2004 &self,
2005 base_name: String,
2006 used_names: &mut std::collections::HashSet<String>,
2007 fallback_index: usize,
2008 ) -> String {
2009 if used_names.insert(base_name.clone()) {
2010 return base_name;
2011 }
2012
2013 for i in 2..100 {
2015 let numbered_name = format!("{base_name}{i}");
2016 if used_names.insert(numbered_name.clone()) {
2017 return numbered_name;
2018 }
2019 }
2020
2021 let fallback = format!("Variant{fallback_index}");
2023 used_names.insert(fallback.clone());
2024 fallback
2025 }
2026
2027 fn find_request_type_for_operation(
2029 &self,
2030 operation_id: &str,
2031 analysis: &SchemaAnalysis,
2032 ) -> Option<String> {
2033 analysis.operations.get(operation_id).and_then(|op| {
2035 op.request_body
2036 .as_ref()
2037 .and_then(|rb| rb.schema_name().map(|s| s.to_string()))
2038 })
2039 }
2040
2041 fn resolve_streaming_event_type(
2043 &self,
2044 endpoint: &crate::streaming::StreamingEndpoint,
2045 analysis: &SchemaAnalysis,
2046 ) -> Result<String> {
2047 match &endpoint.event_flow {
2048 crate::streaming::EventFlow::Simple => {
2049 if analysis.schemas.contains_key(&endpoint.event_union_type) {
2052 Ok(endpoint.event_union_type.to_string())
2053 } else {
2054 Err(crate::error::GeneratorError::ValidationError(format!(
2055 "Streaming response type '{}' not found in schema for simple streaming endpoint '{}'",
2056 endpoint.event_union_type, endpoint.operation_id
2057 )))
2058 }
2059 }
2060 crate::streaming::EventFlow::StartDeltaStop { .. } => {
2061 if analysis.schemas.contains_key(&endpoint.event_union_type) {
2064 Ok(endpoint.event_union_type.to_string())
2065 } else {
2066 Err(crate::error::GeneratorError::ValidationError(format!(
2067 "Event union type '{}' not found in schema for complex streaming endpoint '{}'",
2068 endpoint.event_union_type, endpoint.operation_id
2069 )))
2070 }
2071 }
2072 }
2073 }
2074
2075 fn generate_streaming_error_types(&self) -> Result<TokenStream> {
2077 Ok(quote! {
2078 #[derive(Debug, thiserror::Error)]
2080 pub enum StreamingError {
2081 #[error("Connection error: {0}")]
2082 Connection(String),
2083 #[error("HTTP error: {status}")]
2084 Http { status: u16 },
2085 #[error("SSE parsing error: {0}")]
2086 Parsing(String),
2087 #[error("Authentication error: {0}")]
2088 Authentication(String),
2089 #[error("Rate limit error: {0}")]
2090 RateLimit(String),
2091 #[error("API error: {0}")]
2092 Api(String),
2093 #[error("Timeout error: {0}")]
2094 Timeout(String),
2095 #[error("JSON serialization/deserialization error: {0}")]
2096 Json(#[from] serde_json::Error),
2097 #[error("Request error: {0}")]
2098 Request(reqwest::Error),
2099 }
2100
2101 impl From<reqwest::header::InvalidHeaderValue> for StreamingError {
2102 fn from(err: reqwest::header::InvalidHeaderValue) -> Self {
2103 StreamingError::Api(format!("Invalid header value: {}", err))
2104 }
2105 }
2106
2107 impl From<reqwest::Error> for StreamingError {
2108 fn from(err: reqwest::Error) -> Self {
2109 if err.is_timeout() {
2110 StreamingError::Timeout(err.to_string())
2111 } else if err.is_status() {
2112 if let Some(status) = err.status() {
2113 StreamingError::Http { status: status.as_u16() }
2114 } else {
2115 StreamingError::Connection(err.to_string())
2116 }
2117 } else {
2118 StreamingError::Request(err)
2119 }
2120 }
2121 }
2122 })
2123 }
2124
2125 fn generate_endpoint_trait(
2127 &self,
2128 endpoint: &crate::streaming::StreamingEndpoint,
2129 analysis: &SchemaAnalysis,
2130 ) -> Result<TokenStream> {
2131 use crate::streaming::HttpMethod;
2132
2133 let trait_name = format_ident!(
2134 "{}StreamingClient",
2135 self.to_rust_type_name(&endpoint.operation_id)
2136 );
2137 let method_name =
2138 format_ident!("stream_{}", self.to_rust_field_name(&endpoint.operation_id));
2139 let event_type =
2140 format_ident!("{}", self.resolve_streaming_event_type(endpoint, analysis)?);
2141
2142 let method_signature = match endpoint.http_method {
2144 HttpMethod::Get => {
2145 let mut param_defs = Vec::new();
2147 for qp in &endpoint.query_parameters {
2148 let param_name = format_ident!("{}", self.to_rust_field_name(&qp.name));
2149 if qp.required {
2150 param_defs.push(quote! { #param_name: &str });
2151 } else {
2152 param_defs.push(quote! { #param_name: Option<&str> });
2153 }
2154 }
2155 quote! {
2156 async fn #method_name(
2157 &self,
2158 #(#param_defs),*
2159 ) -> Result<Pin<Box<dyn Stream<Item = Result<#event_type, Self::Error>> + Send>>, Self::Error>;
2160 }
2161 }
2162 HttpMethod::Post => {
2163 let request_type = self
2165 .find_request_type_for_operation(&endpoint.operation_id, analysis)
2166 .unwrap_or_else(|| "serde_json::Value".to_string());
2167 let request_type_ident = if request_type.contains("::") {
2168 let parts: Vec<&str> = request_type.split("::").collect();
2169 let path_parts: Vec<_> = parts.iter().map(|p| format_ident!("{}", p)).collect();
2170 quote! { #(#path_parts)::* }
2171 } else {
2172 let ident = format_ident!("{}", request_type);
2173 quote! { #ident }
2174 };
2175 quote! {
2176 async fn #method_name(
2177 &self,
2178 request: #request_type_ident,
2179 ) -> Result<Pin<Box<dyn Stream<Item = Result<#event_type, Self::Error>> + Send>>, Self::Error>;
2180 }
2181 }
2182 };
2183
2184 Ok(quote! {
2185 #[async_trait]
2187 pub trait #trait_name {
2188 type Error: std::error::Error + Send + Sync + 'static;
2189
2190 #method_signature
2192 }
2193 })
2194 }
2195
2196 fn generate_streaming_client_impl(
2198 &self,
2199 streaming_config: &crate::streaming::StreamingConfig,
2200 analysis: &SchemaAnalysis,
2201 ) -> Result<TokenStream> {
2202 let client_name = format_ident!(
2203 "{}Client",
2204 self.to_rust_type_name(&streaming_config.client_module_name)
2205 );
2206
2207 let mut struct_fields = vec![
2210 quote! { base_url: String },
2211 quote! { api_key: Option<String> },
2212 quote! { http_client: reqwest::Client },
2213 quote! { custom_headers: std::collections::BTreeMap<String, String> },
2214 ];
2215
2216 let has_optional_headers = !streaming_config
2217 .endpoints
2218 .iter()
2219 .all(|e| e.optional_headers.is_empty());
2220
2221 if has_optional_headers {
2222 struct_fields
2223 .push(quote! { optional_headers: std::collections::BTreeMap<String, String> });
2224 }
2225
2226 let default_base_url = if let Some(ref streaming_config) = self.config.streaming_config {
2229 streaming_config
2230 .endpoints
2231 .first()
2232 .and_then(|e| e.base_url.as_deref())
2233 .unwrap_or("https://api.example.com")
2234 } else {
2235 "https://api.example.com"
2236 };
2237
2238 let constructor_fields = if has_optional_headers {
2240 quote! {
2241 base_url: #default_base_url.to_string(),
2242 api_key: None,
2243 http_client: reqwest::Client::new(),
2244 custom_headers: std::collections::BTreeMap::new(),
2245 optional_headers: std::collections::BTreeMap::new(),
2246 }
2247 } else {
2248 quote! {
2249 base_url: #default_base_url.to_string(),
2250 api_key: None,
2251 http_client: reqwest::Client::new(),
2252 custom_headers: std::collections::BTreeMap::new(),
2253 }
2254 };
2255
2256 let optional_headers_method = if has_optional_headers {
2258 quote! {
2259 pub fn set_optional_headers(&mut self, headers: std::collections::BTreeMap<String, String>) {
2261 self.optional_headers = headers;
2262 }
2263 }
2264 } else {
2265 TokenStream::new()
2266 };
2267
2268 let constructor = quote! {
2269 impl #client_name {
2270 pub fn new() -> Self {
2272 Self {
2273 #constructor_fields
2274 }
2275 }
2276
2277 pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
2279 self.base_url = base_url.into();
2280 self
2281 }
2282
2283 pub fn with_api_key(mut self, api_key: impl Into<String>) -> Self {
2285 self.api_key = Some(api_key.into());
2286 self
2287 }
2288
2289 pub fn with_header(
2291 mut self,
2292 name: impl Into<String>,
2293 value: impl Into<String>,
2294 ) -> Self {
2295 self.custom_headers.insert(name.into(), value.into());
2296 self
2297 }
2298
2299 pub fn with_http_client(mut self, client: reqwest::Client) -> Self {
2301 self.http_client = client;
2302 self
2303 }
2304
2305 #optional_headers_method
2306 }
2307 };
2308
2309 let mut trait_impls = Vec::new();
2311 for endpoint in &streaming_config.endpoints {
2312 let trait_impl = self.generate_endpoint_trait_impl(endpoint, &client_name, analysis)?;
2313 trait_impls.push(trait_impl);
2314 }
2315
2316 let default_impl = quote! {
2318 impl Default for #client_name {
2319 fn default() -> Self {
2320 Self::new()
2321 }
2322 }
2323 };
2324
2325 Ok(quote! {
2326 #[derive(Debug, Clone)]
2328 pub struct #client_name {
2329 #(#struct_fields,)*
2330 }
2331
2332 #constructor
2333
2334 #default_impl
2335
2336 #(#trait_impls)*
2337 })
2338 }
2339
2340 fn generate_endpoint_trait_impl(
2342 &self,
2343 endpoint: &crate::streaming::StreamingEndpoint,
2344 client_name: &proc_macro2::Ident,
2345 analysis: &SchemaAnalysis,
2346 ) -> Result<TokenStream> {
2347 use crate::streaming::HttpMethod;
2348
2349 let trait_name = format_ident!(
2350 "{}StreamingClient",
2351 self.to_rust_type_name(&endpoint.operation_id)
2352 );
2353 let method_name =
2354 format_ident!("stream_{}", self.to_rust_field_name(&endpoint.operation_id));
2355 let event_type =
2356 format_ident!("{}", self.resolve_streaming_event_type(endpoint, analysis)?);
2357
2358 let mut header_setup = Vec::new();
2360 for (name, value) in &endpoint.required_headers {
2361 header_setup.push(quote! {
2362 headers.insert(#name, HeaderValue::from_static(#value));
2363 });
2364 }
2365
2366 if let Some(auth_header) = &endpoint.auth_header {
2369 match auth_header {
2370 crate::streaming::AuthHeader::Bearer(header_name) => {
2371 header_setup.push(quote! {
2372 if let Some(ref api_key) = self.api_key {
2373 headers.insert(#header_name, HeaderValue::from_str(&format!("Bearer {}", api_key))?);
2374 }
2375 });
2376 }
2377 crate::streaming::AuthHeader::ApiKey(header_name) => {
2378 header_setup.push(quote! {
2379 if let Some(ref api_key) = self.api_key {
2380 headers.insert(#header_name, HeaderValue::from_str(api_key)?);
2381 }
2382 });
2383 }
2384 }
2385 } else {
2386 header_setup.push(quote! {
2388 if let Some(ref api_key) = self.api_key {
2389 headers.insert("Authorization", HeaderValue::from_str(&format!("Bearer {}", api_key))?);
2390 }
2391 });
2392 }
2393
2394 header_setup.push(quote! {
2396 for (name, value) in &self.custom_headers {
2397 if let (Ok(header_name), Ok(header_value)) = (reqwest::header::HeaderName::from_bytes(name.as_bytes()), HeaderValue::from_str(value)) {
2398 headers.insert(header_name, header_value);
2399 }
2400 }
2401 });
2402
2403 if !endpoint.optional_headers.is_empty() {
2405 header_setup.push(quote! {
2406 for (key, value) in &self.optional_headers {
2407 if let (Ok(header_name), Ok(header_value)) = (reqwest::header::HeaderName::from_bytes(key.as_bytes()), HeaderValue::from_str(value)) {
2408 headers.insert(header_name, header_value);
2409 }
2410 }
2411 });
2412 }
2413
2414 match endpoint.http_method {
2416 HttpMethod::Get => self.generate_get_streaming_impl(
2417 endpoint,
2418 client_name,
2419 &trait_name,
2420 &method_name,
2421 &event_type,
2422 &header_setup,
2423 ),
2424 HttpMethod::Post => self.generate_post_streaming_impl(
2425 endpoint,
2426 client_name,
2427 &trait_name,
2428 &method_name,
2429 &event_type,
2430 &header_setup,
2431 analysis,
2432 ),
2433 }
2434 }
2435
2436 fn generate_get_streaming_impl(
2438 &self,
2439 endpoint: &crate::streaming::StreamingEndpoint,
2440 client_name: &proc_macro2::Ident,
2441 trait_name: &proc_macro2::Ident,
2442 method_name: &proc_macro2::Ident,
2443 event_type: &proc_macro2::Ident,
2444 header_setup: &[TokenStream],
2445 ) -> Result<TokenStream> {
2446 let path = &endpoint.path;
2447
2448 let mut param_defs = Vec::new();
2450 let mut query_params = Vec::new();
2451
2452 for qp in &endpoint.query_parameters {
2453 let param_name = format_ident!("{}", self.to_rust_field_name(&qp.name));
2454 let param_name_str = &qp.name;
2455
2456 if qp.required {
2457 param_defs.push(quote! { #param_name: &str });
2458 query_params.push(quote! {
2459 url.query_pairs_mut().append_pair(#param_name_str, #param_name);
2460 });
2461 } else {
2462 param_defs.push(quote! { #param_name: Option<&str> });
2463 query_params.push(quote! {
2464 if let Some(v) = #param_name {
2465 url.query_pairs_mut().append_pair(#param_name_str, v);
2466 }
2467 });
2468 }
2469 }
2470
2471 let url_construction = quote! {
2473 let base_url = url::Url::parse(&self.base_url)
2474 .map_err(|e| StreamingError::Connection(format!("Invalid base URL: {}", e)))?;
2475 let path_to_join = #path.trim_start_matches('/');
2476 let mut url = base_url.join(path_to_join)
2477 .map_err(|e| StreamingError::Connection(format!("URL join error: {}", e)))?;
2478 #(#query_params)*
2479 };
2480
2481 let instrument_skip = quote! { #[instrument(skip(self), name = "streaming_get_request")] };
2482
2483 Ok(quote! {
2484 #[async_trait]
2485 impl #trait_name for #client_name {
2486 type Error = StreamingError;
2487
2488 #instrument_skip
2489 async fn #method_name(
2490 &self,
2491 #(#param_defs),*
2492 ) -> Result<Pin<Box<dyn Stream<Item = Result<#event_type, Self::Error>> + Send>>, Self::Error> {
2493 debug!("Starting streaming GET request");
2494
2495 let mut headers = HeaderMap::new();
2496 #(#header_setup)*
2497
2498 #url_construction
2499 let url_str = url.to_string();
2500 debug!("Making streaming GET request to: {}", url_str);
2501
2502 let request_builder = self.http_client
2503 .get(url_str)
2504 .headers(headers);
2505
2506 debug!("Creating SSE stream from request");
2507 let stream = parse_sse_stream::<#event_type>(request_builder).await?;
2508 info!("SSE stream created successfully");
2509 Ok(Box::pin(stream))
2510 }
2511 }
2512 })
2513 }
2514
2515 #[allow(clippy::too_many_arguments)]
2517 fn generate_post_streaming_impl(
2518 &self,
2519 endpoint: &crate::streaming::StreamingEndpoint,
2520 client_name: &proc_macro2::Ident,
2521 trait_name: &proc_macro2::Ident,
2522 method_name: &proc_macro2::Ident,
2523 event_type: &proc_macro2::Ident,
2524 header_setup: &[TokenStream],
2525 analysis: &SchemaAnalysis,
2526 ) -> Result<TokenStream> {
2527 let path = &endpoint.path;
2528
2529 let request_type = self
2531 .find_request_type_for_operation(&endpoint.operation_id, analysis)
2532 .unwrap_or_else(|| "serde_json::Value".to_string());
2533 let request_type_ident = if request_type.contains("::") {
2534 let parts: Vec<&str> = request_type.split("::").collect();
2535 let path_parts: Vec<_> = parts.iter().map(|p| format_ident!("{}", p)).collect();
2536 quote! { #(#path_parts)::* }
2537 } else {
2538 let ident = format_ident!("{}", request_type);
2539 quote! { #ident }
2540 };
2541
2542 let url_construction = quote! {
2544 let base_url = url::Url::parse(&self.base_url)
2545 .map_err(|e| StreamingError::Connection(format!("Invalid base URL: {}", e)))?;
2546 let path_to_join = #path.trim_start_matches('/');
2547 let url = base_url.join(path_to_join)
2548 .map_err(|e| StreamingError::Connection(format!("URL join error: {}", e)))?
2549 .to_string();
2550 };
2551
2552 let stream_param = &endpoint.stream_parameter;
2554 let stream_setup = if stream_param.is_empty() {
2555 quote! {
2556 let streaming_request = request;
2557 }
2558 } else {
2559 quote! {
2560 let mut streaming_request = request;
2562 if let Ok(mut request_value) = serde_json::to_value(&streaming_request) {
2563 if let Some(obj) = request_value.as_object_mut() {
2564 obj.insert(#stream_param.to_string(), serde_json::Value::Bool(true));
2565 }
2566 streaming_request = serde_json::from_value(request_value)?;
2567 }
2568 }
2569 };
2570
2571 Ok(quote! {
2572 #[async_trait]
2573 impl #trait_name for #client_name {
2574 type Error = StreamingError;
2575
2576 #[instrument(skip(self, request), name = "streaming_post_request")]
2577 async fn #method_name(
2578 &self,
2579 request: #request_type_ident,
2580 ) -> Result<Pin<Box<dyn Stream<Item = Result<#event_type, Self::Error>> + Send>>, Self::Error> {
2581 debug!("Starting streaming POST request");
2582
2583 #stream_setup
2584
2585 let mut headers = HeaderMap::new();
2586 #(#header_setup)*
2587
2588 #url_construction
2589 debug!("Making streaming POST request to: {}", url);
2590
2591 let request_builder = self.http_client
2592 .post(&url)
2593 .headers(headers)
2594 .json(&streaming_request);
2595
2596 debug!("Creating SSE stream from request");
2597 let stream = parse_sse_stream::<#event_type>(request_builder).await?;
2598 info!("SSE stream created successfully");
2599 Ok(Box::pin(stream))
2600 }
2601 }
2602 })
2603 }
2604
2605 fn generate_sse_parser_utilities(
2607 &self,
2608 _streaming_config: &crate::streaming::StreamingConfig,
2609 ) -> Result<TokenStream> {
2610 Ok(quote! {
2611 pub async fn parse_sse_stream<T>(
2613 request_builder: reqwest::RequestBuilder
2614 ) -> Result<impl Stream<Item = Result<T, StreamingError>>, StreamingError>
2615 where
2616 T: serde::de::DeserializeOwned + Send + 'static,
2617 {
2618 let mut event_source = reqwest_eventsource::EventSource::new(request_builder).map_err(|e| {
2619 StreamingError::Connection(format!("Failed to create event source: {}", e))
2620 })?;
2621
2622 let stream = event_source.filter_map(|event_result| async move {
2623 match event_result {
2624 Ok(reqwest_eventsource::Event::Open) => {
2625 debug!("SSE connection opened");
2626 None
2627 }
2628 Ok(reqwest_eventsource::Event::Message(message)) => {
2629 if message.event == "ping" {
2631 debug!("Received SSE ping event, skipping");
2632 return None;
2633 }
2634
2635 if message.data.trim().is_empty() {
2637 debug!("Empty SSE data, skipping");
2638 return None;
2639 }
2640
2641 if let Ok(json_value) = serde_json::from_str::<serde_json::Value>(&message.data) {
2643 if let Some(event_type) = json_value.get("event").and_then(|v| v.as_str()) {
2644 if event_type == "ping" {
2645 debug!("Received ping event in JSON data, skipping");
2646 return None;
2647 }
2648 }
2649
2650 match serde_json::from_value::<T>(json_value) {
2652 Ok(parsed_event) => {
2653 Some(Ok(parsed_event))
2654 }
2655 Err(e) => {
2656 if message.data.contains("ping") || message.event.contains("ping") {
2657 debug!("Ignoring ping-related event: {}", message.data);
2658 None
2659 } else {
2660 Some(Err(StreamingError::Parsing(
2661 format!("Failed to parse SSE event: {} (raw: {})", e, message.data)
2662 )))
2663 }
2664 }
2665 }
2666 } else {
2667 Some(Err(StreamingError::Parsing(
2669 format!("SSE event is not valid JSON: {}", message.data)
2670 )))
2671 }
2672 }
2673 Err(e) => {
2674 match e {
2676 reqwest_eventsource::Error::StreamEnded => {
2677 debug!("SSE stream completed normally");
2678 None }
2680 reqwest_eventsource::Error::InvalidStatusCode(status, response) => {
2681 let status_code = status.as_u16();
2683
2684 let error_body = match response.text().await {
2686 Ok(body) => body,
2687 Err(_) => "Failed to read error response body".to_string()
2688 };
2689
2690 error!("SSE connection error - HTTP {}: {}", status_code, error_body);
2691
2692 let detailed_error = format!(
2693 "HTTP {} error: {}",
2694 status_code,
2695 error_body
2696 );
2697
2698 Some(Err(StreamingError::Connection(detailed_error)))
2699 }
2700 _ => {
2701 let error_str = e.to_string();
2702 if error_str.contains("stream closed") {
2703 debug!("SSE stream closed");
2704 None
2705 } else {
2706 error!("SSE connection error: {}", e);
2707 Some(Err(StreamingError::Connection(error_str)))
2708 }
2709 }
2710 }
2711 }
2712 }
2713 });
2714
2715 Ok(stream)
2716 }
2717 })
2718 }
2719
2720 fn generate_reconnection_utilities(
2722 &self,
2723 reconnect_config: &crate::streaming::ReconnectionConfig,
2724 ) -> Result<TokenStream> {
2725 let max_retries = reconnect_config.max_retries;
2726 let initial_delay = reconnect_config.initial_delay_ms;
2727 let max_delay = reconnect_config.max_delay_ms;
2728 let backoff_multiplier = reconnect_config.backoff_multiplier;
2729
2730 Ok(quote! {
2731 #[derive(Debug, Clone)]
2733 pub struct ReconnectionManager {
2734 max_retries: u32,
2735 initial_delay_ms: u64,
2736 max_delay_ms: u64,
2737 backoff_multiplier: f64,
2738 current_attempt: u32,
2739 }
2740
2741 impl ReconnectionManager {
2742 pub fn new() -> Self {
2744 Self {
2745 max_retries: #max_retries,
2746 initial_delay_ms: #initial_delay,
2747 max_delay_ms: #max_delay,
2748 backoff_multiplier: #backoff_multiplier,
2749 current_attempt: 0,
2750 }
2751 }
2752
2753 pub fn should_retry(&self) -> bool {
2755 self.current_attempt < self.max_retries
2756 }
2757
2758 pub fn next_retry_delay(&mut self) -> Duration {
2760 if !self.should_retry() {
2761 return Duration::from_secs(0);
2762 }
2763
2764 let delay_ms = (self.initial_delay_ms as f64
2765 * self.backoff_multiplier.powi(self.current_attempt as i32)) as u64;
2766 let delay_ms = delay_ms.min(self.max_delay_ms);
2767
2768 self.current_attempt += 1;
2769 Duration::from_millis(delay_ms)
2770 }
2771
2772 pub fn reset(&mut self) {
2774 self.current_attempt = 0;
2775 }
2776
2777 pub fn current_attempt(&self) -> u32 {
2779 self.current_attempt
2780 }
2781 }
2782
2783 impl Default for ReconnectionManager {
2784 fn default() -> Self {
2785 Self::new()
2786 }
2787 }
2788 })
2789 }
2790}