1use crate::analysis::{OperationInfo, ParameterInfo, SchemaAnalysis};
151use crate::generator::CodeGenerator;
152use heck::ToSnakeCase;
153use proc_macro2::TokenStream;
154use quote::{format_ident, quote};
155use std::collections::BTreeMap;
156
157impl CodeGenerator {
158 pub fn generate_http_client_struct(&self) -> TokenStream {
160 let has_retry = self.config().retry_config.is_some();
161 let has_tracing = self.config().tracing_enabled;
162
163 let retry_config_struct = if has_retry {
165 quote! {
166 #[derive(Debug, Clone)]
168 pub struct RetryConfig {
169 pub max_retries: u32,
170 pub initial_delay_ms: u64,
171 pub max_delay_ms: u64,
172 }
173
174 impl Default for RetryConfig {
175 fn default() -> Self {
176 Self {
177 max_retries: 3,
178 initial_delay_ms: 500,
179 max_delay_ms: 16000,
180 }
181 }
182 }
183 }
184 } else {
185 quote! {}
186 };
187
188 let client_struct = quote! {
190 use reqwest_middleware::{ClientBuilder, ClientWithMiddleware};
191 use std::collections::BTreeMap;
192
193 #[derive(Clone)]
195 pub struct HttpClient {
196 base_url: String,
197 api_key: Option<String>,
198 http_client: ClientWithMiddleware,
199 custom_headers: BTreeMap<String, String>,
200 }
201 };
202
203 let constructor = self.generate_constructor(has_retry, has_tracing);
205
206 let builder_methods = self.generate_builder_methods();
208
209 let default_impl = quote! {
211 impl Default for HttpClient {
212 fn default() -> Self {
213 Self::new()
214 }
215 }
216 };
217
218 quote! {
220 #retry_config_struct
221 #client_struct
222
223 impl HttpClient {
224 #constructor
225 #builder_methods
226 }
227
228 #default_impl
229 }
230 }
231
232 fn generate_constructor(&self, has_retry: bool, has_tracing: bool) -> TokenStream {
234 let retry_param = if has_retry {
235 quote! { retry_config: Option<RetryConfig>, }
236 } else {
237 quote! {}
238 };
239
240 let tracing_param = if has_tracing {
241 quote! { enable_tracing: bool, }
242 } else {
243 quote! {}
244 };
245
246 let retry_middleware = if has_retry {
247 quote! {
248 if let Some(config) = retry_config {
249 use reqwest_retry::{RetryTransientMiddleware, policies::ExponentialBackoff};
250
251 let retry_policy = ExponentialBackoff::builder()
252 .retry_bounds(
253 std::time::Duration::from_millis(config.initial_delay_ms),
254 std::time::Duration::from_millis(config.max_delay_ms),
255 )
256 .build_with_max_retries(config.max_retries);
257
258 let retry_middleware = RetryTransientMiddleware::new_with_policy(retry_policy);
259 client_builder = client_builder.with(retry_middleware);
260 }
261 }
262 } else {
263 quote! {}
264 };
265
266 let tracing_middleware = if has_tracing {
267 quote! {
268 if enable_tracing {
269 use reqwest_tracing::TracingMiddleware;
270 client_builder = client_builder.with(TracingMiddleware::default());
271 }
272 }
273 } else {
274 quote! {}
275 };
276
277 let default_constructor = if has_retry && has_tracing {
278 quote! {
279 pub fn new() -> Self {
281 Self::with_config(None, true)
282 }
283 }
284 } else if has_retry {
285 quote! {
286 pub fn new() -> Self {
288 Self::with_config(None)
289 }
290 }
291 } else if has_tracing {
292 quote! {
293 pub fn new() -> Self {
295 Self::with_config(true)
296 }
297 }
298 } else {
299 quote! {
300 pub fn new() -> Self {
302 let reqwest_client = reqwest::Client::new();
303 let client_builder = ClientBuilder::new(reqwest_client);
304 let http_client = client_builder.build();
305
306 Self {
307 base_url: String::new(),
308 api_key: None,
309 http_client,
310 custom_headers: BTreeMap::new(),
311 }
312 }
313 }
314 };
315
316 if has_retry || has_tracing {
317 quote! {
318 #default_constructor
319
320 pub fn with_config(#retry_param #tracing_param) -> Self {
322 let reqwest_client = reqwest::Client::new();
323 let mut client_builder = ClientBuilder::new(reqwest_client);
324
325 #tracing_middleware
326 #retry_middleware
327
328 let http_client = client_builder.build();
329
330 Self {
331 base_url: String::new(),
332 api_key: None,
333 http_client,
334 custom_headers: BTreeMap::new(),
335 }
336 }
337 }
338 } else {
339 default_constructor
340 }
341 }
342
343 fn generate_builder_methods(&self) -> TokenStream {
345 quote! {
346 pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
348 self.base_url = base_url.into();
349 self
350 }
351
352 pub fn with_api_key(mut self, api_key: impl Into<String>) -> Self {
354 self.api_key = Some(api_key.into());
355 self
356 }
357
358 pub fn with_header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
360 self.custom_headers.insert(name.into(), value.into());
361 self
362 }
363
364 pub fn with_headers(mut self, headers: BTreeMap<String, String>) -> Self {
366 self.custom_headers.extend(headers);
367 self
368 }
369 }
370 }
371
372 pub fn generate_operation_methods(&self, analysis: &SchemaAnalysis) -> TokenStream {
378 let param_enums = self.generate_param_enum_types(analysis);
379
380 let op_error_enums: Vec<TokenStream> = analysis
381 .operations
382 .values()
383 .filter_map(|op| self.generate_op_error_enum(op))
384 .collect();
385
386 let methods: Vec<TokenStream> = analysis
387 .operations
388 .values()
389 .map(|op| self.generate_single_operation_method(op))
390 .collect();
391
392 quote! {
393 #param_enums
394
395 #(#op_error_enums)*
396
397 impl HttpClient {
398 #(#methods)*
399 }
400 }
401 }
402
403 fn generate_param_enum_types(&self, analysis: &SchemaAnalysis) -> TokenStream {
408 let mut by_name: BTreeMap<String, &ParameterInfo> = BTreeMap::new();
409 for op in analysis.operations.values() {
410 for param in &op.parameters {
411 if param.enum_values.is_some() {
412 by_name.entry(param.rust_type.clone()).or_insert(param);
413 }
414 }
415 }
416
417 if by_name.is_empty() {
418 return quote! {};
419 }
420
421 let defs: Vec<TokenStream> = by_name
422 .values()
423 .map(|param| self.generate_single_param_enum(param))
424 .collect();
425
426 quote! { #(#defs)* }
427 }
428
429 fn generate_single_param_enum(&self, param: &ParameterInfo) -> TokenStream {
430 let Some(values) = param.enum_values.as_deref() else {
431 return quote! {};
432 };
433
434 let enum_ident = format_ident!("{}", param.rust_type);
435
436 let variants: Vec<TokenStream> = values
437 .iter()
438 .map(|value| {
439 let variant_ident = format_ident!("{}", self.to_rust_enum_variant(value));
440 quote! {
441 #[serde(rename = #value)]
442 #variant_ident,
443 }
444 })
445 .collect();
446
447 let display_arms: Vec<TokenStream> = values
448 .iter()
449 .map(|value| {
450 let variant_ident = format_ident!("{}", self.to_rust_enum_variant(value));
451 quote! { Self::#variant_ident => #value, }
452 })
453 .collect();
454
455 let doc = format!(
456 "Allowed values for the `{}` {} parameter.",
457 param.name, param.location
458 );
459
460 quote! {
461 #[doc = #doc]
462 #[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
463 pub enum #enum_ident {
464 #(#variants)*
465 }
466
467 impl #enum_ident {
468 pub fn as_str(&self) -> &'static str {
469 match self {
470 #(#display_arms)*
471 }
472 }
473 }
474
475 impl std::fmt::Display for #enum_ident {
476 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
477 f.write_str(self.as_str())
478 }
479 }
480
481 impl AsRef<str> for #enum_ident {
482 fn as_ref(&self) -> &str {
483 self.as_str()
484 }
485 }
486 }
487 }
488
489 fn generate_op_error_enum(&self, op: &OperationInfo) -> Option<TokenStream> {
494 let variants: Vec<(String, String)> = op
495 .response_schemas
496 .iter()
497 .filter(|(code, _)| !code.starts_with('2'))
498 .map(|(code, schema)| (code.clone(), schema.clone()))
499 .collect();
500
501 if variants.is_empty() {
502 return None;
503 }
504
505 let enum_ident = self.op_error_enum_ident(op);
506 let variant_decls: Vec<TokenStream> = variants
507 .iter()
508 .map(|(code, schema)| {
509 let variant_ident = Self::op_error_variant_ident(code);
510 let payload_ty_name = self.to_rust_type_name(schema);
511 let payload_ty = syn::Ident::new(&payload_ty_name, proc_macro2::Span::call_site());
512 quote! { #variant_ident(#payload_ty) }
513 })
514 .collect();
515
516 let doc = format!(
517 "Typed error responses for `{}`. One variant per declared non-2xx response.",
518 op.operation_id
519 );
520
521 Some(quote! {
522 #[doc = #doc]
523 #[derive(Debug, Clone)]
524 pub enum #enum_ident {
525 #(#variant_decls,)*
526 }
527 })
528 }
529
530 fn op_error_enum_ident(&self, op: &OperationInfo) -> syn::Ident {
532 use heck::ToPascalCase;
533 let name = format!(
534 "{}ApiError",
535 op.operation_id.replace('.', "_").to_pascal_case()
536 );
537 syn::Ident::new(&name, proc_macro2::Span::call_site())
538 }
539
540 fn op_error_variant_ident(status_code: &str) -> syn::Ident {
543 let raw = match status_code {
544 "default" | "Default" => "Default".to_string(),
545 other if other.chars().all(|c| c.is_ascii_digit()) => format!("Status{other}"),
546 other => format!("Status{}", other.to_ascii_lowercase()),
547 };
548 syn::Ident::new(&raw, proc_macro2::Span::call_site())
549 }
550
551 fn op_error_type_token(&self, op: &OperationInfo) -> TokenStream {
555 if op
556 .response_schemas
557 .iter()
558 .any(|(code, _)| !code.starts_with('2'))
559 {
560 let ident = self.op_error_enum_ident(op);
561 quote! { #ident }
562 } else {
563 quote! { serde_json::Value }
564 }
565 }
566
567 fn generate_single_operation_method(&self, op: &OperationInfo) -> TokenStream {
569 let method_name = self.get_method_name(op);
570 let http_method = self.get_http_method(op);
571 let path = &op.path;
572 let request_param = self.generate_request_param(op);
573 let request_body = self.generate_request_body(op);
574 let query_params = self.generate_query_params(op);
575 let response_type = self.get_response_type(op);
576 let has_response_body = self.get_success_response_schema(op).is_some();
577 let op_error_type = self.op_error_type_token(op);
578 let error_handling = self.generate_error_handling(op, has_response_body);
579 let url_construction = self.generate_url_construction(path, op);
580 let doc_comment = self.generate_operation_doc_comment(op);
581
582 quote! {
583 #doc_comment
584 pub async fn #method_name(
585 &self,
586 #request_param
587 ) -> Result<#response_type, ApiOpError<#op_error_type>> {
588 #url_construction
589
590 let mut req = self.http_client
591 .#http_method(request_url)
592 #request_body;
593
594 #query_params
595
596 if let Some(api_key) = &self.api_key {
598 req = req.bearer_auth(api_key);
599 }
600
601 for (name, value) in &self.custom_headers {
603 req = req.header(name, value);
604 }
605
606 let response = req.send().await?;
607 #error_handling
608 }
609 }
610 }
611
612 fn generate_query_params(&self, op: &OperationInfo) -> TokenStream {
614 let query_params: Vec<_> = op
615 .parameters
616 .iter()
617 .filter(|p| p.location == "query")
618 .collect();
619
620 if query_params.is_empty() {
621 return quote! {};
622 }
623
624 let mut param_building = Vec::new();
625
626 for param in query_params {
627 let param_name_snake = self.sanitize_param_name(¶m.name);
629 let param_name = Self::to_field_ident(¶m_name_snake);
630
631 let param_key = ¶m.name;
633
634 if param.required {
635 if param.rust_type == "String" {
637 param_building.push(quote! {
638 query_params.push((#param_key, #param_name.as_ref().to_string()));
639 });
640 } else {
641 param_building.push(quote! {
642 query_params.push((#param_key, #param_name.to_string()));
643 });
644 }
645 } else {
646 if param.rust_type == "String" {
648 param_building.push(quote! {
649 if let Some(v) = #param_name {
650 query_params.push((#param_key, v.as_ref().to_string()));
651 }
652 });
653 } else {
654 param_building.push(quote! {
655 if let Some(v) = #param_name {
656 query_params.push((#param_key, v.to_string()));
657 }
658 });
659 }
660 }
661 }
662
663 quote! {
664 {
666 let mut query_params: Vec<(&str, String)> = Vec::new();
667 #(#param_building)*
668 if !query_params.is_empty() {
669 req = req.query(&query_params);
670 }
671 }
672 }
673 }
674
675 fn generate_operation_doc_comment(&self, op: &OperationInfo) -> TokenStream {
677 let method = op.method.to_uppercase();
678 let path = &op.path;
679 let doc = format!("{} {}", method, path);
680
681 quote! {
682 #[doc = #doc]
683 }
684 }
685
686 fn get_method_name(&self, op: &OperationInfo) -> syn::Ident {
688 let name = if !op.operation_id.is_empty() {
689 op.operation_id.to_snake_case()
690 } else {
691 format!(
693 "{}_{}",
694 op.method,
695 op.path.replace('/', "_").replace(['{', '}'], "")
696 )
697 .to_snake_case()
698 };
699
700 syn::Ident::new(&name, proc_macro2::Span::call_site())
701 }
702
703 fn get_http_method(&self, op: &OperationInfo) -> syn::Ident {
705 let method = match op.method.to_uppercase().as_str() {
706 "GET" => "get",
707 "POST" => "post",
708 "PUT" => "put",
709 "DELETE" => "delete",
710 "PATCH" => "patch",
711 _ => "get", };
713
714 syn::Ident::new(method, proc_macro2::Span::call_site())
715 }
716
717 fn generate_request_param(&self, op: &OperationInfo) -> TokenStream {
719 let mut params = Vec::new();
720
721 for param in &op.parameters {
723 if param.location == "path" {
724 let param_name_snake = self.sanitize_param_name(¶m.name);
725 let param_name = Self::to_field_ident(¶m_name_snake);
726 let param_type = self.get_param_rust_type(param);
727 params.push(quote! { #param_name: #param_type });
728 }
729 }
730
731 for param in &op.parameters {
733 if param.location == "query" {
734 let param_name_snake = self.sanitize_param_name(¶m.name);
735 let param_name = Self::to_field_ident(¶m_name_snake);
736 let param_type = self.get_param_rust_type(param);
737
738 if param.required {
740 params.push(quote! { #param_name: #param_type });
741 } else {
742 params.push(quote! { #param_name: Option<#param_type> });
743 }
744 }
745 }
746
747 if let Some(ref rb) = op.request_body {
749 use crate::analysis::RequestBodyContent;
750 match rb {
751 RequestBodyContent::Json { schema_name }
752 | RequestBodyContent::FormUrlEncoded { schema_name } => {
753 let rust_type_name = self.to_rust_type_name(schema_name);
754 let request_ident =
755 syn::Ident::new(&rust_type_name, proc_macro2::Span::call_site());
756 params.push(quote! { request: #request_ident });
757 }
758 RequestBodyContent::Multipart => {
759 params.push(quote! { form: reqwest::multipart::Form });
760 }
761 RequestBodyContent::OctetStream => {
762 params.push(quote! { body: Vec<u8> });
763 }
764 RequestBodyContent::TextPlain => {
765 params.push(quote! { body: String });
766 }
767 }
768 }
769
770 if params.is_empty() {
771 quote! {}
772 } else {
773 quote! { #(#params),* }
774 }
775 }
776
777 fn get_param_rust_type(&self, param: &crate::analysis::ParameterInfo) -> TokenStream {
779 let type_str = ¶m.rust_type;
780 match type_str.as_str() {
781 "String" => quote! { impl AsRef<str> },
782 "i64" => quote! { i64 },
783 "i32" => quote! { i32 },
784 "f64" => quote! { f64 },
785 "bool" => quote! { bool },
786 _ => {
787 let type_ident = syn::Ident::new(type_str, proc_macro2::Span::call_site());
788 quote! { #type_ident }
789 }
790 }
791 }
792
793 fn generate_request_body(&self, op: &OperationInfo) -> TokenStream {
795 if let Some(ref rb) = op.request_body {
796 use crate::analysis::RequestBodyContent;
797 match rb {
798 RequestBodyContent::Json { .. } => {
799 quote! {
800 .body(serde_json::to_vec(&request).map_err(HttpError::serialization_error)?)
801 .header("content-type", "application/json")
802 }
803 }
804 RequestBodyContent::FormUrlEncoded { .. } => {
805 quote! {
806 .body(serde_urlencoded::to_string(&request).map_err(HttpError::serialization_error)?)
807 .header("content-type", "application/x-www-form-urlencoded")
808 }
809 }
810 RequestBodyContent::Multipart => {
811 quote! {
812 .multipart(form)
813 }
814 }
815 RequestBodyContent::OctetStream => {
816 quote! {
817 .body(body)
818 .header("content-type", "application/octet-stream")
819 }
820 }
821 RequestBodyContent::TextPlain => {
822 quote! {
823 .body(body)
824 .header("content-type", "text/plain")
825 }
826 }
827 }
828 } else {
829 quote! {}
830 }
831 }
832
833 fn get_success_response_schema<'a>(&self, op: &'a OperationInfo) -> Option<&'a String> {
839 op.response_schemas
840 .get("200")
841 .or_else(|| op.response_schemas.get("201"))
842 .or_else(|| {
843 op.response_schemas
844 .iter()
845 .find(|(code, _)| code.starts_with('2'))
846 .map(|(_, v)| v)
847 })
848 }
849
850 fn get_response_type(&self, op: &OperationInfo) -> TokenStream {
852 if let Some(response_type) = self.get_success_response_schema(op) {
853 let rust_type_name = self.to_rust_type_name(response_type);
855 let response_ident = syn::Ident::new(&rust_type_name, proc_macro2::Span::call_site());
856 quote! { #response_ident }
857 } else {
858 quote! { () }
859 }
860 }
861
862 fn generate_error_handling(&self, op: &OperationInfo, has_response_body: bool) -> TokenStream {
871 let op_error_type = self.op_error_type_token(op);
872
873 let success_branch = if has_response_body {
874 quote! {
875 match serde_json::from_str(&body_text) {
876 Ok(body) => Ok(body),
877 Err(e) => Err(ApiOpError::Api(ApiError {
878 status: status_code,
879 headers: headers,
880 body: body_text,
881 typed: None,
882 parse_error: Some(format!(
883 "failed to deserialize 2xx response body: {}",
884 e
885 )),
886 })),
887 }
888 }
889 } else {
890 quote! {
891 let _ = body_text;
892 let _ = headers;
893 Ok(())
894 }
895 };
896
897 let error_match_arms = self.generate_error_match_arms(op);
898
899 quote! {
900 let status = response.status();
901 let status_code = status.as_u16();
902 let headers = response.headers().clone();
903 let body_text = response.text().await
904 .map_err(|e| ApiOpError::Transport(HttpError::Network(e)))?;
905
906 if status.is_success() {
907 #success_branch
908 } else {
909 let typed: Option<#op_error_type>;
910 let parse_error: Option<String>;
911 #error_match_arms
912 Err(ApiOpError::Api(ApiError {
913 status: status_code,
914 headers,
915 body: body_text,
916 typed,
917 parse_error,
918 }))
919 }
920 }
921 }
922
923 fn generate_error_match_arms(&self, op: &OperationInfo) -> TokenStream {
926 let arms: Vec<TokenStream> = op
927 .response_schemas
928 .iter()
929 .filter(|(code, _)| !code.starts_with('2'))
930 .filter_map(|(code, schema)| {
931 let variant_ident = Self::op_error_variant_ident(code);
932 let payload_ty_name = self.to_rust_type_name(schema);
933 let payload_ty = syn::Ident::new(&payload_ty_name, proc_macro2::Span::call_site());
934 let enum_ident = self.op_error_enum_ident(op);
935
936 let pattern = match code.as_str() {
937 "default" | "Default" => return None, other if other.chars().all(|c| c.is_ascii_digit()) => {
939 let n: u16 = other.parse().ok()?;
940 quote! { #n }
941 }
942 _ => return None,
945 };
946
947 Some(quote! {
948 #pattern => {
949 match serde_json::from_str::<#payload_ty>(&body_text) {
950 Ok(v) => {
951 typed = Some(#enum_ident::#variant_ident(v));
952 parse_error = None;
953 }
954 Err(e) => {
955 typed = None;
956 parse_error = Some(e.to_string());
957 }
958 }
959 }
960 })
961 })
962 .collect();
963
964 let has_typed_enum = op.response_schemas.iter().any(|(code, _)| {
968 !code.starts_with('2') && !matches!(code.as_str(), "default" | "Default")
969 });
970
971 let default_arm = if has_typed_enum {
972 quote! {
973 _ => {
974 typed = None;
975 parse_error = None;
976 }
977 }
978 } else {
979 quote! {
981 _ => {
982 match serde_json::from_str::<serde_json::Value>(&body_text) {
983 Ok(v) => {
984 typed = Some(v);
985 parse_error = None;
986 }
987 Err(e) => {
988 typed = None;
989 parse_error = Some(e.to_string());
990 }
991 }
992 }
993 }
994 };
995
996 if arms.is_empty() {
997 quote! {
999 match status_code {
1000 #default_arm
1001 }
1002 }
1003 } else {
1004 quote! {
1005 match status_code {
1006 #(#arms)*
1007 #default_arm
1008 }
1009 }
1010 }
1011 }
1012
1013 fn generate_url_construction(&self, path: &str, op: &OperationInfo) -> TokenStream {
1015 if path.contains('{') {
1017 self.generate_url_with_params(path, op)
1018 } else {
1019 quote! {
1020 let request_url = format!("{}{}", self.base_url, #path);
1021 }
1022 }
1023 }
1024
1025 fn generate_url_with_params(&self, path: &str, op: &OperationInfo) -> TokenStream {
1027 let mut format_string = path.to_string();
1029 let mut format_args = Vec::new();
1030
1031 let path_params: Vec<_> = op
1033 .parameters
1034 .iter()
1035 .filter(|p| p.location == "path")
1036 .collect();
1037
1038 for param in &path_params {
1040 let placeholder = format!("{{{}}}", param.name);
1041 if format_string.contains(&placeholder) {
1042 format_string = format_string.replace(&placeholder, "{}");
1043
1044 let param_name_snake = self.sanitize_param_name(¶m.name);
1046 let param_ident = Self::to_field_ident(¶m_name_snake);
1047
1048 if param.rust_type == "String" {
1050 format_args.push(quote! { #param_ident.as_ref() });
1051 } else {
1052 format_args.push(quote! { #param_ident });
1053 }
1054 }
1055 }
1056
1057 if format_args.is_empty() {
1058 quote! {
1060 let request_url = format!("{}{}", self.base_url, #path);
1061 }
1062 } else {
1063 quote! {
1065 let request_url = format!("{}{}", self.base_url, format!(#format_string, #(#format_args),*));
1066 }
1067 }
1068 }
1069
1070 fn sanitize_param_name(&self, name: &str) -> String {
1072 let snake_case = name.to_snake_case();
1073 if Self::is_rust_keyword(&snake_case) {
1074 format!("r#{snake_case}")
1075 } else {
1076 snake_case
1077 }
1078 }
1079}