1use crate::{GeneratorError, Result, analysis::SchemaAnalysis, streaming::StreamingConfig};
2use proc_macro2::TokenStream;
3use quote::{format_ident, quote};
4use std::collections::BTreeMap;
5use std::path::PathBuf;
6
7#[derive(Clone)]
9struct DiscriminatedVariantInfo {
10 discriminator_field: String,
12 discriminator_value: String,
14 is_parent_untagged: bool,
16}
17
18#[derive(Debug, Clone)]
19pub struct GeneratorConfig {
20 pub spec_path: PathBuf,
22 pub output_dir: PathBuf,
24 pub module_name: String,
26 pub enable_sse_client: bool,
28 pub enable_async_client: bool,
30 pub enable_specta: bool,
32 pub type_mappings: BTreeMap<String, String>,
34 pub streaming_config: Option<StreamingConfig>,
36 pub nullable_field_overrides: BTreeMap<String, bool>,
39 pub schema_extensions: Vec<PathBuf>,
42 pub http_client_config: Option<crate::http_config::HttpClientConfig>,
44 pub retry_config: Option<crate::http_config::RetryConfig>,
46 pub tracing_enabled: bool,
48 pub auth_config: Option<crate::http_config::AuthConfig>,
50 pub enable_registry: bool,
52 pub registry_only: bool,
54}
55
56impl Default for GeneratorConfig {
57 fn default() -> Self {
58 Self {
59 spec_path: "openapi.json".into(),
60 output_dir: "src/gen".into(),
61 module_name: "api_types".to_string(),
62 enable_sse_client: true,
63 enable_async_client: true,
64 enable_specta: false,
65 type_mappings: default_type_mappings(),
66 streaming_config: None,
67 nullable_field_overrides: BTreeMap::new(),
68 schema_extensions: Vec::new(),
69 http_client_config: None,
70 retry_config: None,
71 tracing_enabled: true,
72 auth_config: None,
73 enable_registry: false,
74 registry_only: false,
75 }
76 }
77}
78
79pub fn default_type_mappings() -> BTreeMap<String, String> {
80 let mut mappings = BTreeMap::new();
81 mappings.insert("integer".to_string(), "i64".to_string());
82 mappings.insert("number".to_string(), "f64".to_string());
83 mappings.insert("string".to_string(), "String".to_string());
84 mappings.insert("boolean".to_string(), "bool".to_string());
85 mappings
86}
87
88#[derive(Debug, Clone)]
90pub struct GeneratedFile {
91 pub path: PathBuf,
93 pub content: String,
95}
96
97#[derive(Debug, Clone)]
99pub struct GenerationResult {
100 pub files: Vec<GeneratedFile>,
102 pub mod_file: GeneratedFile,
104}
105
106pub struct CodeGenerator {
107 config: GeneratorConfig,
108}
109
110impl CodeGenerator {
111 pub fn new(config: GeneratorConfig) -> Self {
112 Self { config }
113 }
114
115 pub fn config(&self) -> &GeneratorConfig {
117 &self.config
118 }
119
120 pub fn generate_all(&self, analysis: &mut SchemaAnalysis) -> Result<GenerationResult> {
122 let mut files = Vec::new();
123
124 if !self.config.registry_only {
125 let types_content = self.generate_types(analysis)?;
127 files.push(GeneratedFile {
128 path: "types.rs".into(),
129 content: types_content,
130 });
131
132 if let Some(ref streaming_config) = self.config.streaming_config {
134 let streaming_content =
135 self.generate_streaming_client(streaming_config, analysis)?;
136 files.push(GeneratedFile {
137 path: "streaming.rs".into(),
138 content: streaming_content,
139 });
140 }
141
142 if self.config.enable_async_client {
144 let http_content = self.generate_http_client(analysis)?;
145 files.push(GeneratedFile {
146 path: "client.rs".into(),
147 content: http_content,
148 });
149 }
150 }
151
152 if self.config.enable_registry || self.config.registry_only {
154 let registry_content = self.generate_registry(analysis)?;
155 files.push(GeneratedFile {
156 path: "registry.rs".into(),
157 content: registry_content,
158 });
159 }
160
161 let mod_content = self.generate_mod_file(&files)?;
163 let mod_file = GeneratedFile {
164 path: "mod.rs".into(),
165 content: mod_content,
166 };
167
168 Ok(GenerationResult { files, mod_file })
169 }
170
171 pub fn generate(&self, analysis: &mut SchemaAnalysis) -> Result<String> {
173 self.generate_types(analysis)
174 }
175
176 fn generate_types(&self, analysis: &mut SchemaAnalysis) -> Result<String> {
178 let mut type_definitions = TokenStream::new();
179
180 let mut discriminated_variant_info: BTreeMap<String, DiscriminatedVariantInfo> =
183 BTreeMap::new();
184
185 let mut sorted_schemas: Vec<_> = analysis.schemas.iter().collect();
187 sorted_schemas.sort_by_key(|(name, _)| name.as_str());
188
189 for (_parent_name, schema) in sorted_schemas {
190 if let crate::analysis::SchemaType::DiscriminatedUnion {
191 variants,
192 discriminator_field,
193 } = &schema.schema_type
194 {
195 let is_parent_untagged =
197 self.should_use_untagged_discriminated_union(schema, analysis);
198
199 for variant in variants {
200 if let Some(variant_schema) = analysis.schemas.get(&variant.type_name) {
203 if let crate::analysis::SchemaType::Object { properties, .. } =
204 &variant_schema.schema_type
205 {
206 if properties.contains_key(discriminator_field) {
207 discriminated_variant_info.insert(
208 variant.type_name.clone(),
209 DiscriminatedVariantInfo {
210 discriminator_field: discriminator_field.clone(),
211 discriminator_value: variant.discriminator_value.clone(),
212 is_parent_untagged,
213 },
214 );
215 }
216 }
217 }
218 }
219 }
220 }
221
222 let generation_order = analysis.dependencies.topological_sort()?;
224
225 let mut processed = std::collections::HashSet::new();
227
228 for schema_name in generation_order {
230 if let Some(schema) = analysis.schemas.get(&schema_name) {
231 let type_def =
232 self.generate_type_definition(schema, analysis, &discriminated_variant_info)?;
233 if !type_def.is_empty() {
234 type_definitions.extend(type_def);
235 }
236 processed.insert(schema_name);
237 }
238 }
239
240 let mut remaining_schemas: Vec<_> = analysis
243 .schemas
244 .iter()
245 .filter(|(name, _)| !processed.contains(*name))
246 .collect();
247 remaining_schemas.sort_by_key(|(name, _)| name.as_str());
248
249 for (_schema_name, schema) in remaining_schemas {
250 let type_def =
251 self.generate_type_definition(schema, analysis, &discriminated_variant_info)?;
252 if !type_def.is_empty() {
253 type_definitions.extend(type_def);
254 }
255 }
256
257 let generated = quote! {
259 #![allow(clippy::large_enum_variant)]
265 #![allow(clippy::format_in_format_args)]
266 #![allow(clippy::let_unit_value)]
267 #![allow(unreachable_patterns)]
268
269 use serde::{Deserialize, Serialize};
270
271 #type_definitions
272 };
273
274 let syntax_tree = syn::parse2::<syn::File>(generated).map_err(|e| {
276 GeneratorError::CodeGenError(format!("Failed to parse generated code: {e}"))
277 })?;
278
279 let formatted = prettyplease::unparse(&syntax_tree);
280
281 Ok(formatted)
282 }
283
284 fn generate_streaming_client(
286 &self,
287 streaming_config: &StreamingConfig,
288 analysis: &SchemaAnalysis,
289 ) -> Result<String> {
290 let mut client_code = TokenStream::new();
291
292 let imports = quote! {
294 #![allow(clippy::format_in_format_args)]
299 #![allow(clippy::let_unit_value)]
300 #![allow(unused_mut)]
301
302 use super::types::*;
303 use async_trait::async_trait;
304 use futures_util::{Stream, StreamExt};
305 use std::pin::Pin;
306 use std::time::Duration;
307 use reqwest::header::{HeaderMap, HeaderValue};
308 use tracing::{debug, error, info, warn, instrument};
309 };
310 client_code.extend(imports);
311
312 if streaming_config.generate_client {
314 let error_types = self.generate_streaming_error_types()?;
315 client_code.extend(error_types);
316 }
317
318 for endpoint in &streaming_config.endpoints {
320 let trait_code = self.generate_endpoint_trait(endpoint, analysis)?;
321 client_code.extend(trait_code);
322 }
323
324 if streaming_config.generate_client {
326 let client_impl = self.generate_streaming_client_impl(streaming_config, analysis)?;
327 client_code.extend(client_impl);
328 }
329
330 if streaming_config.event_parser_helpers {
332 let parser_code = self.generate_sse_parser_utilities(streaming_config)?;
333 client_code.extend(parser_code);
334 }
335
336 if let Some(reconnect_config) = &streaming_config.reconnection_config {
338 let reconnect_code = self.generate_reconnection_utilities(reconnect_config)?;
339 client_code.extend(reconnect_code);
340 }
341
342 let syntax_tree = syn::parse2::<syn::File>(client_code).map_err(|e| {
343 GeneratorError::CodeGenError(format!("Failed to parse streaming client code: {e}"))
344 })?;
345
346 Ok(prettyplease::unparse(&syntax_tree))
347 }
348
349 pub fn generate_http_client(&self, analysis: &SchemaAnalysis) -> Result<String> {
351 let error_types = self.generate_http_error_types();
352 let client_struct = self.generate_http_client_struct();
353 let operation_methods = self.generate_operation_methods(analysis);
354
355 let generated = quote! {
356 #![allow(clippy::format_in_format_args)]
361 #![allow(clippy::let_unit_value)]
362
363 use super::types::*;
364
365 #error_types
366
367 #client_struct
368
369 #operation_methods
370 };
371
372 let syntax_tree = syn::parse2::<syn::File>(generated).map_err(|e| {
373 GeneratorError::CodeGenError(format!("Failed to parse HTTP client code: {e}"))
374 })?;
375
376 Ok(prettyplease::unparse(&syntax_tree))
377 }
378
379 fn generate_http_error_types(&self) -> TokenStream {
381 quote! {
382 use thiserror::Error;
383
384 #[derive(Error, Debug)]
392 pub enum HttpError {
393 #[error("Network error: {0}")]
395 Network(#[from] reqwest::Error),
396
397 #[error("Middleware error: {0}")]
399 Middleware(#[from] reqwest_middleware::Error),
400
401 #[error("Failed to serialize request: {0}")]
403 Serialization(String),
404
405 #[error("Authentication error: {0}")]
407 Auth(String),
408
409 #[error("Request timeout")]
411 Timeout,
412
413 #[error("Configuration error: {0}")]
415 Config(String),
416
417 #[error("{0}")]
419 Other(String),
420 }
421
422 impl HttpError {
423 pub fn serialization_error(error: impl std::fmt::Display) -> Self {
425 Self::Serialization(error.to_string())
426 }
427
428 pub fn is_retryable(&self) -> bool {
430 matches!(self, Self::Network(_) | Self::Middleware(_) | Self::Timeout)
431 }
432 }
433
434 #[derive(Debug, Clone)]
444 pub struct ApiError<E> {
445 pub status: u16,
446 pub headers: reqwest::header::HeaderMap,
447 pub body: String,
448 pub typed: Option<E>,
449 pub parse_error: Option<String>,
450 }
451
452 impl<E> ApiError<E> {
453 pub fn is_client_error(&self) -> bool {
454 (400..500).contains(&self.status)
455 }
456
457 pub fn is_server_error(&self) -> bool {
458 (500..600).contains(&self.status)
459 }
460
461 pub fn is_retryable(&self) -> bool {
464 matches!(self.status, 429 | 500 | 502 | 503 | 504)
465 }
466 }
467
468 impl<E: std::fmt::Debug> std::fmt::Display for ApiError<E> {
469 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
470 write!(f, "API error {}: {}", self.status, self.body)
471 }
472 }
473
474 impl<E: std::fmt::Debug> std::error::Error for ApiError<E> {}
475
476 #[derive(Debug, Error)]
484 pub enum ApiOpError<E: std::fmt::Debug> {
485 #[error(transparent)]
486 Transport(#[from] HttpError),
487
488 #[error(transparent)]
489 Api(ApiError<E>),
490 }
491
492 impl<E: std::fmt::Debug> ApiOpError<E> {
493 pub fn api(&self) -> Option<&ApiError<E>> {
495 match self {
496 Self::Api(e) => Some(e),
497 Self::Transport(_) => None,
498 }
499 }
500
501 pub fn is_api_error(&self) -> bool {
504 matches!(self, Self::Api(_))
505 }
506 }
507
508 impl<E: std::fmt::Debug> From<reqwest::Error> for ApiOpError<E> {
511 fn from(e: reqwest::Error) -> Self {
512 Self::Transport(HttpError::Network(e))
513 }
514 }
515
516 impl<E: std::fmt::Debug> From<reqwest_middleware::Error> for ApiOpError<E> {
517 fn from(e: reqwest_middleware::Error) -> Self {
518 Self::Transport(HttpError::Middleware(e))
519 }
520 }
521
522 pub type HttpResult<T> = Result<T, HttpError>;
526 }
527 }
528
529 fn generate_mod_file(&self, files: &[GeneratedFile]) -> Result<String> {
531 let mut module_declarations = Vec::new();
532 let mut pub_uses = Vec::new();
533
534 for file in files {
535 if let Some(module_name) = file.path.file_stem().and_then(|s| s.to_str()) {
536 if module_name != "mod" {
537 module_declarations.push(format!("pub mod {module_name};"));
538 pub_uses.push(format!("pub use {module_name}::*;"));
539 }
540 }
541 }
542
543 let content = format!(
544 r#"//! Generated API modules
545//!
546//! This module exports all generated API types and clients.
547//! Do not edit manually - regenerate using the appropriate script.
548
549#![allow(unused_imports)]
550
551{}
552
553{}
554"#,
555 module_declarations.join("\n"),
556 pub_uses.join("\n")
557 );
558
559 Ok(content)
560 }
561
562 pub fn write_files(&self, result: &GenerationResult) -> Result<()> {
564 use std::fs;
565
566 fs::create_dir_all(&self.config.output_dir)?;
568
569 for file in &result.files {
571 let file_path = self.config.output_dir.join(&file.path);
572 fs::write(&file_path, &file.content)?;
573 }
574
575 let mod_path = self.config.output_dir.join(&result.mod_file.path);
577 fs::write(&mod_path, &result.mod_file.content)?;
578
579 Ok(())
580 }
581
582 fn generate_type_definition(
583 &self,
584 schema: &crate::analysis::AnalyzedSchema,
585 analysis: &crate::analysis::SchemaAnalysis,
586 discriminated_variant_info: &BTreeMap<String, DiscriminatedVariantInfo>,
587 ) -> Result<TokenStream> {
588 use crate::analysis::SchemaType;
589
590 match &schema.schema_type {
591 SchemaType::Primitive { rust_type } => {
592 self.generate_type_alias(schema, rust_type)
594 }
595 SchemaType::StringEnum { values } => self.generate_string_enum(schema, values),
596 SchemaType::ExtensibleEnum { known_values } => {
597 self.generate_extensible_enum(schema, known_values)
598 }
599 SchemaType::Object {
600 properties,
601 required,
602 additional_properties,
603 } => self.generate_struct(
604 schema,
605 properties,
606 required,
607 *additional_properties,
608 analysis,
609 discriminated_variant_info.get(&schema.name),
610 ),
611 SchemaType::DiscriminatedUnion {
612 discriminator_field,
613 variants,
614 } => {
615 if self.should_use_untagged_discriminated_union(schema, analysis) {
617 let schema_refs: Vec<crate::analysis::SchemaRef> = variants
619 .iter()
620 .map(|v| crate::analysis::SchemaRef {
621 target: v.type_name.clone(),
622 nullable: false,
623 })
624 .collect();
625 self.generate_union_enum(schema, &schema_refs)
626 } else {
627 self.generate_discriminated_enum(
628 schema,
629 discriminator_field,
630 variants,
631 analysis,
632 )
633 }
634 }
635 SchemaType::Union { variants } => self.generate_union_enum(schema, variants),
636 SchemaType::Reference { target } => {
637 if schema.name != *target {
640 let alias_name = format_ident!("{}", self.to_rust_type_name(&schema.name));
642 let target_type = format_ident!("{}", self.to_rust_type_name(target));
643
644 let doc_comment = if let Some(desc) = &schema.description {
645 quote! { #[doc = #desc] }
646 } else {
647 TokenStream::new()
648 };
649
650 Ok(quote! {
651 #doc_comment
652 pub type #alias_name = #target_type;
653 })
654 } else {
655 Ok(TokenStream::new())
657 }
658 }
659 SchemaType::Array { item_type } => {
660 let array_name = format_ident!("{}", self.to_rust_type_name(&schema.name));
668
669 if let SchemaType::Reference { target } = item_type.as_ref() {
671 if let Some(info) = discriminated_variant_info.get(target) {
672 if !info.is_parent_untagged {
673 let wrapper_name =
675 format_ident!("{}Item", self.to_rust_type_name(&schema.name));
676 let variant_type = format_ident!("{}", self.to_rust_type_name(target));
677 let disc_field = &info.discriminator_field;
678 let disc_value = &info.discriminator_value;
679
680 let doc_comment = if let Some(desc) = &schema.description {
681 quote! { #[doc = #desc] }
682 } else {
683 TokenStream::new()
684 };
685
686 return Ok(quote! {
687 #[derive(Debug, Clone, Deserialize, Serialize)]
691 #[serde(tag = #disc_field)]
692 pub enum #wrapper_name {
693 #[serde(rename = #disc_value)]
694 #variant_type(#variant_type),
695 }
696 #doc_comment
697 pub type #array_name = Vec<#wrapper_name>;
698 });
699 }
700 }
701 }
702
703 let inner_type = self.generate_array_item_type(item_type, analysis);
704
705 let doc_comment = if let Some(desc) = &schema.description {
706 quote! { #[doc = #desc] }
707 } else {
708 TokenStream::new()
709 };
710
711 Ok(quote! {
712 #doc_comment
713 pub type #array_name = Vec<#inner_type>;
714 })
715 }
716 SchemaType::Composition { schemas } => {
717 self.generate_composition_struct(schema, schemas)
718 }
719 }
720 }
721
722 fn generate_type_alias(
723 &self,
724 schema: &crate::analysis::AnalyzedSchema,
725 rust_type: &str,
726 ) -> Result<TokenStream> {
727 let type_name = format_ident!("{}", self.to_rust_type_name(&schema.name));
728
729 let base_type = if rust_type.contains("::") {
731 let parts: Vec<&str> = rust_type.split("::").collect();
732 if parts.len() == 2 {
733 let module = format_ident!("{}", parts[0]);
734 let type_name_part = format_ident!("{}", parts[1]);
735 quote! { #module::#type_name_part }
736 } else {
737 let path_parts: Vec<_> = parts.iter().map(|p| format_ident!("{}", p)).collect();
739 quote! { #(#path_parts)::* }
740 }
741 } else {
742 let simple_type = format_ident!("{}", rust_type);
743 quote! { #simple_type }
744 };
745
746 let doc_comment = if let Some(desc) = &schema.description {
747 let sanitized_desc = self.sanitize_doc_comment(desc);
748 quote! { #[doc = #sanitized_desc] }
749 } else {
750 TokenStream::new()
751 };
752
753 Ok(quote! {
754 #doc_comment
755 pub type #type_name = #base_type;
756 })
757 }
758
759 fn generate_extensible_enum(
760 &self,
761 schema: &crate::analysis::AnalyzedSchema,
762 known_values: &[String],
763 ) -> Result<TokenStream> {
764 let enum_name = format_ident!("{}", self.to_rust_type_name(&schema.name));
765
766 let doc_comment = if let Some(desc) = &schema.description {
767 quote! { #[doc = #desc] }
768 } else {
769 TokenStream::new()
770 };
771
772 let known_variants = known_values.iter().map(|value| {
777 let variant_name = self.to_rust_enum_variant(value);
778 let variant_ident = format_ident!("{}", variant_name);
779 quote! {
780 #variant_ident,
781 }
782 });
783
784 let match_arms_de = known_values.iter().map(|value| {
785 let variant_name = self.to_rust_enum_variant(value);
786 let variant_ident = format_ident!("{}", variant_name);
787 quote! {
788 #value => Ok(#enum_name::#variant_ident),
789 }
790 });
791
792 let match_arms_ser = known_values.iter().map(|value| {
793 let variant_name = self.to_rust_enum_variant(value);
794 let variant_ident = format_ident!("{}", variant_name);
795 quote! {
796 #enum_name::#variant_ident => #value,
797 }
798 });
799
800 let derives = if self.config.enable_specta {
801 quote! {
802 #[derive(Debug, Clone, PartialEq, Eq)]
803 #[cfg_attr(feature = "specta", derive(specta::Type))]
804 }
805 } else {
806 quote! {
807 #[derive(Debug, Clone, PartialEq, Eq)]
808 }
809 };
810
811 Ok(quote! {
812 #doc_comment
813 #derives
814 pub enum #enum_name {
815 #(#known_variants)*
816 Custom(String),
818 }
819
820 impl<'de> serde::Deserialize<'de> for #enum_name {
821 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
822 where
823 D: serde::Deserializer<'de>,
824 {
825 let value = String::deserialize(deserializer)?;
826 match value.as_str() {
827 #(#match_arms_de)*
828 _ => Ok(#enum_name::Custom(value)),
829 }
830 }
831 }
832
833 impl serde::Serialize for #enum_name {
834 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
835 where
836 S: serde::Serializer,
837 {
838 let value = match self {
839 #(#match_arms_ser)*
840 #enum_name::Custom(s) => s.as_str(),
841 };
842 serializer.serialize_str(value)
843 }
844 }
845 })
846 }
847
848 fn generate_string_enum(
849 &self,
850 schema: &crate::analysis::AnalyzedSchema,
851 values: &[String],
852 ) -> Result<TokenStream> {
853 let enum_name = format_ident!("{}", self.to_rust_type_name(&schema.name));
854
855 let default_value = schema
857 .default
858 .as_ref()
859 .and_then(|v| v.as_str())
860 .map(|s| s.to_string());
861
862 let variants = values.iter().enumerate().map(|(i, value)| {
863 let variant_name = self.to_rust_enum_variant(value);
865 let variant_ident = format_ident!("{}", variant_name);
866
867 let is_default = if let Some(ref default) = default_value {
869 value == default
870 } else {
871 i == 0 };
873
874 if is_default {
875 quote! {
876 #[default]
877 #[serde(rename = #value)]
878 #variant_ident,
879 }
880 } else {
881 quote! {
882 #[serde(rename = #value)]
883 #variant_ident,
884 }
885 }
886 });
887
888 let doc_comment = if let Some(desc) = &schema.description {
889 quote! { #[doc = #desc] }
890 } else {
891 TokenStream::new()
892 };
893
894 let derives = if self.config.enable_specta {
896 quote! {
897 #[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize, Default)]
898 #[cfg_attr(feature = "specta", derive(specta::Type))]
899 }
900 } else {
901 quote! {
902 #[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize, Default)]
903 }
904 };
905
906 Ok(quote! {
907 #doc_comment
908 #derives
909 pub enum #enum_name {
910 #(#variants)*
911 }
912 })
913 }
914
915 fn generate_struct(
916 &self,
917 schema: &crate::analysis::AnalyzedSchema,
918 properties: &BTreeMap<String, crate::analysis::PropertyInfo>,
919 required: &std::collections::HashSet<String>,
920 additional_properties: bool,
921 analysis: &crate::analysis::SchemaAnalysis,
922 discriminator_info: Option<&DiscriminatedVariantInfo>,
923 ) -> Result<TokenStream> {
924 let struct_name = format_ident!("{}", self.to_rust_type_name(&schema.name));
925
926 let mut sorted_properties: Vec<_> = properties.iter().collect();
928 sorted_properties.sort_by_key(|(name, _)| name.as_str());
929
930 let mut fields: Vec<TokenStream> = sorted_properties
931 .into_iter()
932 .filter(|(field_name, _)| {
933 if let Some(info) = discriminator_info {
937 if !info.is_parent_untagged
938 && field_name.as_str() == info.discriminator_field.as_str()
939 {
940 false } else {
942 true }
944 } else {
945 true }
947 })
948 .map(|(field_name, prop)| {
949 let field_ident = Self::to_field_ident(&self.to_rust_field_name(field_name));
950 let is_required = required.contains(field_name);
951 let field_type =
952 self.generate_field_type(&schema.name, field_name, prop, is_required, analysis);
953
954 let serde_attrs =
955 self.generate_serde_field_attrs(field_name, prop, is_required, analysis);
956 let specta_attrs = self.generate_specta_field_attrs(field_name);
957
958 let doc_comment = if let Some(desc) = &prop.description {
959 let sanitized_desc = self.sanitize_doc_comment(desc);
960 quote! { #[doc = #sanitized_desc] }
961 } else {
962 TokenStream::new()
963 };
964
965 quote! {
966 #doc_comment
967 #serde_attrs
968 #specta_attrs
969 pub #field_ident: #field_type,
970 }
971 })
972 .collect();
973
974 if additional_properties {
976 fields.push(quote! {
977 #[serde(flatten)]
979 pub additional_properties: std::collections::BTreeMap<String, serde_json::Value>,
980 });
981 }
982
983 let doc_comment = if let Some(desc) = &schema.description {
984 quote! { #[doc = #desc] }
985 } else {
986 TokenStream::new()
987 };
988
989 let derives = if self.config.enable_specta {
993 quote! {
994 #[derive(Debug, Clone, Deserialize, Serialize)]
995 #[cfg_attr(feature = "specta", derive(specta::Type))]
996 }
997 } else {
998 quote! {
999 #[derive(Debug, Clone, Deserialize, Serialize)]
1000 }
1001 };
1002
1003 Ok(quote! {
1004 #doc_comment
1005 #derives
1006 pub struct #struct_name {
1007 #(#fields)*
1008 }
1009 })
1010 }
1011
1012 fn generate_discriminated_enum(
1013 &self,
1014 schema: &crate::analysis::AnalyzedSchema,
1015 discriminator_field: &str,
1016 variants: &[crate::analysis::UnionVariant],
1017 analysis: &crate::analysis::SchemaAnalysis,
1018 ) -> Result<TokenStream> {
1019 let enum_name = format_ident!("{}", self.to_rust_type_name(&schema.name));
1020
1021 let has_nested_discriminated_union = variants.iter().any(|variant| {
1023 if let Some(variant_schema) = analysis.schemas.get(&variant.type_name) {
1024 matches!(
1025 variant_schema.schema_type,
1026 crate::analysis::SchemaType::DiscriminatedUnion { .. }
1027 )
1028 } else {
1029 false
1030 }
1031 });
1032
1033 if has_nested_discriminated_union {
1035 let schema_refs: Vec<crate::analysis::SchemaRef> = variants
1037 .iter()
1038 .map(|v| crate::analysis::SchemaRef {
1039 target: v.type_name.clone(),
1040 nullable: false,
1041 })
1042 .collect();
1043 return self.generate_union_enum(schema, &schema_refs);
1044 }
1045
1046 let enum_variants = variants.iter().map(|variant| {
1047 let variant_name = format_ident!("{}", variant.rust_name);
1048 let variant_value = &variant.discriminator_value;
1049
1050 let variant_type = format_ident!("{}", self.to_rust_type_name(&variant.type_name));
1053 quote! {
1054 #[serde(rename = #variant_value)]
1055 #variant_name(#variant_type),
1056 }
1057 });
1058
1059 let doc_comment = if let Some(desc) = &schema.description {
1060 quote! { #[doc = #desc] }
1061 } else {
1062 TokenStream::new()
1063 };
1064
1065 let derives = if self.config.enable_specta {
1067 quote! {
1068 #[derive(Debug, Clone, Deserialize, Serialize)]
1069 #[cfg_attr(feature = "specta", derive(specta::Type))]
1070 #[serde(tag = #discriminator_field)]
1071 }
1072 } else {
1073 quote! {
1074 #[derive(Debug, Clone, Deserialize, Serialize)]
1075 #[serde(tag = #discriminator_field)]
1076 }
1077 };
1078
1079 Ok(quote! {
1080 #doc_comment
1081 #derives
1082 pub enum #enum_name {
1083 #(#enum_variants)*
1084 }
1085 })
1086 }
1087
1088 fn should_use_untagged_discriminated_union(
1090 &self,
1091 schema: &crate::analysis::AnalyzedSchema,
1092 analysis: &crate::analysis::SchemaAnalysis,
1093 ) -> bool {
1094 for other_schema in analysis.schemas.values() {
1099 if let crate::analysis::SchemaType::DiscriminatedUnion {
1100 variants,
1101 discriminator_field: _,
1102 } = &other_schema.schema_type
1103 {
1104 for variant in variants {
1105 if variant.type_name == schema.name {
1106 if let crate::analysis::SchemaType::DiscriminatedUnion {
1111 discriminator_field: current_discriminator,
1112 variants: current_variants,
1113 ..
1114 } = &schema.schema_type
1115 {
1116 for current_variant in current_variants {
1118 if let Some(variant_schema) =
1119 analysis.schemas.get(¤t_variant.type_name)
1120 {
1121 if let crate::analysis::SchemaType::Object {
1122 properties, ..
1123 } = &variant_schema.schema_type
1124 {
1125 if properties.contains_key(current_discriminator) {
1126 return false;
1129 }
1130 }
1131 }
1132 }
1133 }
1134
1135 return true;
1137 }
1138 }
1139 }
1140 }
1141 false
1142 }
1143
1144 fn generate_union_enum(
1145 &self,
1146 schema: &crate::analysis::AnalyzedSchema,
1147 variants: &[crate::analysis::SchemaRef],
1148 ) -> Result<TokenStream> {
1149 let enum_name = format_ident!("{}", self.to_rust_type_name(&schema.name));
1150
1151 let mut used_variant_names = std::collections::HashSet::new();
1153 let enum_variants = variants.iter().enumerate().map(|(i, variant)| {
1154 let base_variant_name = self.type_name_to_variant_name(&variant.target);
1156 let variant_name = self.ensure_unique_variant_name_generator(
1157 base_variant_name,
1158 &mut used_variant_names,
1159 i,
1160 );
1161 let variant_name_ident = format_ident!("{}", variant_name);
1162
1163 let variant_type_tokens = if matches!(
1165 variant.target.as_str(),
1166 "bool"
1167 | "i8"
1168 | "i16"
1169 | "i32"
1170 | "i64"
1171 | "i128"
1172 | "u8"
1173 | "u16"
1174 | "u32"
1175 | "u64"
1176 | "u128"
1177 | "f32"
1178 | "f64"
1179 | "String"
1180 ) {
1181 let type_ident = format_ident!("{}", variant.target);
1182 quote! { #type_ident }
1183 } else if variant.target == "serde_json::Value" {
1184 quote! { serde_json::Value }
1187 } else if variant.target.starts_with("Vec<") && variant.target.ends_with(">") {
1188 let inner = &variant.target[4..variant.target.len() - 1];
1190
1191 if inner.starts_with("Vec<") && inner.ends_with(">") {
1193 let inner_inner = &inner[4..inner.len() - 1];
1194 if inner_inner == "serde_json::Value" {
1195 quote! { Vec<Vec<serde_json::Value>> }
1196 } else {
1197 let inner_inner_type = if matches!(
1198 inner_inner,
1199 "bool"
1200 | "i8"
1201 | "i16"
1202 | "i32"
1203 | "i64"
1204 | "i128"
1205 | "u8"
1206 | "u16"
1207 | "u32"
1208 | "u64"
1209 | "u128"
1210 | "f32"
1211 | "f64"
1212 | "String"
1213 ) {
1214 format_ident!("{}", inner_inner)
1215 } else {
1216 format_ident!("{}", self.to_rust_type_name(inner_inner))
1217 };
1218 quote! { Vec<Vec<#inner_inner_type>> }
1219 }
1220 } else if inner == "serde_json::Value" {
1221 quote! { Vec<serde_json::Value> }
1222 } else {
1223 let inner_type = if matches!(
1224 inner,
1225 "bool"
1226 | "i8"
1227 | "i16"
1228 | "i32"
1229 | "i64"
1230 | "i128"
1231 | "u8"
1232 | "u16"
1233 | "u32"
1234 | "u64"
1235 | "u128"
1236 | "f32"
1237 | "f64"
1238 | "String"
1239 ) {
1240 format_ident!("{}", inner)
1241 } else {
1242 format_ident!("{}", self.to_rust_type_name(inner))
1243 };
1244 quote! { Vec<#inner_type> }
1245 }
1246 } else {
1247 let type_ident = format_ident!("{}", self.to_rust_type_name(&variant.target));
1248 quote! { #type_ident }
1249 };
1250
1251 quote! {
1252 #variant_name_ident(#variant_type_tokens),
1253 }
1254 });
1255
1256 let doc_comment = if let Some(desc) = &schema.description {
1257 quote! { #[doc = #desc] }
1258 } else {
1259 TokenStream::new()
1260 };
1261
1262 let derives = if self.config.enable_specta {
1264 quote! {
1265 #[derive(Debug, Clone, Deserialize, Serialize)]
1266 #[cfg_attr(feature = "specta", derive(specta::Type))]
1267 #[serde(untagged)]
1268 }
1269 } else {
1270 quote! {
1271 #[derive(Debug, Clone, Deserialize, Serialize)]
1272 #[serde(untagged)]
1273 }
1274 };
1275
1276 Ok(quote! {
1277 #doc_comment
1278 #derives
1279 pub enum #enum_name {
1280 #(#enum_variants)*
1281 }
1282 })
1283 }
1284
1285 fn generate_field_type(
1286 &self,
1287 schema_name: &str,
1288 field_name: &str,
1289 prop: &crate::analysis::PropertyInfo,
1290 is_required: bool,
1291 analysis: &crate::analysis::SchemaAnalysis,
1292 ) -> TokenStream {
1293 use crate::analysis::SchemaType;
1294
1295 let base_type = match &prop.schema_type {
1296 SchemaType::Primitive { rust_type } => {
1297 if rust_type.contains("::") {
1299 let parts: Vec<&str> = rust_type.split("::").collect();
1300 if parts.len() == 2 {
1301 let module = format_ident!("{}", parts[0]);
1302 let type_name = format_ident!("{}", parts[1]);
1303 quote! { #module::#type_name }
1304 } else {
1305 let path_parts: Vec<_> =
1307 parts.iter().map(|p| format_ident!("{}", p)).collect();
1308 quote! { #(#path_parts)::* }
1309 }
1310 } else {
1311 let type_ident = format_ident!("{}", rust_type);
1312 quote! { #type_ident }
1313 }
1314 }
1315 SchemaType::Reference { target } => {
1316 let target_type = format_ident!("{}", self.to_rust_type_name(target));
1317 if analysis.dependencies.recursive_schemas.contains(target) {
1319 quote! { Box<#target_type> }
1320 } else {
1321 quote! { #target_type }
1322 }
1323 }
1324 SchemaType::Array { item_type } => {
1325 let inner_type = self.generate_array_item_type(item_type, analysis);
1326 quote! { Vec<#inner_type> }
1327 }
1328 _ => {
1329 quote! { serde_json::Value }
1331 }
1332 };
1333
1334 let override_key = format!("{schema_name}.{field_name}");
1336 let is_nullable_override = self
1337 .config
1338 .nullable_field_overrides
1339 .get(&override_key)
1340 .copied()
1341 .unwrap_or(false);
1342
1343 if is_required && !prop.nullable && !is_nullable_override {
1344 if prop.default.is_some() && self.type_lacks_default(&prop.schema_type, analysis) {
1347 quote! { Option<#base_type> }
1348 } else {
1349 base_type
1350 }
1351 } else {
1352 quote! { Option<#base_type> }
1353 }
1354 }
1355
1356 fn generate_serde_field_attrs(
1357 &self,
1358 field_name: &str,
1359 prop: &crate::analysis::PropertyInfo,
1360 is_required: bool,
1361 analysis: &crate::analysis::SchemaAnalysis,
1362 ) -> TokenStream {
1363 let mut attrs = Vec::new();
1364
1365 let rust_field_name = self.to_rust_field_name(field_name);
1368 let comparison_name = rust_field_name
1369 .strip_prefix("r#")
1370 .unwrap_or(&rust_field_name);
1371 if comparison_name != field_name {
1372 attrs.push(quote! { rename = #field_name });
1373 }
1374
1375 if !is_required || prop.nullable {
1377 attrs.push(quote! { skip_serializing_if = "Option::is_none" });
1378 }
1379
1380 if prop.default.is_some()
1384 && (is_required && !prop.nullable)
1385 && !self.type_lacks_default(&prop.schema_type, analysis)
1386 {
1387 attrs.push(quote! { default });
1388 }
1389
1390 if attrs.is_empty() {
1391 TokenStream::new()
1392 } else {
1393 quote! { #[serde(#(#attrs),*)] }
1394 }
1395 }
1396
1397 fn type_lacks_default(
1401 &self,
1402 schema_type: &crate::analysis::SchemaType,
1403 analysis: &crate::analysis::SchemaAnalysis,
1404 ) -> bool {
1405 use crate::analysis::SchemaType;
1406 match schema_type {
1407 SchemaType::DiscriminatedUnion { .. } | SchemaType::Union { .. } => true,
1408 SchemaType::Reference { target } => {
1409 if let Some(schema) = analysis.schemas.get(target) {
1410 self.type_lacks_default(&schema.schema_type, analysis)
1411 } else {
1412 false
1413 }
1414 }
1415 _ => false,
1416 }
1417 }
1418
1419 fn generate_specta_field_attrs(&self, field_name: &str) -> TokenStream {
1420 if !self.config.enable_specta {
1421 return TokenStream::new();
1422 }
1423
1424 let camel_case_name = self.to_camel_case(field_name);
1426
1427 if camel_case_name != field_name {
1429 quote! { #[cfg_attr(feature = "specta", specta(rename = #camel_case_name))] }
1430 } else {
1431 TokenStream::new()
1432 }
1433 }
1434
1435 fn to_rust_enum_variant(&self, s: &str) -> String {
1436 let mut result = String::new();
1438 let mut next_upper = true;
1439 let mut prev_was_upper = false;
1440
1441 for (i, c) in s.chars().enumerate() {
1442 match c {
1443 'a'..='z' => {
1444 if next_upper {
1445 result.push(c.to_ascii_uppercase());
1446 next_upper = false;
1447 } else {
1448 result.push(c);
1449 }
1450 prev_was_upper = false;
1451 }
1452 'A'..='Z' => {
1453 if next_upper || (!prev_was_upper && i > 0) {
1454 result.push(c);
1456 next_upper = false;
1457 } else {
1458 result.push(c.to_ascii_lowercase());
1460 }
1461 prev_was_upper = true;
1462 }
1463 '0'..='9' => {
1464 result.push(c);
1465 next_upper = false;
1466 prev_was_upper = false;
1467 }
1468 '.' | '-' | '_' | ' ' | '@' | '#' | '$' | '/' | '\\' => {
1469 next_upper = true;
1471 prev_was_upper = false;
1472 }
1473 _ => {
1474 next_upper = true;
1476 prev_was_upper = false;
1477 }
1478 }
1479 }
1480
1481 if result.is_empty() {
1483 result = "Value".to_string();
1484 }
1485
1486 if result.chars().next().is_some_and(|c| c.is_ascii_digit()) {
1488 result = format!("Variant{result}");
1489 }
1490
1491 match result.as_str() {
1493 "Null" => "NullValue".to_string(),
1494 "True" => "TrueValue".to_string(),
1495 "False" => "FalseValue".to_string(),
1496 "Type" => "Type_".to_string(),
1497 "Match" => "Match_".to_string(),
1498 "Fn" => "Fn_".to_string(),
1499 "Impl" => "Impl_".to_string(),
1500 "Trait" => "Trait_".to_string(),
1501 "Struct" => "Struct_".to_string(),
1502 "Enum" => "Enum_".to_string(),
1503 "Mod" => "Mod_".to_string(),
1504 "Use" => "Use_".to_string(),
1505 "Pub" => "Pub_".to_string(),
1506 "Const" => "Const_".to_string(),
1507 "Static" => "Static_".to_string(),
1508 "Let" => "Let_".to_string(),
1509 "Mut" => "Mut_".to_string(),
1510 "Ref" => "Ref_".to_string(),
1511 "Move" => "Move_".to_string(),
1512 "Return" => "Return_".to_string(),
1513 "If" => "If_".to_string(),
1514 "Else" => "Else_".to_string(),
1515 "While" => "While_".to_string(),
1516 "For" => "For_".to_string(),
1517 "Loop" => "Loop_".to_string(),
1518 "Break" => "Break_".to_string(),
1519 "Continue" => "Continue_".to_string(),
1520 "Self" => "Self_".to_string(),
1521 "Super" => "Super_".to_string(),
1522 "Crate" => "Crate_".to_string(),
1523 "Async" => "Async_".to_string(),
1524 "Await" => "Await_".to_string(),
1525 _ => result,
1526 }
1527 }
1528
1529 #[allow(dead_code)]
1530 fn to_rust_identifier(&self, s: &str) -> String {
1531 let mut result = s
1533 .chars()
1534 .map(|c| match c {
1535 'a'..='z' | 'A'..='Z' | '0'..='9' => c,
1536 '.' | '-' | '_' | ' ' | '@' | '#' | '$' | '/' | '\\' => '_',
1537 _ => '_',
1538 })
1539 .collect::<String>();
1540
1541 result = result.trim_matches('_').to_string();
1543
1544 if result.is_empty() {
1546 result = "value".to_string();
1547 }
1548
1549 if result.chars().next().is_some_and(|c| c.is_ascii_digit()) {
1551 result = format!("variant_{result}");
1552 }
1553
1554 match result.as_str() {
1556 "null" => "null_value".to_string(),
1557 "true" => "true_value".to_string(),
1558 "false" => "false_value".to_string(),
1559 "type" => "type_".to_string(),
1560 "match" => "match_".to_string(),
1561 "fn" => "fn_".to_string(),
1562 "impl" => "impl_".to_string(),
1563 "trait" => "trait_".to_string(),
1564 "struct" => "struct_".to_string(),
1565 "enum" => "enum_".to_string(),
1566 "mod" => "mod_".to_string(),
1567 "use" => "use_".to_string(),
1568 "pub" => "pub_".to_string(),
1569 "const" => "const_".to_string(),
1570 "static" => "static_".to_string(),
1571 "let" => "let_".to_string(),
1572 "mut" => "mut_".to_string(),
1573 "ref" => "ref_".to_string(),
1574 "move" => "move_".to_string(),
1575 "return" => "return_".to_string(),
1576 "if" => "if_".to_string(),
1577 "else" => "else_".to_string(),
1578 "while" => "while_".to_string(),
1579 "for" => "for_".to_string(),
1580 "loop" => "loop_".to_string(),
1581 "break" => "break_".to_string(),
1582 "continue" => "continue_".to_string(),
1583 "self" => "self_".to_string(),
1584 "super" => "super_".to_string(),
1585 "crate" => "crate_".to_string(),
1586 "async" => "async_".to_string(),
1587 "await" => "await_".to_string(),
1588 "override" => "override_".to_string(),
1590 "box" => "box_".to_string(),
1591 "dyn" => "dyn_".to_string(),
1592 "where" => "where_".to_string(),
1593 "in" => "in_".to_string(),
1594 "abstract" => "abstract_".to_string(),
1596 "become" => "become_".to_string(),
1597 "do" => "do_".to_string(),
1598 "final" => "final_".to_string(),
1599 "macro" => "macro_".to_string(),
1600 "priv" => "priv_".to_string(),
1601 "try" => "try_".to_string(),
1602 "typeof" => "typeof_".to_string(),
1603 "unsized" => "unsized_".to_string(),
1604 "virtual" => "virtual_".to_string(),
1605 "yield" => "yield_".to_string(),
1606 _ => result,
1607 }
1608 }
1609
1610 fn sanitize_doc_comment(&self, desc: &str) -> String {
1611 let mut result = desc.to_string();
1613
1614 if result.contains('\n')
1622 && (result.contains('{')
1623 || result.contains("```")
1624 || result.contains("Human:")
1625 || result.contains("Assistant:")
1626 || result
1627 .lines()
1628 .any(|line| line.trim().starts_with('"') && line.trim().ends_with('"')))
1629 {
1630 if result.contains("```") {
1632 result = result.replace("```", "```ignore");
1633 } else {
1634 if result.lines().any(|line| {
1636 let trimmed = line.trim();
1637 trimmed.starts_with('"') && trimmed.ends_with('"') && trimmed.len() > 2
1638 }) {
1639 result = format!("```ignore\n{result}\n```");
1640 }
1641 }
1642 }
1643
1644 result
1645 }
1646
1647 pub(crate) fn to_rust_type_name(&self, s: &str) -> String {
1648 let mut result = String::new();
1650 let mut next_upper = true;
1651 let mut prev_was_lower = false;
1652
1653 for c in s.chars() {
1654 match c {
1655 'a'..='z' => {
1656 if next_upper {
1657 result.push(c.to_ascii_uppercase());
1658 next_upper = false;
1659 } else {
1660 result.push(c);
1661 }
1662 prev_was_lower = true;
1663 }
1664 'A'..='Z' => {
1665 result.push(c);
1666 next_upper = false;
1667 prev_was_lower = false;
1668 }
1669 '0'..='9' => {
1670 if prev_was_lower && !result.chars().last().unwrap_or(' ').is_ascii_digit() {
1673 }
1675 result.push(c);
1676 next_upper = false;
1677 prev_was_lower = false;
1678 }
1679 '_' | '-' | '.' | ' ' => {
1680 next_upper = true;
1682 prev_was_lower = false;
1683 }
1684 _ => {
1685 next_upper = true;
1687 prev_was_lower = false;
1688 }
1689 }
1690 }
1691
1692 if result.is_empty() {
1694 result = "Type".to_string();
1695 }
1696
1697 if result.chars().next().is_some_and(|c| c.is_ascii_digit()) {
1699 result = format!("Type{result}");
1700 }
1701
1702 result
1703 }
1704
1705 fn to_rust_field_name(&self, s: &str) -> String {
1706 let mut result = String::new();
1708 let mut prev_was_upper = false;
1709 let mut prev_was_underscore = false;
1710
1711 for (i, c) in s.chars().enumerate() {
1712 match c {
1713 'A'..='Z' => {
1714 if i > 0 && !prev_was_upper && !prev_was_underscore {
1716 result.push('_');
1717 }
1718 result.push(c.to_ascii_lowercase());
1719 prev_was_upper = true;
1720 prev_was_underscore = false;
1721 }
1722 'a'..='z' | '0'..='9' => {
1723 result.push(c);
1724 prev_was_upper = false;
1725 prev_was_underscore = false;
1726 }
1727 '-' | '.' | '_' | '@' | '#' | '$' | ' ' => {
1728 if !prev_was_underscore && !result.is_empty() {
1729 result.push('_');
1730 prev_was_underscore = true;
1731 }
1732 prev_was_upper = false;
1733 }
1734 _ => {
1735 if !prev_was_underscore && !result.is_empty() {
1737 result.push('_');
1738 }
1739 prev_was_upper = false;
1740 prev_was_underscore = true;
1741 }
1742 }
1743 }
1744
1745 let mut result = result.trim_matches('_').to_string();
1747 if result.is_empty() {
1748 return "field".to_string();
1749 }
1750
1751 if result.chars().next().is_some_and(|c| c.is_ascii_digit()) {
1753 result = format!("field_{result}");
1754 }
1755
1756 if Self::is_rust_keyword(&result) {
1758 format!("r#{result}")
1759 } else {
1760 result
1761 }
1762 }
1763
1764 pub fn is_rust_keyword(s: &str) -> bool {
1766 matches!(
1767 s,
1768 "type"
1769 | "match"
1770 | "fn"
1771 | "struct"
1772 | "enum"
1773 | "impl"
1774 | "trait"
1775 | "mod"
1776 | "use"
1777 | "pub"
1778 | "const"
1779 | "static"
1780 | "let"
1781 | "mut"
1782 | "ref"
1783 | "move"
1784 | "return"
1785 | "if"
1786 | "else"
1787 | "while"
1788 | "for"
1789 | "loop"
1790 | "break"
1791 | "continue"
1792 | "self"
1793 | "super"
1794 | "crate"
1795 | "async"
1796 | "await"
1797 | "override"
1798 | "box"
1799 | "dyn"
1800 | "where"
1801 | "in"
1802 | "abstract"
1803 | "become"
1804 | "do"
1805 | "final"
1806 | "macro"
1807 | "priv"
1808 | "try"
1809 | "typeof"
1810 | "unsized"
1811 | "virtual"
1812 | "yield"
1813 )
1814 }
1815
1816 pub fn to_field_ident(name: &str) -> proc_macro2::Ident {
1818 if let Some(raw) = name.strip_prefix("r#") {
1819 proc_macro2::Ident::new_raw(raw, proc_macro2::Span::call_site())
1820 } else {
1821 proc_macro2::Ident::new(name, proc_macro2::Span::call_site())
1822 }
1823 }
1824
1825 fn to_camel_case(&self, s: &str) -> String {
1826 let mut result = String::new();
1828 let mut capitalize_next = false;
1829
1830 for (i, c) in s.chars().enumerate() {
1831 match c {
1832 '_' | '-' | '.' | ' ' => {
1833 capitalize_next = true;
1835 }
1836 'A'..='Z' => {
1837 if i == 0 {
1838 result.push(c.to_ascii_lowercase());
1840 } else if capitalize_next {
1841 result.push(c);
1842 capitalize_next = false;
1843 } else {
1844 result.push(c.to_ascii_lowercase());
1845 }
1846 }
1847 'a'..='z' | '0'..='9' => {
1848 if capitalize_next {
1849 result.push(c.to_ascii_uppercase());
1850 capitalize_next = false;
1851 } else {
1852 result.push(c);
1853 }
1854 }
1855 _ => {
1856 capitalize_next = true;
1858 }
1859 }
1860 }
1861
1862 if result.is_empty() {
1863 return "field".to_string();
1864 }
1865
1866 result
1867 }
1868
1869 fn generate_composition_struct(
1870 &self,
1871 schema: &crate::analysis::AnalyzedSchema,
1872 schemas: &[crate::analysis::SchemaRef],
1873 ) -> Result<TokenStream> {
1874 let struct_name = format_ident!("{}", self.to_rust_type_name(&schema.name));
1875
1876 let fields = schemas.iter().enumerate().map(|(i, schema_ref)| {
1882 let field_name = format_ident!("part_{}", i);
1883 let field_type = format_ident!("{}", self.to_rust_type_name(&schema_ref.target));
1884
1885 quote! {
1886 #[serde(flatten)]
1887 pub #field_name: #field_type,
1888 }
1889 });
1890
1891 let doc_comment = if let Some(desc) = &schema.description {
1892 quote! { #[doc = #desc] }
1893 } else {
1894 TokenStream::new()
1895 };
1896
1897 let derives = if self.config.enable_specta {
1899 quote! {
1900 #[derive(Debug, Clone, Deserialize, Serialize)]
1901 #[cfg_attr(feature = "specta", derive(specta::Type))]
1902 }
1903 } else {
1904 quote! {
1905 #[derive(Debug, Clone, Deserialize, Serialize)]
1906 }
1907 };
1908
1909 Ok(quote! {
1910 #doc_comment
1911 #derives
1912 pub struct #struct_name {
1913 #(#fields)*
1914 }
1915 })
1916 }
1917
1918 #[allow(dead_code)]
1919 fn find_missing_types(&self, analysis: &SchemaAnalysis) -> std::collections::HashSet<String> {
1920 let mut missing = std::collections::HashSet::new();
1921 let defined_types: std::collections::HashSet<String> =
1922 analysis.schemas.keys().cloned().collect();
1923
1924 for schema in analysis.schemas.values() {
1926 match &schema.schema_type {
1927 crate::analysis::SchemaType::Union { variants } => {
1928 for variant in variants {
1929 if !defined_types.contains(&variant.target) {
1930 missing.insert(variant.target.clone());
1931 }
1932 }
1933 }
1934 crate::analysis::SchemaType::DiscriminatedUnion { variants, .. } => {
1935 for variant in variants {
1936 if !defined_types.contains(&variant.type_name) {
1937 missing.insert(variant.type_name.clone());
1938 }
1939 }
1940 }
1941 crate::analysis::SchemaType::Object { properties, .. } => {
1942 let mut sorted_props: Vec<_> = properties.iter().collect();
1944 sorted_props.sort_by_key(|(name, _)| name.as_str());
1945 for (_, prop) in sorted_props {
1946 if let crate::analysis::SchemaType::Reference { target } = &prop.schema_type
1947 {
1948 if !defined_types.contains(target) {
1949 missing.insert(target.clone());
1950 }
1951 }
1952 }
1953 }
1954 crate::analysis::SchemaType::Reference { target }
1955 if !defined_types.contains(target) =>
1956 {
1957 missing.insert(target.clone());
1958 }
1959 _ => {}
1960 }
1961 }
1962
1963 missing
1964 }
1965
1966 #[allow(clippy::only_used_in_recursion)]
1967 fn generate_array_item_type(
1968 &self,
1969 item_type: &crate::analysis::SchemaType,
1970 analysis: &crate::analysis::SchemaAnalysis,
1971 ) -> TokenStream {
1972 use crate::analysis::SchemaType;
1973
1974 match item_type {
1975 SchemaType::Primitive { rust_type } => {
1976 if rust_type.contains("::") {
1978 let parts: Vec<&str> = rust_type.split("::").collect();
1979 if parts.len() == 2 {
1980 let module = format_ident!("{}", parts[0]);
1981 let type_name = format_ident!("{}", parts[1]);
1982 quote! { #module::#type_name }
1983 } else {
1984 let path_parts: Vec<_> =
1986 parts.iter().map(|p| format_ident!("{}", p)).collect();
1987 quote! { #(#path_parts)::* }
1988 }
1989 } else {
1990 let type_ident = format_ident!("{}", rust_type);
1991 quote! { #type_ident }
1992 }
1993 }
1994 SchemaType::Reference { target } => {
1995 let target_type = format_ident!("{}", self.to_rust_type_name(target));
1996 if analysis.dependencies.recursive_schemas.contains(target) {
1998 quote! { Box<#target_type> }
1999 } else {
2000 quote! { #target_type }
2001 }
2002 }
2003 SchemaType::Array { item_type } => {
2004 let inner_type = self.generate_array_item_type(item_type, analysis);
2006 quote! { Vec<#inner_type> }
2007 }
2008 _ => {
2009 quote! { serde_json::Value }
2011 }
2012 }
2013 }
2014
2015 fn type_name_to_variant_name(&self, type_name: &str) -> String {
2017 match type_name {
2019 "bool" => return "Boolean".to_string(),
2020 "i8" | "i16" | "i32" | "i64" | "i128" => return "Integer".to_string(),
2021 "u8" | "u16" | "u32" | "u64" | "u128" => return "UnsignedInteger".to_string(),
2022 "f32" | "f64" => return "Number".to_string(),
2023 "String" => return "String".to_string(),
2024 "serde_json::Value" => return "Value".to_string(),
2025 _ => {}
2026 }
2027
2028 if type_name.starts_with("Vec<") && type_name.ends_with(">") {
2030 let inner = &type_name[4..type_name.len() - 1];
2031 if inner.starts_with("Vec<") && inner.ends_with(">") {
2033 let inner_inner = &inner[4..inner.len() - 1];
2034 return format!("{}ArrayArray", self.type_name_to_variant_name(inner_inner));
2035 }
2036 return format!("{}Array", self.type_name_to_variant_name(inner));
2037 }
2038
2039 let clean_name = type_name
2045 .trim_end_matches("Type")
2046 .trim_end_matches("Schema")
2047 .trim_end_matches("Item");
2048
2049 self.to_rust_type_name(clean_name)
2051 }
2052
2053 fn ensure_unique_variant_name_generator(
2055 &self,
2056 base_name: String,
2057 used_names: &mut std::collections::HashSet<String>,
2058 fallback_index: usize,
2059 ) -> String {
2060 if used_names.insert(base_name.clone()) {
2061 return base_name;
2062 }
2063
2064 for i in 2..100 {
2066 let numbered_name = format!("{base_name}{i}");
2067 if used_names.insert(numbered_name.clone()) {
2068 return numbered_name;
2069 }
2070 }
2071
2072 let fallback = format!("Variant{fallback_index}");
2074 used_names.insert(fallback.clone());
2075 fallback
2076 }
2077
2078 fn find_request_type_for_operation(
2080 &self,
2081 operation_id: &str,
2082 analysis: &SchemaAnalysis,
2083 ) -> Option<String> {
2084 analysis.operations.get(operation_id).and_then(|op| {
2086 op.request_body
2087 .as_ref()
2088 .and_then(|rb| rb.schema_name().map(|s| s.to_string()))
2089 })
2090 }
2091
2092 fn resolve_streaming_event_type(
2094 &self,
2095 endpoint: &crate::streaming::StreamingEndpoint,
2096 analysis: &SchemaAnalysis,
2097 ) -> Result<String> {
2098 match &endpoint.event_flow {
2099 crate::streaming::EventFlow::Simple => {
2100 if analysis.schemas.contains_key(&endpoint.event_union_type) {
2103 Ok(endpoint.event_union_type.to_string())
2104 } else {
2105 Err(crate::error::GeneratorError::ValidationError(format!(
2106 "Streaming response type '{}' not found in schema for simple streaming endpoint '{}'",
2107 endpoint.event_union_type, endpoint.operation_id
2108 )))
2109 }
2110 }
2111 crate::streaming::EventFlow::StartDeltaStop { .. } => {
2112 if analysis.schemas.contains_key(&endpoint.event_union_type) {
2115 Ok(endpoint.event_union_type.to_string())
2116 } else {
2117 Err(crate::error::GeneratorError::ValidationError(format!(
2118 "Event union type '{}' not found in schema for complex streaming endpoint '{}'",
2119 endpoint.event_union_type, endpoint.operation_id
2120 )))
2121 }
2122 }
2123 }
2124 }
2125
2126 fn generate_streaming_error_types(&self) -> Result<TokenStream> {
2128 Ok(quote! {
2129 #[derive(Debug, thiserror::Error)]
2131 pub enum StreamingError {
2132 #[error("Connection error: {0}")]
2133 Connection(String),
2134 #[error("HTTP error: {status}")]
2135 Http { status: u16 },
2136 #[error("SSE parsing error: {0}")]
2137 Parsing(String),
2138 #[error("Authentication error: {0}")]
2139 Authentication(String),
2140 #[error("Rate limit error: {0}")]
2141 RateLimit(String),
2142 #[error("API error: {0}")]
2143 Api(String),
2144 #[error("Timeout error: {0}")]
2145 Timeout(String),
2146 #[error("JSON serialization/deserialization error: {0}")]
2147 Json(#[from] serde_json::Error),
2148 #[error("Request error: {0}")]
2149 Request(reqwest::Error),
2150 }
2151
2152 impl From<reqwest::header::InvalidHeaderValue> for StreamingError {
2153 fn from(err: reqwest::header::InvalidHeaderValue) -> Self {
2154 StreamingError::Api(format!("Invalid header value: {}", err))
2155 }
2156 }
2157
2158 impl From<reqwest::Error> for StreamingError {
2159 fn from(err: reqwest::Error) -> Self {
2160 if err.is_timeout() {
2161 StreamingError::Timeout(err.to_string())
2162 } else if err.is_status() {
2163 if let Some(status) = err.status() {
2164 StreamingError::Http { status: status.as_u16() }
2165 } else {
2166 StreamingError::Connection(err.to_string())
2167 }
2168 } else {
2169 StreamingError::Request(err)
2170 }
2171 }
2172 }
2173 })
2174 }
2175
2176 fn generate_endpoint_trait(
2178 &self,
2179 endpoint: &crate::streaming::StreamingEndpoint,
2180 analysis: &SchemaAnalysis,
2181 ) -> Result<TokenStream> {
2182 use crate::streaming::HttpMethod;
2183
2184 let trait_name = format_ident!(
2185 "{}StreamingClient",
2186 self.to_rust_type_name(&endpoint.operation_id)
2187 );
2188 let method_name =
2189 format_ident!("stream_{}", self.to_rust_field_name(&endpoint.operation_id));
2190 let event_type =
2191 format_ident!("{}", self.resolve_streaming_event_type(endpoint, analysis)?);
2192
2193 let method_signature = match endpoint.http_method {
2195 HttpMethod::Get => {
2196 let mut param_defs = Vec::new();
2198 for qp in &endpoint.query_parameters {
2199 let param_name = format_ident!("{}", self.to_rust_field_name(&qp.name));
2200 if qp.required {
2201 param_defs.push(quote! { #param_name: &str });
2202 } else {
2203 param_defs.push(quote! { #param_name: Option<&str> });
2204 }
2205 }
2206 quote! {
2207 async fn #method_name(
2208 &self,
2209 #(#param_defs),*
2210 ) -> Result<Pin<Box<dyn Stream<Item = Result<#event_type, Self::Error>> + Send>>, Self::Error>;
2211 }
2212 }
2213 HttpMethod::Post => {
2214 let request_type = self
2216 .find_request_type_for_operation(&endpoint.operation_id, analysis)
2217 .unwrap_or_else(|| "serde_json::Value".to_string());
2218 let request_type_ident = if request_type.contains("::") {
2219 let parts: Vec<&str> = request_type.split("::").collect();
2220 let path_parts: Vec<_> = parts.iter().map(|p| format_ident!("{}", p)).collect();
2221 quote! { #(#path_parts)::* }
2222 } else {
2223 let ident = format_ident!("{}", request_type);
2224 quote! { #ident }
2225 };
2226 quote! {
2227 async fn #method_name(
2228 &self,
2229 request: #request_type_ident,
2230 ) -> Result<Pin<Box<dyn Stream<Item = Result<#event_type, Self::Error>> + Send>>, Self::Error>;
2231 }
2232 }
2233 };
2234
2235 Ok(quote! {
2236 #[async_trait]
2238 pub trait #trait_name {
2239 type Error: std::error::Error + Send + Sync + 'static;
2240
2241 #method_signature
2243 }
2244 })
2245 }
2246
2247 fn generate_streaming_client_impl(
2249 &self,
2250 streaming_config: &crate::streaming::StreamingConfig,
2251 analysis: &SchemaAnalysis,
2252 ) -> Result<TokenStream> {
2253 let client_name = format_ident!(
2254 "{}Client",
2255 self.to_rust_type_name(&streaming_config.client_module_name)
2256 );
2257
2258 let mut struct_fields = vec![
2261 quote! { base_url: String },
2262 quote! { api_key: Option<String> },
2263 quote! { http_client: reqwest::Client },
2264 quote! { custom_headers: std::collections::BTreeMap<String, String> },
2265 ];
2266
2267 let has_optional_headers = !streaming_config
2268 .endpoints
2269 .iter()
2270 .all(|e| e.optional_headers.is_empty());
2271
2272 if has_optional_headers {
2273 struct_fields
2274 .push(quote! { optional_headers: std::collections::BTreeMap<String, String> });
2275 }
2276
2277 let default_base_url = if let Some(ref streaming_config) = self.config.streaming_config {
2280 streaming_config
2281 .endpoints
2282 .first()
2283 .and_then(|e| e.base_url.as_deref())
2284 .unwrap_or("https://api.example.com")
2285 } else {
2286 "https://api.example.com"
2287 };
2288
2289 let constructor_fields = if has_optional_headers {
2291 quote! {
2292 base_url: #default_base_url.to_string(),
2293 api_key: None,
2294 http_client: reqwest::Client::new(),
2295 custom_headers: std::collections::BTreeMap::new(),
2296 optional_headers: std::collections::BTreeMap::new(),
2297 }
2298 } else {
2299 quote! {
2300 base_url: #default_base_url.to_string(),
2301 api_key: None,
2302 http_client: reqwest::Client::new(),
2303 custom_headers: std::collections::BTreeMap::new(),
2304 }
2305 };
2306
2307 let optional_headers_method = if has_optional_headers {
2309 quote! {
2310 pub fn set_optional_headers(&mut self, headers: std::collections::BTreeMap<String, String>) {
2312 self.optional_headers = headers;
2313 }
2314 }
2315 } else {
2316 TokenStream::new()
2317 };
2318
2319 let constructor = quote! {
2320 impl #client_name {
2321 pub fn new() -> Self {
2323 Self {
2324 #constructor_fields
2325 }
2326 }
2327
2328 pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
2330 self.base_url = base_url.into();
2331 self
2332 }
2333
2334 pub fn with_api_key(mut self, api_key: impl Into<String>) -> Self {
2336 self.api_key = Some(api_key.into());
2337 self
2338 }
2339
2340 pub fn with_header(
2342 mut self,
2343 name: impl Into<String>,
2344 value: impl Into<String>,
2345 ) -> Self {
2346 self.custom_headers.insert(name.into(), value.into());
2347 self
2348 }
2349
2350 pub fn with_http_client(mut self, client: reqwest::Client) -> Self {
2352 self.http_client = client;
2353 self
2354 }
2355
2356 #optional_headers_method
2357 }
2358 };
2359
2360 let mut trait_impls = Vec::new();
2362 for endpoint in &streaming_config.endpoints {
2363 let trait_impl = self.generate_endpoint_trait_impl(endpoint, &client_name, analysis)?;
2364 trait_impls.push(trait_impl);
2365 }
2366
2367 let default_impl = quote! {
2369 impl Default for #client_name {
2370 fn default() -> Self {
2371 Self::new()
2372 }
2373 }
2374 };
2375
2376 Ok(quote! {
2377 #[derive(Debug, Clone)]
2379 pub struct #client_name {
2380 #(#struct_fields,)*
2381 }
2382
2383 #constructor
2384
2385 #default_impl
2386
2387 #(#trait_impls)*
2388 })
2389 }
2390
2391 fn generate_endpoint_trait_impl(
2393 &self,
2394 endpoint: &crate::streaming::StreamingEndpoint,
2395 client_name: &proc_macro2::Ident,
2396 analysis: &SchemaAnalysis,
2397 ) -> Result<TokenStream> {
2398 use crate::streaming::HttpMethod;
2399
2400 let trait_name = format_ident!(
2401 "{}StreamingClient",
2402 self.to_rust_type_name(&endpoint.operation_id)
2403 );
2404 let method_name =
2405 format_ident!("stream_{}", self.to_rust_field_name(&endpoint.operation_id));
2406 let event_type =
2407 format_ident!("{}", self.resolve_streaming_event_type(endpoint, analysis)?);
2408
2409 let mut header_setup = Vec::new();
2411 for (name, value) in &endpoint.required_headers {
2412 header_setup.push(quote! {
2413 headers.insert(#name, HeaderValue::from_static(#value));
2414 });
2415 }
2416
2417 if let Some(auth_header) = &endpoint.auth_header {
2420 match auth_header {
2421 crate::streaming::AuthHeader::Bearer(header_name) => {
2422 header_setup.push(quote! {
2423 if let Some(ref api_key) = self.api_key {
2424 headers.insert(#header_name, HeaderValue::from_str(&format!("Bearer {}", api_key))?);
2425 }
2426 });
2427 }
2428 crate::streaming::AuthHeader::ApiKey(header_name) => {
2429 header_setup.push(quote! {
2430 if let Some(ref api_key) = self.api_key {
2431 headers.insert(#header_name, HeaderValue::from_str(api_key)?);
2432 }
2433 });
2434 }
2435 }
2436 } else {
2437 header_setup.push(quote! {
2439 if let Some(ref api_key) = self.api_key {
2440 headers.insert("Authorization", HeaderValue::from_str(&format!("Bearer {}", api_key))?);
2441 }
2442 });
2443 }
2444
2445 header_setup.push(quote! {
2447 for (name, value) in &self.custom_headers {
2448 if let (Ok(header_name), Ok(header_value)) = (reqwest::header::HeaderName::from_bytes(name.as_bytes()), HeaderValue::from_str(value)) {
2449 headers.insert(header_name, header_value);
2450 }
2451 }
2452 });
2453
2454 if !endpoint.optional_headers.is_empty() {
2456 header_setup.push(quote! {
2457 for (key, value) in &self.optional_headers {
2458 if let (Ok(header_name), Ok(header_value)) = (reqwest::header::HeaderName::from_bytes(key.as_bytes()), HeaderValue::from_str(value)) {
2459 headers.insert(header_name, header_value);
2460 }
2461 }
2462 });
2463 }
2464
2465 match endpoint.http_method {
2467 HttpMethod::Get => self.generate_get_streaming_impl(
2468 endpoint,
2469 client_name,
2470 &trait_name,
2471 &method_name,
2472 &event_type,
2473 &header_setup,
2474 ),
2475 HttpMethod::Post => self.generate_post_streaming_impl(
2476 endpoint,
2477 client_name,
2478 &trait_name,
2479 &method_name,
2480 &event_type,
2481 &header_setup,
2482 analysis,
2483 ),
2484 }
2485 }
2486
2487 fn generate_get_streaming_impl(
2489 &self,
2490 endpoint: &crate::streaming::StreamingEndpoint,
2491 client_name: &proc_macro2::Ident,
2492 trait_name: &proc_macro2::Ident,
2493 method_name: &proc_macro2::Ident,
2494 event_type: &proc_macro2::Ident,
2495 header_setup: &[TokenStream],
2496 ) -> Result<TokenStream> {
2497 let path = &endpoint.path;
2498
2499 let mut param_defs = Vec::new();
2501 let mut query_params = Vec::new();
2502
2503 for qp in &endpoint.query_parameters {
2504 let param_name = format_ident!("{}", self.to_rust_field_name(&qp.name));
2505 let param_name_str = &qp.name;
2506
2507 if qp.required {
2508 param_defs.push(quote! { #param_name: &str });
2509 query_params.push(quote! {
2510 url.query_pairs_mut().append_pair(#param_name_str, #param_name);
2511 });
2512 } else {
2513 param_defs.push(quote! { #param_name: Option<&str> });
2514 query_params.push(quote! {
2515 if let Some(v) = #param_name {
2516 url.query_pairs_mut().append_pair(#param_name_str, v);
2517 }
2518 });
2519 }
2520 }
2521
2522 let url_construction = quote! {
2524 let base_url = url::Url::parse(&self.base_url)
2525 .map_err(|e| StreamingError::Connection(format!("Invalid base URL: {}", e)))?;
2526 let path_to_join = #path.trim_start_matches('/');
2527 let mut url = base_url.join(path_to_join)
2528 .map_err(|e| StreamingError::Connection(format!("URL join error: {}", e)))?;
2529 #(#query_params)*
2530 };
2531
2532 let instrument_skip = quote! { #[instrument(skip(self), name = "streaming_get_request")] };
2533
2534 Ok(quote! {
2535 #[async_trait]
2536 impl #trait_name for #client_name {
2537 type Error = StreamingError;
2538
2539 #instrument_skip
2540 async fn #method_name(
2541 &self,
2542 #(#param_defs),*
2543 ) -> Result<Pin<Box<dyn Stream<Item = Result<#event_type, Self::Error>> + Send>>, Self::Error> {
2544 debug!("Starting streaming GET request");
2545
2546 let mut headers = HeaderMap::new();
2547 #(#header_setup)*
2548
2549 #url_construction
2550 let url_str = url.to_string();
2551 debug!("Making streaming GET request to: {}", url_str);
2552
2553 let request_builder = self.http_client
2554 .get(url_str)
2555 .headers(headers);
2556
2557 debug!("Creating SSE stream from request");
2558 let stream = parse_sse_stream::<#event_type>(request_builder).await?;
2559 info!("SSE stream created successfully");
2560 Ok(Box::pin(stream))
2561 }
2562 }
2563 })
2564 }
2565
2566 #[allow(clippy::too_many_arguments)]
2568 fn generate_post_streaming_impl(
2569 &self,
2570 endpoint: &crate::streaming::StreamingEndpoint,
2571 client_name: &proc_macro2::Ident,
2572 trait_name: &proc_macro2::Ident,
2573 method_name: &proc_macro2::Ident,
2574 event_type: &proc_macro2::Ident,
2575 header_setup: &[TokenStream],
2576 analysis: &SchemaAnalysis,
2577 ) -> Result<TokenStream> {
2578 let path = &endpoint.path;
2579
2580 let request_type = self
2582 .find_request_type_for_operation(&endpoint.operation_id, analysis)
2583 .unwrap_or_else(|| "serde_json::Value".to_string());
2584 let request_type_ident = if request_type.contains("::") {
2585 let parts: Vec<&str> = request_type.split("::").collect();
2586 let path_parts: Vec<_> = parts.iter().map(|p| format_ident!("{}", p)).collect();
2587 quote! { #(#path_parts)::* }
2588 } else {
2589 let ident = format_ident!("{}", request_type);
2590 quote! { #ident }
2591 };
2592
2593 let url_construction = quote! {
2595 let base_url = url::Url::parse(&self.base_url)
2596 .map_err(|e| StreamingError::Connection(format!("Invalid base URL: {}", e)))?;
2597 let path_to_join = #path.trim_start_matches('/');
2598 let url = base_url.join(path_to_join)
2599 .map_err(|e| StreamingError::Connection(format!("URL join error: {}", e)))?
2600 .to_string();
2601 };
2602
2603 let stream_param = &endpoint.stream_parameter;
2605 let stream_setup = if stream_param.is_empty() {
2606 quote! {
2607 let streaming_request = request;
2608 }
2609 } else {
2610 quote! {
2611 let mut streaming_request = request;
2613 if let Ok(mut request_value) = serde_json::to_value(&streaming_request) {
2614 if let Some(obj) = request_value.as_object_mut() {
2615 obj.insert(#stream_param.to_string(), serde_json::Value::Bool(true));
2616 }
2617 streaming_request = serde_json::from_value(request_value)?;
2618 }
2619 }
2620 };
2621
2622 Ok(quote! {
2623 #[async_trait]
2624 impl #trait_name for #client_name {
2625 type Error = StreamingError;
2626
2627 #[instrument(skip(self, request), name = "streaming_post_request")]
2628 async fn #method_name(
2629 &self,
2630 request: #request_type_ident,
2631 ) -> Result<Pin<Box<dyn Stream<Item = Result<#event_type, Self::Error>> + Send>>, Self::Error> {
2632 debug!("Starting streaming POST request");
2633
2634 #stream_setup
2635
2636 let mut headers = HeaderMap::new();
2637 #(#header_setup)*
2638
2639 #url_construction
2640 debug!("Making streaming POST request to: {}", url);
2641
2642 let request_builder = self.http_client
2643 .post(&url)
2644 .headers(headers)
2645 .json(&streaming_request);
2646
2647 debug!("Creating SSE stream from request");
2648 let stream = parse_sse_stream::<#event_type>(request_builder).await?;
2649 info!("SSE stream created successfully");
2650 Ok(Box::pin(stream))
2651 }
2652 }
2653 })
2654 }
2655
2656 fn generate_sse_parser_utilities(
2658 &self,
2659 _streaming_config: &crate::streaming::StreamingConfig,
2660 ) -> Result<TokenStream> {
2661 Ok(quote! {
2662 pub async fn parse_sse_stream<T>(
2664 request_builder: reqwest::RequestBuilder
2665 ) -> Result<impl Stream<Item = Result<T, StreamingError>>, StreamingError>
2666 where
2667 T: serde::de::DeserializeOwned + Send + 'static,
2668 {
2669 let mut event_source = reqwest_eventsource::EventSource::new(request_builder).map_err(|e| {
2670 StreamingError::Connection(format!("Failed to create event source: {}", e))
2671 })?;
2672
2673 let stream = event_source.filter_map(|event_result| async move {
2674 match event_result {
2675 Ok(reqwest_eventsource::Event::Open) => {
2676 debug!("SSE connection opened");
2677 None
2678 }
2679 Ok(reqwest_eventsource::Event::Message(message)) => {
2680 if message.event == "ping" {
2682 debug!("Received SSE ping event, skipping");
2683 return None;
2684 }
2685
2686 if message.data.trim().is_empty() {
2688 debug!("Empty SSE data, skipping");
2689 return None;
2690 }
2691
2692 if let Ok(json_value) = serde_json::from_str::<serde_json::Value>(&message.data) {
2694 if let Some(event_type) = json_value.get("event").and_then(|v| v.as_str()) {
2695 if event_type == "ping" {
2696 debug!("Received ping event in JSON data, skipping");
2697 return None;
2698 }
2699 }
2700
2701 match serde_json::from_value::<T>(json_value) {
2703 Ok(parsed_event) => {
2704 Some(Ok(parsed_event))
2705 }
2706 Err(e) => {
2707 if message.data.contains("ping") || message.event.contains("ping") {
2708 debug!("Ignoring ping-related event: {}", message.data);
2709 None
2710 } else {
2711 Some(Err(StreamingError::Parsing(
2712 format!("Failed to parse SSE event: {} (raw: {})", e, message.data)
2713 )))
2714 }
2715 }
2716 }
2717 } else {
2718 Some(Err(StreamingError::Parsing(
2720 format!("SSE event is not valid JSON: {}", message.data)
2721 )))
2722 }
2723 }
2724 Err(e) => {
2725 match e {
2727 reqwest_eventsource::Error::StreamEnded => {
2728 debug!("SSE stream completed normally");
2729 None }
2731 reqwest_eventsource::Error::InvalidStatusCode(status, response) => {
2732 let status_code = status.as_u16();
2734
2735 let error_body = match response.text().await {
2737 Ok(body) => body,
2738 Err(_) => "Failed to read error response body".to_string()
2739 };
2740
2741 error!("SSE connection error - HTTP {}: {}", status_code, error_body);
2742
2743 let detailed_error = format!(
2744 "HTTP {} error: {}",
2745 status_code,
2746 error_body
2747 );
2748
2749 Some(Err(StreamingError::Connection(detailed_error)))
2750 }
2751 _ => {
2752 let error_str = e.to_string();
2753 if error_str.contains("stream closed") {
2754 debug!("SSE stream closed");
2755 None
2756 } else {
2757 error!("SSE connection error: {}", e);
2758 Some(Err(StreamingError::Connection(error_str)))
2759 }
2760 }
2761 }
2762 }
2763 }
2764 });
2765
2766 Ok(stream)
2767 }
2768 })
2769 }
2770
2771 fn generate_reconnection_utilities(
2773 &self,
2774 reconnect_config: &crate::streaming::ReconnectionConfig,
2775 ) -> Result<TokenStream> {
2776 let max_retries = reconnect_config.max_retries;
2777 let initial_delay = reconnect_config.initial_delay_ms;
2778 let max_delay = reconnect_config.max_delay_ms;
2779 let backoff_multiplier = reconnect_config.backoff_multiplier;
2780
2781 Ok(quote! {
2782 #[derive(Debug, Clone)]
2784 pub struct ReconnectionManager {
2785 max_retries: u32,
2786 initial_delay_ms: u64,
2787 max_delay_ms: u64,
2788 backoff_multiplier: f64,
2789 current_attempt: u32,
2790 }
2791
2792 impl ReconnectionManager {
2793 pub fn new() -> Self {
2795 Self {
2796 max_retries: #max_retries,
2797 initial_delay_ms: #initial_delay,
2798 max_delay_ms: #max_delay,
2799 backoff_multiplier: #backoff_multiplier,
2800 current_attempt: 0,
2801 }
2802 }
2803
2804 pub fn should_retry(&self) -> bool {
2806 self.current_attempt < self.max_retries
2807 }
2808
2809 pub fn next_retry_delay(&mut self) -> Duration {
2811 if !self.should_retry() {
2812 return Duration::from_secs(0);
2813 }
2814
2815 let delay_ms = (self.initial_delay_ms as f64
2816 * self.backoff_multiplier.powi(self.current_attempt as i32)) as u64;
2817 let delay_ms = delay_ms.min(self.max_delay_ms);
2818
2819 self.current_attempt += 1;
2820 Duration::from_millis(delay_ms)
2821 }
2822
2823 pub fn reset(&mut self) {
2825 self.current_attempt = 0;
2826 }
2827
2828 pub fn current_attempt(&self) -> u32 {
2830 self.current_attempt
2831 }
2832 }
2833
2834 impl Default for ReconnectionManager {
2835 fn default() -> Self {
2836 Self::new()
2837 }
2838 }
2839 })
2840 }
2841}