1use crate::analysis::{OperationInfo, SchemaAnalysis};
151use crate::generator::CodeGenerator;
152use heck::ToSnakeCase;
153use proc_macro2::TokenStream;
154use quote::quote;
155
156impl CodeGenerator {
157 pub fn generate_http_client_struct(&self) -> TokenStream {
159 let has_retry = self.config().retry_config.is_some();
160 let has_tracing = self.config().tracing_enabled;
161
162 let retry_config_struct = if has_retry {
164 quote! {
165 #[derive(Debug, Clone)]
167 pub struct RetryConfig {
168 pub max_retries: u32,
169 pub initial_delay_ms: u64,
170 pub max_delay_ms: u64,
171 }
172
173 impl Default for RetryConfig {
174 fn default() -> Self {
175 Self {
176 max_retries: 3,
177 initial_delay_ms: 500,
178 max_delay_ms: 16000,
179 }
180 }
181 }
182 }
183 } else {
184 quote! {}
185 };
186
187 let client_struct = quote! {
189 use reqwest_middleware::{ClientBuilder, ClientWithMiddleware};
190 use std::collections::BTreeMap;
191
192 #[derive(Clone)]
194 pub struct HttpClient {
195 base_url: String,
196 api_key: Option<String>,
197 http_client: ClientWithMiddleware,
198 custom_headers: BTreeMap<String, String>,
199 }
200 };
201
202 let constructor = self.generate_constructor(has_retry, has_tracing);
204
205 let builder_methods = self.generate_builder_methods();
207
208 let default_impl = quote! {
210 impl Default for HttpClient {
211 fn default() -> Self {
212 Self::new()
213 }
214 }
215 };
216
217 quote! {
219 #retry_config_struct
220 #client_struct
221
222 impl HttpClient {
223 #constructor
224 #builder_methods
225 }
226
227 #default_impl
228 }
229 }
230
231 fn generate_constructor(&self, has_retry: bool, has_tracing: bool) -> TokenStream {
233 let retry_param = if has_retry {
234 quote! { retry_config: Option<RetryConfig>, }
235 } else {
236 quote! {}
237 };
238
239 let tracing_param = if has_tracing {
240 quote! { enable_tracing: bool, }
241 } else {
242 quote! {}
243 };
244
245 let retry_middleware = if has_retry {
246 quote! {
247 if let Some(config) = retry_config {
248 use reqwest_retry::{RetryTransientMiddleware, policies::ExponentialBackoff};
249
250 let retry_policy = ExponentialBackoff::builder()
251 .retry_bounds(
252 std::time::Duration::from_millis(config.initial_delay_ms),
253 std::time::Duration::from_millis(config.max_delay_ms),
254 )
255 .build_with_max_retries(config.max_retries);
256
257 let retry_middleware = RetryTransientMiddleware::new_with_policy(retry_policy);
258 client_builder = client_builder.with(retry_middleware);
259 }
260 }
261 } else {
262 quote! {}
263 };
264
265 let tracing_middleware = if has_tracing {
266 quote! {
267 if enable_tracing {
268 use reqwest_tracing::TracingMiddleware;
269 client_builder = client_builder.with(TracingMiddleware::default());
270 }
271 }
272 } else {
273 quote! {}
274 };
275
276 let default_constructor = if has_retry && has_tracing {
277 quote! {
278 pub fn new() -> Self {
280 Self::with_config(None, true)
281 }
282 }
283 } else if has_retry {
284 quote! {
285 pub fn new() -> Self {
287 Self::with_config(None)
288 }
289 }
290 } else if has_tracing {
291 quote! {
292 pub fn new() -> Self {
294 Self::with_config(true)
295 }
296 }
297 } else {
298 quote! {
299 pub fn new() -> Self {
301 let reqwest_client = reqwest::Client::new();
302 let client_builder = ClientBuilder::new(reqwest_client);
303 let http_client = client_builder.build();
304
305 Self {
306 base_url: String::new(),
307 api_key: None,
308 http_client,
309 custom_headers: BTreeMap::new(),
310 }
311 }
312 }
313 };
314
315 if has_retry || has_tracing {
316 quote! {
317 #default_constructor
318
319 pub fn with_config(#retry_param #tracing_param) -> Self {
321 let reqwest_client = reqwest::Client::new();
322 let mut client_builder = ClientBuilder::new(reqwest_client);
323
324 #tracing_middleware
325 #retry_middleware
326
327 let http_client = client_builder.build();
328
329 Self {
330 base_url: String::new(),
331 api_key: None,
332 http_client,
333 custom_headers: BTreeMap::new(),
334 }
335 }
336 }
337 } else {
338 default_constructor
339 }
340 }
341
342 fn generate_builder_methods(&self) -> TokenStream {
344 quote! {
345 pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
347 self.base_url = base_url.into();
348 self
349 }
350
351 pub fn with_api_key(mut self, api_key: impl Into<String>) -> Self {
353 self.api_key = Some(api_key.into());
354 self
355 }
356
357 pub fn with_header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
359 self.custom_headers.insert(name.into(), value.into());
360 self
361 }
362
363 pub fn with_headers(mut self, headers: BTreeMap<String, String>) -> Self {
365 self.custom_headers.extend(headers);
366 self
367 }
368 }
369 }
370
371 pub fn generate_operation_methods(&self, analysis: &SchemaAnalysis) -> TokenStream {
373 let methods: Vec<TokenStream> = analysis
374 .operations
375 .values()
376 .map(|op| self.generate_single_operation_method(op))
377 .collect();
378
379 quote! {
380 impl HttpClient {
381 #(#methods)*
382 }
383 }
384 }
385
386 fn generate_single_operation_method(&self, op: &OperationInfo) -> TokenStream {
388 let method_name = self.get_method_name(op);
389 let http_method = self.get_http_method(op);
390 let path = &op.path;
391 let request_param = self.generate_request_param(op);
392 let request_body = self.generate_request_body(op);
393 let query_params = self.generate_query_params(op);
394 let response_type = self.get_response_type(op);
395 let has_response_body = self.get_success_response_schema(op).is_some();
396 let error_handling = self.generate_error_handling(has_response_body);
397 let url_construction = self.generate_url_construction(path, op);
398 let doc_comment = self.generate_operation_doc_comment(op);
399
400 quote! {
401 #doc_comment
402 pub async fn #method_name(
403 &self,
404 #request_param
405 ) -> HttpResult<#response_type> {
406 #url_construction
407
408 let mut req = self.http_client
409 .#http_method(url)
410 #request_body;
411
412 #query_params
413
414 if let Some(api_key) = &self.api_key {
416 req = req.bearer_auth(api_key);
417 }
418
419 for (name, value) in &self.custom_headers {
421 req = req.header(name, value);
422 }
423
424 let response = req.send().await?;
425 #error_handling
426 }
427 }
428 }
429
430 fn generate_query_params(&self, op: &OperationInfo) -> TokenStream {
432 let query_params: Vec<_> = op
433 .parameters
434 .iter()
435 .filter(|p| p.location == "query")
436 .collect();
437
438 if query_params.is_empty() {
439 return quote! {};
440 }
441
442 let mut param_building = Vec::new();
443
444 for param in query_params {
445 let param_name_snake = self.sanitize_param_name(¶m.name);
447 let param_name = syn::Ident::new(¶m_name_snake, proc_macro2::Span::call_site());
448
449 let param_key = ¶m.name;
451
452 if param.required {
453 if param.rust_type == "String" {
455 param_building.push(quote! {
456 query_params.push((#param_key, #param_name.as_ref().to_string()));
457 });
458 } else {
459 param_building.push(quote! {
460 query_params.push((#param_key, #param_name.to_string()));
461 });
462 }
463 } else {
464 if param.rust_type == "String" {
466 param_building.push(quote! {
467 if let Some(v) = #param_name {
468 query_params.push((#param_key, v.as_ref().to_string()));
469 }
470 });
471 } else {
472 param_building.push(quote! {
473 if let Some(v) = #param_name {
474 query_params.push((#param_key, v.to_string()));
475 }
476 });
477 }
478 }
479 }
480
481 quote! {
482 {
484 let mut query_params: Vec<(&str, String)> = Vec::new();
485 #(#param_building)*
486 if !query_params.is_empty() {
487 req = req.query(&query_params);
488 }
489 }
490 }
491 }
492
493 fn generate_operation_doc_comment(&self, op: &OperationInfo) -> TokenStream {
495 let method = op.method.to_uppercase();
496 let path = &op.path;
497 let doc = format!("{} {}", method, path);
498
499 quote! {
500 #[doc = #doc]
501 }
502 }
503
504 fn get_method_name(&self, op: &OperationInfo) -> syn::Ident {
506 let name = if !op.operation_id.is_empty() {
507 op.operation_id.to_snake_case()
508 } else {
509 format!(
511 "{}_{}",
512 op.method,
513 op.path.replace('/', "_").replace(['{', '}'], "")
514 )
515 .to_snake_case()
516 };
517
518 syn::Ident::new(&name, proc_macro2::Span::call_site())
519 }
520
521 fn get_http_method(&self, op: &OperationInfo) -> syn::Ident {
523 let method = match op.method.to_uppercase().as_str() {
524 "GET" => "get",
525 "POST" => "post",
526 "PUT" => "put",
527 "DELETE" => "delete",
528 "PATCH" => "patch",
529 _ => "get", };
531
532 syn::Ident::new(method, proc_macro2::Span::call_site())
533 }
534
535 fn generate_request_param(&self, op: &OperationInfo) -> TokenStream {
537 let mut params = Vec::new();
538
539 for param in &op.parameters {
541 if param.location == "path" {
542 let param_name_snake = self.sanitize_param_name(¶m.name);
543 let param_name = syn::Ident::new(¶m_name_snake, proc_macro2::Span::call_site());
544 let param_type = self.get_param_rust_type(param);
545 params.push(quote! { #param_name: #param_type });
546 }
547 }
548
549 for param in &op.parameters {
551 if param.location == "query" {
552 let param_name_snake = self.sanitize_param_name(¶m.name);
553 let param_name = syn::Ident::new(¶m_name_snake, proc_macro2::Span::call_site());
554 let param_type = self.get_param_rust_type(param);
555
556 if param.required {
558 params.push(quote! { #param_name: #param_type });
559 } else {
560 params.push(quote! { #param_name: Option<#param_type> });
561 }
562 }
563 }
564
565 if let Some(ref rb) = op.request_body {
567 use crate::analysis::RequestBodyContent;
568 match rb {
569 RequestBodyContent::Json { schema_name }
570 | RequestBodyContent::FormUrlEncoded { schema_name } => {
571 let request_ident =
572 syn::Ident::new(schema_name, proc_macro2::Span::call_site());
573 params.push(quote! { request: #request_ident });
574 }
575 RequestBodyContent::Multipart => {
576 params.push(quote! { form: reqwest::multipart::Form });
577 }
578 RequestBodyContent::OctetStream => {
579 params.push(quote! { body: Vec<u8> });
580 }
581 RequestBodyContent::TextPlain => {
582 params.push(quote! { body: String });
583 }
584 }
585 }
586
587 if params.is_empty() {
588 quote! {}
589 } else {
590 quote! { #(#params),* }
591 }
592 }
593
594 fn get_param_rust_type(&self, param: &crate::analysis::ParameterInfo) -> TokenStream {
596 let type_str = ¶m.rust_type;
597 match type_str.as_str() {
598 "String" => quote! { impl AsRef<str> },
599 "i64" => quote! { i64 },
600 "i32" => quote! { i32 },
601 "f64" => quote! { f64 },
602 "bool" => quote! { bool },
603 _ => {
604 let type_ident = syn::Ident::new(type_str, proc_macro2::Span::call_site());
605 quote! { #type_ident }
606 }
607 }
608 }
609
610 fn generate_request_body(&self, op: &OperationInfo) -> TokenStream {
612 if let Some(ref rb) = op.request_body {
613 use crate::analysis::RequestBodyContent;
614 match rb {
615 RequestBodyContent::Json { .. } => {
616 quote! {
617 .body(serde_json::to_vec(&request).map_err(HttpError::serialization_error)?)
618 .header("content-type", "application/json")
619 }
620 }
621 RequestBodyContent::FormUrlEncoded { .. } => {
622 quote! {
623 .body(serde_urlencoded::to_string(&request).map_err(HttpError::serialization_error)?)
624 .header("content-type", "application/x-www-form-urlencoded")
625 }
626 }
627 RequestBodyContent::Multipart => {
628 quote! {
629 .multipart(form)
630 }
631 }
632 RequestBodyContent::OctetStream => {
633 quote! {
634 .body(body)
635 .header("content-type", "application/octet-stream")
636 }
637 }
638 RequestBodyContent::TextPlain => {
639 quote! {
640 .body(body)
641 .header("content-type", "text/plain")
642 }
643 }
644 }
645 } else {
646 quote! {}
647 }
648 }
649
650 fn get_success_response_schema<'a>(&self, op: &'a OperationInfo) -> Option<&'a String> {
656 op.response_schemas
657 .get("200")
658 .or_else(|| op.response_schemas.get("201"))
659 .or_else(|| {
660 op.response_schemas
661 .iter()
662 .find(|(code, _)| code.starts_with('2'))
663 .map(|(_, v)| v)
664 })
665 }
666
667 fn get_response_type(&self, op: &OperationInfo) -> TokenStream {
669 if let Some(response_type) = self.get_success_response_schema(op) {
670 let rust_type_name = self.to_rust_type_name(response_type);
672 let response_ident = syn::Ident::new(&rust_type_name, proc_macro2::Span::call_site());
673 quote! { #response_ident }
674 } else {
675 quote! { () }
676 }
677 }
678
679 fn generate_error_handling(&self, has_response_body: bool) -> TokenStream {
684 let success_branch = if has_response_body {
685 quote! {
686 let body = response.json().await
687 .map_err(HttpError::deserialization_error)?;
688 Ok(body)
689 }
690 } else {
691 quote! {
692 Ok(())
693 }
694 };
695
696 quote! {
697 let status = response.status();
698
699 if status.is_success() {
700 #success_branch
701 } else {
702 let status_code = status.as_u16();
703 let message = status.canonical_reason().unwrap_or("Unknown error");
704 let body = response.text().await.ok();
705 Err(HttpError::from_status(status_code, message, body))
706 }
707 }
708 }
709
710 fn generate_url_construction(&self, path: &str, op: &OperationInfo) -> TokenStream {
712 if path.contains('{') {
714 self.generate_url_with_params(path, op)
715 } else {
716 quote! {
717 let url = format!("{}{}", self.base_url, #path);
718 }
719 }
720 }
721
722 fn generate_url_with_params(&self, path: &str, op: &OperationInfo) -> TokenStream {
724 let mut format_string = path.to_string();
726 let mut format_args = Vec::new();
727
728 let path_params: Vec<_> = op
730 .parameters
731 .iter()
732 .filter(|p| p.location == "path")
733 .collect();
734
735 for param in &path_params {
737 let placeholder = format!("{{{}}}", param.name);
738 if format_string.contains(&placeholder) {
739 format_string = format_string.replace(&placeholder, "{}");
740
741 let param_name_snake = self.sanitize_param_name(¶m.name);
743 let param_ident =
744 syn::Ident::new(¶m_name_snake, proc_macro2::Span::call_site());
745
746 if param.rust_type == "String" {
748 format_args.push(quote! { #param_ident.as_ref() });
749 } else {
750 format_args.push(quote! { #param_ident });
751 }
752 }
753 }
754
755 if format_args.is_empty() {
756 quote! {
758 let url = format!("{}{}", self.base_url, #path);
759 }
760 } else {
761 quote! {
763 let url = format!("{}{}", self.base_url, format!(#format_string, #(#format_args),*));
764 }
765 }
766 }
767
768 fn sanitize_param_name(&self, name: &str) -> String {
770 let snake_case = name.to_snake_case();
771 match snake_case.as_str() {
772 "type" => "type_".to_string(),
773 "match" => "match_".to_string(),
774 "fn" => "fn_".to_string(),
775 "impl" => "impl_".to_string(),
776 "trait" => "trait_".to_string(),
777 "struct" => "struct_".to_string(),
778 "enum" => "enum_".to_string(),
779 "mod" => "mod_".to_string(),
780 "use" => "use_".to_string(),
781 "pub" => "pub_".to_string(),
782 "const" => "const_".to_string(),
783 "static" => "static_".to_string(),
784 "let" => "let_".to_string(),
785 "mut" => "mut_".to_string(),
786 "ref" => "ref_".to_string(),
787 "move" => "move_".to_string(),
788 "return" => "return_".to_string(),
789 "if" => "if_".to_string(),
790 "else" => "else_".to_string(),
791 "while" => "while_".to_string(),
792 "for" => "for_".to_string(),
793 "loop" => "loop_".to_string(),
794 "break" => "break_".to_string(),
795 "continue" => "continue_".to_string(),
796 "self" => "self_".to_string(),
797 "super" => "super_".to_string(),
798 "crate" => "crate_".to_string(),
799 "async" => "async_".to_string(),
800 "await" => "await_".to_string(),
801 "override" => "override_".to_string(),
802 "box" => "box_".to_string(),
803 "dyn" => "dyn_".to_string(),
804 "where" => "where_".to_string(),
805 "in" => "in_".to_string(),
806 "abstract" => "abstract_".to_string(),
807 "become" => "become_".to_string(),
808 "do" => "do_".to_string(),
809 "final" => "final_".to_string(),
810 "macro" => "macro_".to_string(),
811 "priv" => "priv_".to_string(),
812 "try" => "try_".to_string(),
813 "typeof" => "typeof_".to_string(),
814 "unsized" => "unsized_".to_string(),
815 "virtual" => "virtual_".to_string(),
816 "yield" => "yield_".to_string(),
817 _ => snake_case,
818 }
819 }
820}