1use std::{
2 collections::BTreeMap,
3 fs,
4 path::{Path, PathBuf},
5};
6
7use anyhow::{Context, Result};
8use arvalez_ir::{CoreIr, HttpMethod, Operation, ParameterLocation, RequestBody, TypeRef};
9use serde::Serialize;
10use tera::{Context as TeraContext, Tera};
11
12const TEMPLATE_GO_MOD: &str = "package/go.mod.tera";
13const TEMPLATE_README: &str = "package/README.md.tera";
14const TEMPLATE_MODELS: &str = "package/models.go.tera";
15const TEMPLATE_CLIENT: &str = "package/client.go.tera";
16const TEMPLATE_MODEL_STRUCT: &str = "partials/model_struct.go.tera";
17const TEMPLATE_SERVICE: &str = "partials/service.go.tera";
18const TEMPLATE_CLIENT_METHOD: &str = "partials/client_method.go.tera";
19
20const BUILTIN_TEMPLATES: &[(&str, &str)] = &[
21 (
22 TEMPLATE_GO_MOD,
23 include_str!("../templates/package/go.mod.tera"),
24 ),
25 (
26 TEMPLATE_README,
27 include_str!("../templates/package/README.md.tera"),
28 ),
29 (
30 TEMPLATE_MODELS,
31 include_str!("../templates/package/models.go.tera"),
32 ),
33 (
34 TEMPLATE_CLIENT,
35 include_str!("../templates/package/client.go.tera"),
36 ),
37 (
38 TEMPLATE_MODEL_STRUCT,
39 include_str!("../templates/partials/model_struct.go.tera"),
40 ),
41 (
42 TEMPLATE_SERVICE,
43 include_str!("../templates/partials/service.go.tera"),
44 ),
45 (
46 TEMPLATE_CLIENT_METHOD,
47 include_str!("../templates/partials/client_method.go.tera"),
48 ),
49];
50
51const OVERRIDABLE_TEMPLATES: &[&str] = &[
52 TEMPLATE_GO_MOD,
53 TEMPLATE_README,
54 TEMPLATE_MODELS,
55 TEMPLATE_CLIENT,
56 TEMPLATE_MODEL_STRUCT,
57 TEMPLATE_SERVICE,
58 TEMPLATE_CLIENT_METHOD,
59];
60
61#[derive(Debug, Clone)]
62pub struct GoPackageConfig {
63 pub module_path: String,
64 pub package_name: String,
65 pub version: String,
66 pub template_dir: Option<PathBuf>,
67 pub group_by_tag: bool,
68}
69
70impl GoPackageConfig {
71 pub fn new(module_path: impl Into<String>) -> Self {
72 let module_path = module_path.into();
73 let package_name = default_package_name(&module_path);
74 Self {
75 module_path,
76 package_name,
77 version: "0.1.0".into(),
78 template_dir: None,
79 group_by_tag: false,
80 }
81 }
82
83 pub fn with_package_name(mut self, package_name: impl Into<String>) -> Self {
84 self.package_name = sanitize_package_name(&package_name.into());
85 self
86 }
87
88 pub fn with_version(mut self, version: impl Into<String>) -> Self {
89 self.version = version.into();
90 self
91 }
92
93 pub fn with_template_dir(mut self, template_dir: Option<PathBuf>) -> Self {
94 self.template_dir = template_dir;
95 self
96 }
97
98 pub fn with_group_by_tag(mut self, group_by_tag: bool) -> Self {
99 self.group_by_tag = group_by_tag;
100 self
101 }
102}
103
104#[derive(Debug, Clone)]
105pub struct GeneratedFile {
106 pub path: PathBuf,
107 pub contents: String,
108}
109
110pub fn generate_package(ir: &CoreIr, config: &GoPackageConfig) -> Result<Vec<GeneratedFile>> {
111 let tera = load_templates(config.template_dir.as_deref())?;
112 let package_context = PackageTemplateContext::from_ir(ir, config, &tera)?;
113 let mut context = TeraContext::new();
114 context.insert("package", &package_context);
115
116 Ok(vec![
117 GeneratedFile {
118 path: PathBuf::from("go.mod"),
119 contents: tera
120 .render(TEMPLATE_GO_MOD, &context)
121 .context("failed to render go.mod template")?,
122 },
123 GeneratedFile {
124 path: PathBuf::from("README.md"),
125 contents: tera
126 .render(TEMPLATE_README, &context)
127 .context("failed to render README template")?,
128 },
129 GeneratedFile {
130 path: PathBuf::from("models.go"),
131 contents: tera
132 .render(TEMPLATE_MODELS, &context)
133 .context("failed to render models template")?,
134 },
135 GeneratedFile {
136 path: PathBuf::from("client.go"),
137 contents: tera
138 .render(TEMPLATE_CLIENT, &context)
139 .context("failed to render client template")?,
140 },
141 ])
142}
143
144pub fn write_package(output_dir: impl AsRef<Path>, files: &[GeneratedFile]) -> Result<()> {
145 let output_dir = output_dir.as_ref();
146 fs::create_dir_all(output_dir).with_context(|| {
147 format!(
148 "failed to create output directory `{}`",
149 output_dir.display()
150 )
151 })?;
152
153 for file in files {
154 let path = output_dir.join(&file.path);
155 if let Some(parent) = path.parent() {
156 fs::create_dir_all(parent)
157 .with_context(|| format!("failed to create directory `{}`", parent.display()))?;
158 }
159 fs::write(&path, &file.contents)
160 .with_context(|| format!("failed to write `{}`", path.display()))?;
161 }
162
163 Ok(())
164}
165
166fn load_templates(template_dir: Option<&Path>) -> Result<Tera> {
167 let mut tera = Tera::default();
168 for (name, contents) in BUILTIN_TEMPLATES {
169 tera.add_raw_template(name, contents)
170 .with_context(|| format!("failed to register builtin template `{name}`"))?;
171 }
172
173 if let Some(template_dir) = template_dir {
174 for name in OVERRIDABLE_TEMPLATES {
175 let candidate = template_dir.join(name);
176 if !candidate.exists() {
177 continue;
178 }
179 let contents = fs::read_to_string(&candidate).with_context(|| {
180 format!("failed to read template override `{}`", candidate.display())
181 })?;
182 tera.add_raw_template(name, &contents).with_context(|| {
183 format!(
184 "failed to register template override `{}`",
185 candidate.display()
186 )
187 })?;
188 }
189 }
190
191 Ok(tera)
192}
193
194#[derive(Debug, Serialize)]
195struct PackageTemplateContext {
196 module_path: String,
197 package_name: String,
198 version: String,
199 model_blocks: Vec<String>,
200 service_fields_block: String,
201 service_init_block: String,
202 client_method_blocks: String,
203 service_blocks: Vec<String>,
204}
205
206impl PackageTemplateContext {
207 fn from_ir(ir: &CoreIr, config: &GoPackageConfig, tera: &Tera) -> Result<Self> {
208 let model_blocks = sorted_models(ir)
209 .into_iter()
210 .map(|model| render_model_block(tera, ModelView::from_model(model)))
211 .collect::<Result<Vec<_>>>()?;
212
213 let layout = ClientLayout::from_ir(ir);
214
215 let client_method_blocks = render_method_blocks(
216 tera,
217 if config.group_by_tag {
218 layout
219 .untagged_operations
220 .iter()
221 .map(|operation| OperationMethodView::client_method(operation))
222 .collect::<Vec<_>>()
223 } else {
224 layout
225 .all_operations
226 .iter()
227 .map(|operation| OperationMethodView::client_method(operation))
228 .collect::<Vec<_>>()
229 },
230 )?;
231
232 let service_fields_block = if config.group_by_tag {
233 indent_block(
234 &layout
235 .tag_groups
236 .iter()
237 .map(|group| format!("{} *{}", group.field_name, group.struct_name))
238 .collect::<Vec<_>>(),
239 4,
240 )
241 } else {
242 String::new()
243 };
244
245 let service_init_block = if config.group_by_tag {
246 indent_block(
247 &layout
248 .tag_groups
249 .iter()
250 .map(|group| {
251 format!(
252 "client.{} = &{}{{client: client}}",
253 group.field_name, group.struct_name
254 )
255 })
256 .collect::<Vec<_>>(),
257 4,
258 )
259 } else {
260 String::new()
261 };
262
263 let service_blocks = if config.group_by_tag {
264 layout
265 .tag_groups
266 .iter()
267 .map(|group| render_service_block(tera, ServiceView::from_group(group, tera)))
268 .collect::<Result<Vec<_>>>()?
269 } else {
270 Vec::new()
271 };
272
273 Ok(Self {
274 module_path: config.module_path.clone(),
275 package_name: config.package_name.clone(),
276 version: config.version.clone(),
277 model_blocks,
278 service_fields_block,
279 service_init_block,
280 client_method_blocks,
281 service_blocks,
282 })
283 }
284}
285
286#[derive(Debug)]
287struct ClientLayout<'a> {
288 all_operations: Vec<&'a Operation>,
289 untagged_operations: Vec<&'a Operation>,
290 tag_groups: Vec<TagGroup<'a>>,
291}
292
293impl<'a> ClientLayout<'a> {
294 fn from_ir(ir: &'a CoreIr) -> Self {
295 let all_operations = sorted_operations(ir);
296 let mut tag_map: BTreeMap<String, Vec<&Operation>> = BTreeMap::new();
297 let mut untagged_operations = Vec::new();
298
299 for operation in &all_operations {
300 match operation_primary_tag(operation) {
301 Some(tag) => tag_map.entry(tag).or_default().push(*operation),
302 None => untagged_operations.push(*operation),
303 }
304 }
305
306 let tag_groups = tag_map
307 .into_iter()
308 .map(|(tag, operations)| TagGroup::new(tag, operations))
309 .collect::<Vec<_>>();
310
311 Self {
312 all_operations,
313 untagged_operations,
314 tag_groups,
315 }
316 }
317}
318
319#[derive(Debug)]
320struct TagGroup<'a> {
321 field_name: String,
322 struct_name: String,
323 operations: Vec<&'a Operation>,
324}
325
326impl<'a> TagGroup<'a> {
327 fn new(tag: String, operations: Vec<&'a Operation>) -> Self {
328 Self {
329 field_name: sanitize_exported_identifier(&tag),
330 struct_name: format!("{}Service", sanitize_exported_identifier(&tag)),
331 operations,
332 }
333 }
334}
335
336#[derive(Debug, Serialize)]
337struct ModelView {
338 struct_name: String,
339 has_fields: bool,
340 fields_block: String,
341}
342
343impl ModelView {
344 fn from_model(model: &arvalez_ir::Model) -> Self {
345 let field_lines = model
346 .fields
347 .iter()
348 .map(ModelFieldView::from_field)
349 .map(|field| field.declaration)
350 .collect::<Vec<_>>();
351
352 Self {
353 struct_name: sanitize_exported_identifier(&model.name),
354 has_fields: !field_lines.is_empty(),
355 fields_block: indent_block(&field_lines, 4),
356 }
357 }
358}
359
360#[derive(Debug)]
361struct ModelFieldView {
362 declaration: String,
363}
364
365impl ModelFieldView {
366 fn from_field(field: &arvalez_ir::Field) -> Self {
367 let field_name = sanitize_exported_identifier(&field.name);
368 let field_type = go_field_type(&field.type_ref, field.optional, field.nullable);
369 let tag = if field.optional {
370 format!("`json:\"{},omitempty\"`", field.name)
371 } else {
372 format!("`json:\"{}\"`", field.name)
373 };
374
375 Self {
376 declaration: format!("{field_name} {field_type} {tag}"),
377 }
378 }
379}
380
381#[derive(Debug, Serialize)]
382struct ServiceView {
383 struct_name: String,
384 methods_block: String,
385}
386
387impl ServiceView {
388 fn from_group(group: &TagGroup<'_>, tera: &Tera) -> Result<Self> {
389 Ok(Self {
390 struct_name: group.struct_name.clone(),
391 methods_block: render_method_blocks(
392 tera,
393 group
394 .operations
395 .iter()
396 .map(|operation| {
397 OperationMethodView::service_method(operation, &group.struct_name)
398 })
399 .collect::<Vec<_>>(),
400 )?,
401 })
402 }
403}
404
405#[derive(Debug, Serialize)]
406struct OperationMethodView {
407 receiver_name: String,
408 receiver_type: String,
409 client_expression: String,
410 method_name: String,
411 raw_method_name: String,
412 args_signature: String,
413 return_signature: String,
414 raw_block: String,
415 wrapper_request_call_line: String,
416 wrapper_error_block: String,
417 wrapper_post_request_block: String,
418}
419
420impl OperationMethodView {
421 fn client_method(operation: &Operation) -> Self {
422 Self::from_operation(operation, "c", "Client", "c")
423 }
424
425 fn service_method(operation: &Operation, service_name: &str) -> Self {
426 Self::from_operation(operation, "s", service_name, "s.client")
427 }
428
429 fn from_operation(
430 operation: &Operation,
431 receiver_name: &str,
432 receiver_type: &str,
433 client_expression: &str,
434 ) -> Self {
435 let return_shape = go_return_shape(operation);
436 let wrapper_forward_arguments = build_forward_arguments(operation);
437
438 Self {
439 receiver_name: receiver_name.into(),
440 receiver_type: receiver_type.into(),
441 client_expression: client_expression.into(),
442 method_name: sanitize_exported_identifier(&operation.name),
443 raw_method_name: format!("{}Raw", sanitize_exported_identifier(&operation.name)),
444 args_signature: build_method_args(operation).join(", "),
445 return_signature: return_shape.signature.clone(),
446 raw_block: build_raw_block(operation),
447 wrapper_request_call_line: format!(
448 "response, err := {receiver_name}.{}({wrapper_forward_arguments})",
449 format!("{}Raw", sanitize_exported_identifier(&operation.name))
450 ),
451 wrapper_error_block: indent_block(&return_shape.raw_error_lines, 4),
452 wrapper_post_request_block: indent_block(&return_shape.post_response_lines, 4),
453 }
454 }
455}
456
457#[derive(Debug, Clone)]
458struct GoReturnShape {
459 signature: String,
460 raw_error_lines: Vec<String>,
461 post_response_lines: Vec<String>,
462}
463
464fn go_return_shape(operation: &Operation) -> GoReturnShape {
465 let success = operation
466 .responses
467 .iter()
468 .find(|response| response.status.starts_with('2'));
469
470 match success.and_then(|response| response.type_ref.as_ref()) {
471 Some(type_ref) => {
472 let result_type = go_result_type(type_ref);
473 let zero_lines = if returns_nil_on_error(type_ref) {
474 vec![
475 "if err != nil {".into(),
476 " return nil, err".into(),
477 "}".into(),
478 ]
479 } else {
480 vec![
481 "if err != nil {".into(),
482 format!(" var zero {result_type}"),
483 " return zero, err".into(),
484 "}".into(),
485 ]
486 };
487
488 let mut post_response_lines = vec![
489 "defer response.Body.Close()".into(),
490 "if err := client.handleError(response, requestOptions); err != nil {".into(),
491 ];
492 if returns_nil_on_error(type_ref) {
493 post_response_lines.push(" return nil, err".into());
494 } else {
495 post_response_lines.push(format!(" var zero {result_type}"));
496 post_response_lines.push(" return zero, err".into());
497 }
498 post_response_lines.push("}".into());
499
500 let decode_type = go_decode_type(type_ref);
501 post_response_lines.push(format!("var result {decode_type}"));
502 post_response_lines.push(
503 "if err := client.decodeJSONResponse(response, &result); err != nil {".into(),
504 );
505 if returns_nil_on_error(type_ref) {
506 post_response_lines.push(" return nil, err".into());
507 } else {
508 post_response_lines.push(format!(" var zero {result_type}"));
509 post_response_lines.push(" return zero, err".into());
510 }
511 post_response_lines.push("}".into());
512 if returns_pointer_result(type_ref) {
513 post_response_lines.push("return &result, nil".into());
514 } else {
515 post_response_lines.push("return result, nil".into());
516 }
517
518 GoReturnShape {
519 signature: format!("({result_type}, error)"),
520 raw_error_lines: zero_lines,
521 post_response_lines,
522 }
523 }
524 None => GoReturnShape {
525 signature: "error".into(),
526 raw_error_lines: vec![
527 "if err != nil {".into(),
528 " return err".into(),
529 "}".into(),
530 ],
531 post_response_lines: vec![
532 "defer response.Body.Close()".into(),
533 "if err := client.handleError(response, requestOptions); err != nil {".into(),
534 " return err".into(),
535 "}".into(),
536 "return nil".into(),
537 ],
538 },
539 }
540}
541
542fn build_raw_block(operation: &Operation) -> String {
543 let mut lines = vec![
544 render_go_path_line(&operation.path, &operation.params),
545 "query := url.Values{}".into(),
546 ];
547
548 for param in operation
549 .params
550 .iter()
551 .filter(|param| matches!(param.location, ParameterLocation::Query))
552 {
553 let name = sanitize_identifier(¶m.name);
554 if param.required {
555 lines.push(format!("query.Set({:?}, fmt.Sprint({name}))", param.name));
556 } else {
557 lines.push(format!("if {name} != nil {{"));
558 lines.push(format!(
559 " query.Set({:?}, fmt.Sprint(*{name}))",
560 param.name
561 ));
562 lines.push("}".into());
563 }
564 }
565 lines.push("query = client.mergeQuery(query, requestOptions)".into());
566
567 lines.push("headers := http.Header{}".into());
568 for param in operation
569 .params
570 .iter()
571 .filter(|param| matches!(param.location, ParameterLocation::Header))
572 {
573 let name = sanitize_identifier(¶m.name);
574 if param.required {
575 lines.push(format!("headers.Set({:?}, fmt.Sprint({name}))", param.name));
576 } else {
577 lines.push(format!("if {name} != nil {{"));
578 lines.push(format!(
579 " headers.Set({:?}, fmt.Sprint(*{name}))",
580 param.name
581 ));
582 lines.push("}".into());
583 }
584 }
585
586 lines.push("cookies := []*http.Cookie{}".into());
587 for param in operation
588 .params
589 .iter()
590 .filter(|param| matches!(param.location, ParameterLocation::Cookie))
591 {
592 let name = sanitize_identifier(¶m.name);
593 if param.required {
594 lines.push(format!(
595 "cookies = append(cookies, &http.Cookie{{Name: {:?}, Value: fmt.Sprint({name})}})",
596 param.name
597 ));
598 } else {
599 lines.push(format!("if {name} != nil {{"));
600 lines.push(format!(
601 " cookies = append(cookies, &http.Cookie{{Name: {:?}, Value: fmt.Sprint(*{name})}})",
602 param.name
603 ));
604 lines.push("}".into());
605 }
606 }
607 lines.push("cookies = client.mergeCookies(cookies, requestOptions)".into());
608
609 lines.push("var bodyReader io.Reader".into());
610 lines.push("var err error".into());
611 if let Some(request_body) = &operation.request_body {
612 match classify_request_body(request_body) {
613 RequestBodyKind::Json => {
614 lines.push("headers.Set(\"Content-Type\", \"application/json\")".into());
615 if request_body.required {
616 lines.push("bodyReader, err = client.encodeJSONBody(body)".into());
617 lines.push("if err != nil {".into());
618 lines.push(" return nil, err".into());
619 lines.push("}".into());
620 } else {
621 lines.push("if body != nil {".into());
622 lines.push(" bodyReader, err = client.encodeJSONBody(body)".into());
623 lines.push(" if err != nil {".into());
624 lines.push(" return nil, err".into());
625 lines.push(" }".into());
626 lines.push("}".into());
627 }
628 }
629 RequestBodyKind::Multipart => {
630 lines.push("var contentType string".into());
631 if request_body.required {
632 lines.push(
633 "bodyReader, contentType, err = client.encodeMultipartBody(body)".into(),
634 );
635 lines.push("if err != nil {".into());
636 lines.push(" return nil, err".into());
637 lines.push("}".into());
638 lines.push("headers.Set(\"Content-Type\", contentType)".into());
639 } else {
640 lines.push("if body != nil {".into());
641 lines.push(
642 " bodyReader, contentType, err = client.encodeMultipartBody(body)"
643 .into(),
644 );
645 lines.push(" if err != nil {".into());
646 lines.push(" return nil, err".into());
647 lines.push(" }".into());
648 lines.push(" headers.Set(\"Content-Type\", contentType)".into());
649 lines.push("}".into());
650 }
651 }
652 RequestBodyKind::BinaryOrOther => {
653 if request_body.required {
654 lines.push("bodyReader = body".into());
655 } else {
656 lines.push("bodyReader = body".into());
657 }
658 }
659 }
660 }
661 lines.push("headers = client.mergeHeaders(headers, requestOptions)".into());
662 lines.push(format!(
663 "request, err := http.NewRequestWithContext(client.resolveContext(ctx, requestOptions), {}, client.buildURL(path, query), bodyReader)",
664 go_http_method(operation.method)
665 ));
666 lines.push("if err != nil {".into());
667 lines.push(" return nil, err".into());
668 lines.push("}".into());
669 lines.push("request.Header = headers".into());
670 lines.push("for _, cookie := range cookies {".into());
671 lines.push(" request.AddCookie(cookie)".into());
672 lines.push("}".into());
673 lines.push("return client.httpClient.Do(request)".into());
674
675 indent_block(&lines, 4)
676}
677
678#[derive(Debug, Clone, Copy)]
679enum RequestBodyKind {
680 Json,
681 Multipart,
682 BinaryOrOther,
683}
684
685fn classify_request_body(request_body: &RequestBody) -> RequestBodyKind {
686 if request_body.media_type == "application/json" {
687 RequestBodyKind::Json
688 } else if request_body.media_type.starts_with("multipart/form-data") {
689 RequestBodyKind::Multipart
690 } else {
691 RequestBodyKind::BinaryOrOther
692 }
693}
694
695fn render_go_path_line(path: &str, params: &[arvalez_ir::Parameter]) -> String {
696 let path_params = params
697 .iter()
698 .filter(|param| matches!(param.location, ParameterLocation::Path))
699 .collect::<Vec<_>>();
700
701 if path_params.is_empty() {
702 format!("path := {:?}", path)
703 } else {
704 let mut format_path = String::new();
705 let mut chars = path.chars().peekable();
706 while let Some(ch) = chars.next() {
707 match ch {
708 '{' => {
709 while let Some(next) = chars.peek() {
710 if *next == '}' {
711 chars.next();
712 break;
713 }
714 chars.next();
715 }
716 format_path.push_str("%s");
717 }
718 '%' => format_path.push_str("%%"),
719 _ => format_path.push(ch),
720 }
721 }
722 let arguments = path_params
723 .into_iter()
724 .map(|param| {
725 format!(
726 "url.PathEscape(fmt.Sprint({}))",
727 sanitize_identifier(¶m.name)
728 )
729 })
730 .collect::<Vec<_>>()
731 .join(", ");
732 format!("path := fmt.Sprintf({format_path:?}, {arguments})")
733 }
734}
735
736fn render_model_block(tera: &Tera, model: ModelView) -> Result<String> {
737 let mut context = TeraContext::new();
738 context.insert("model", &model);
739 tera.render(TEMPLATE_MODEL_STRUCT, &context)
740 .context("failed to render model struct partial")
741}
742
743fn render_service_block(tera: &Tera, service: Result<ServiceView>) -> Result<String> {
744 let service = service?;
745 let mut context = TeraContext::new();
746 context.insert("service", &service);
747 tera.render(TEMPLATE_SERVICE, &context)
748 .context("failed to render service partial")
749}
750
751fn render_method_blocks(tera: &Tera, methods: Vec<OperationMethodView>) -> Result<String> {
752 methods
753 .into_iter()
754 .map(|method| {
755 let mut context = TeraContext::new();
756 context.insert("operation", &method);
757 tera.render(TEMPLATE_CLIENT_METHOD, &context)
758 .context("failed to render client method partial")
759 })
760 .collect::<Result<Vec<_>>>()
761 .map(|blocks| blocks.join("\n\n"))
762}
763
764fn build_method_args(operation: &Operation) -> Vec<String> {
765 let mut args = vec!["ctx context.Context".into()];
766
767 for param in operation.params.iter().filter(|param| param.required) {
768 args.push(format!(
769 "{} {}",
770 sanitize_identifier(¶m.name),
771 go_required_arg_type(¶m.type_ref)
772 ));
773 }
774
775 if let Some(request_body) = &operation.request_body
776 && request_body.required
777 {
778 args.push(format!("body {}", go_body_arg_type(request_body, true)));
779 }
780
781 for param in operation.params.iter().filter(|param| !param.required) {
782 args.push(format!(
783 "{} {}",
784 sanitize_identifier(¶m.name),
785 go_optional_arg_type(¶m.type_ref)
786 ));
787 }
788
789 if let Some(request_body) = &operation.request_body
790 && !request_body.required
791 {
792 args.push(format!("body {}", go_body_arg_type(request_body, false)));
793 }
794
795 args.push("requestOptions *RequestOptions".into());
796 args
797}
798
799fn build_forward_arguments(operation: &Operation) -> String {
800 let mut args = vec!["ctx".into()];
801
802 for param in operation.params.iter().filter(|param| param.required) {
803 args.push(sanitize_identifier(¶m.name));
804 }
805 if let Some(request_body) = &operation.request_body
806 && request_body.required
807 {
808 let _ = request_body;
809 args.push("body".into());
810 }
811 for param in operation.params.iter().filter(|param| !param.required) {
812 args.push(sanitize_identifier(¶m.name));
813 }
814 if let Some(request_body) = &operation.request_body
815 && !request_body.required
816 {
817 let _ = request_body;
818 args.push("body".into());
819 }
820
821 args.push("requestOptions".into());
822 args.join(", ")
823}
824
825fn sorted_models(ir: &CoreIr) -> Vec<&arvalez_ir::Model> {
826 let mut models = ir.models.iter().collect::<Vec<_>>();
827 models.sort_by(|left, right| left.name.cmp(&right.name));
828 models
829}
830
831fn sorted_operations(ir: &CoreIr) -> Vec<&Operation> {
832 let mut operations = ir.operations.iter().collect::<Vec<_>>();
833 operations.sort_by(|left, right| left.name.cmp(&right.name));
834 operations
835}
836
837fn operation_primary_tag(operation: &Operation) -> Option<String> {
838 operation
839 .attributes
840 .get("tags")
841 .and_then(|value| value.as_array())
842 .and_then(|tags| tags.first())
843 .and_then(|tag| tag.as_str())
844 .map(str::trim)
845 .filter(|tag| !tag.is_empty())
846 .map(ToOwned::to_owned)
847}
848
849fn go_field_type(type_ref: &TypeRef, optional: bool, nullable: bool) -> String {
850 let base = go_type_ref(type_ref);
851 if optional || nullable {
852 match type_ref {
853 TypeRef::Primitive { name } if name == "string" => "*string".into(),
854 TypeRef::Primitive { name } if name == "integer" => "*int64".into(),
855 TypeRef::Primitive { name } if name == "number" => "*float64".into(),
856 TypeRef::Primitive { name } if name == "boolean" => "*bool".into(),
857 TypeRef::Named { .. } => format!("*{base}"),
858 _ => base,
859 }
860 } else {
861 base
862 }
863}
864
865fn go_type_ref(type_ref: &TypeRef) -> String {
866 match type_ref {
867 TypeRef::Primitive { name } => match name.as_str() {
868 "string" => "string".into(),
869 "integer" => "int64".into(),
870 "number" => "float64".into(),
871 "boolean" => "bool".into(),
872 "binary" => "[]byte".into(),
873 "null" => "any".into(),
874 "any" | "object" => "any".into(),
875 _ => "any".into(),
876 },
877 TypeRef::Named { name } => sanitize_exported_identifier(name),
878 TypeRef::Array { item } => format!("[]{}", go_type_ref(item)),
879 TypeRef::Map { value } => format!("map[string]{}", go_type_ref(value)),
880 TypeRef::Union { .. } => "any".into(),
881 }
882}
883
884fn go_body_arg_type(request_body: &RequestBody, required: bool) -> String {
885 match request_body.type_ref.as_ref() {
886 Some(TypeRef::Named { name }) => format!("*{}", sanitize_exported_identifier(name)),
887 Some(type_ref) => {
888 let base = go_type_ref(type_ref);
889 if required {
890 base
891 } else {
892 match type_ref {
893 TypeRef::Primitive { name } if name == "string" => "*string".into(),
894 TypeRef::Primitive { name } if name == "integer" => "*int64".into(),
895 TypeRef::Primitive { name } if name == "number" => "*float64".into(),
896 TypeRef::Primitive { name } if name == "boolean" => "*bool".into(),
897 _ => base,
898 }
899 }
900 }
901 None => "io.Reader".into(),
902 }
903}
904
905fn go_required_arg_type(type_ref: &TypeRef) -> String {
906 go_type_ref(type_ref)
907}
908
909fn go_optional_arg_type(type_ref: &TypeRef) -> String {
910 match type_ref {
911 TypeRef::Primitive { name } if name == "string" => "*string".into(),
912 TypeRef::Primitive { name } if name == "integer" => "*int64".into(),
913 TypeRef::Primitive { name } if name == "number" => "*float64".into(),
914 TypeRef::Primitive { name } if name == "boolean" => "*bool".into(),
915 TypeRef::Named { name } => format!("*{}", sanitize_exported_identifier(name)),
916 _ => go_type_ref(type_ref),
917 }
918}
919
920fn go_result_type(type_ref: &TypeRef) -> String {
921 if returns_pointer_result(type_ref) {
922 format!("*{}", go_decode_type(type_ref))
923 } else {
924 go_decode_type(type_ref)
925 }
926}
927
928fn go_decode_type(type_ref: &TypeRef) -> String {
929 match type_ref {
930 TypeRef::Named { name } => sanitize_exported_identifier(name),
931 _ => go_type_ref(type_ref),
932 }
933}
934
935fn returns_pointer_result(type_ref: &TypeRef) -> bool {
936 matches!(type_ref, TypeRef::Named { .. })
937}
938
939fn returns_nil_on_error(type_ref: &TypeRef) -> bool {
940 match type_ref {
941 TypeRef::Named { .. } | TypeRef::Array { .. } | TypeRef::Map { .. } => true,
942 TypeRef::Primitive { name } => name == "binary",
943 TypeRef::Union { .. } => false,
944 }
945}
946
947fn go_http_method(method: HttpMethod) -> &'static str {
948 match method {
949 HttpMethod::Get => "http.MethodGet",
950 HttpMethod::Post => "http.MethodPost",
951 HttpMethod::Put => "http.MethodPut",
952 HttpMethod::Patch => "http.MethodPatch",
953 HttpMethod::Delete => "http.MethodDelete",
954 }
955}
956
957fn indent_block(lines: &[String], spaces: usize) -> String {
958 let indent = " ".repeat(spaces);
959 lines
960 .iter()
961 .map(|line| format!("{indent}{line}"))
962 .collect::<Vec<_>>()
963 .join("\n")
964}
965
966fn default_package_name(module_path: &str) -> String {
967 module_path
968 .rsplit('/')
969 .next()
970 .map(sanitize_package_name)
971 .unwrap_or_else(|| "client".into())
972}
973
974fn sanitize_package_name(name: &str) -> String {
975 let mut out = split_words(name).join("");
976 if out.is_empty() {
977 out = "client".into();
978 }
979 if out.chars().next().is_some_and(|ch| ch.is_ascii_digit()) {
980 out.insert(0, 'x');
981 }
982 if is_go_keyword(&out) {
983 out.push_str("pkg");
984 }
985 out.to_ascii_lowercase()
986}
987
988fn sanitize_exported_identifier(name: &str) -> String {
989 let mut out = String::new();
990 for word in split_words(name) {
991 let mut chars = word.chars();
992 if let Some(first) = chars.next() {
993 out.extend(first.to_uppercase());
994 out.push_str(chars.as_str());
995 }
996 }
997 if out.is_empty() {
998 out = "Generated".into();
999 }
1000 if out.chars().next().is_some_and(|ch| ch.is_ascii_digit()) {
1001 out.insert(0, 'X');
1002 }
1003 if is_go_keyword(&out.to_ascii_lowercase()) {
1004 out.push('_');
1005 }
1006 out
1007}
1008
1009fn sanitize_identifier(name: &str) -> String {
1010 let words = split_words(name);
1011 let mut out = if words.is_empty() {
1012 "value".into()
1013 } else {
1014 let mut iter = words.into_iter();
1015 let mut result = iter.next().unwrap_or_else(|| "value".into());
1016 for word in iter {
1017 let mut chars = word.chars();
1018 if let Some(first) = chars.next() {
1019 result.extend(first.to_uppercase());
1020 result.push_str(chars.as_str());
1021 }
1022 }
1023 result
1024 };
1025
1026 if out.chars().next().is_some_and(|ch| ch.is_ascii_digit()) {
1027 out.insert(0, 'x');
1028 }
1029 if is_go_keyword(&out) {
1030 out.push('_');
1031 }
1032 out
1033}
1034
1035fn split_words(input: &str) -> Vec<String> {
1036 let mut words = Vec::new();
1037 let mut current = String::new();
1038
1039 for ch in input.chars() {
1040 if ch.is_ascii_alphanumeric() {
1041 if ch.is_uppercase() && !current.is_empty() {
1042 words.push(current.clone());
1043 current.clear();
1044 }
1045 current.push(ch.to_ascii_lowercase());
1046 } else if !current.is_empty() {
1047 words.push(current.clone());
1048 current.clear();
1049 }
1050 }
1051
1052 if !current.is_empty() {
1053 words.push(current);
1054 }
1055
1056 words
1057}
1058
1059fn is_go_keyword(value: &str) -> bool {
1060 matches!(
1061 value,
1062 "break"
1063 | "default"
1064 | "func"
1065 | "interface"
1066 | "select"
1067 | "case"
1068 | "defer"
1069 | "go"
1070 | "map"
1071 | "struct"
1072 | "chan"
1073 | "else"
1074 | "goto"
1075 | "package"
1076 | "switch"
1077 | "const"
1078 | "fallthrough"
1079 | "if"
1080 | "range"
1081 | "type"
1082 | "continue"
1083 | "for"
1084 | "import"
1085 | "return"
1086 | "var"
1087 )
1088}
1089
1090#[cfg(test)]
1091mod tests {
1092 use std::fs;
1093
1094 use super::*;
1095 use arvalez_ir::{Attributes, Field, Parameter, Response};
1096 use serde_json::json;
1097 use tempfile::tempdir;
1098
1099 fn sample_ir() -> CoreIr {
1100 CoreIr {
1101 models: vec![arvalez_ir::Model {
1102 id: "model.widget".into(),
1103 name: "Widget".into(),
1104 fields: vec![
1105 Field::new("id", TypeRef::primitive("string")),
1106 Field {
1107 name: "count".into(),
1108 type_ref: TypeRef::primitive("integer"),
1109 optional: true,
1110 nullable: false,
1111 attributes: Attributes::default(),
1112 },
1113 ],
1114 attributes: Attributes::default(),
1115 source: None,
1116 }],
1117 operations: vec![Operation {
1118 id: "operation.get_widget".into(),
1119 name: "get_widget".into(),
1120 method: HttpMethod::Get,
1121 path: "/widgets/{widget_id}".into(),
1122 params: vec![
1123 Parameter {
1124 name: "widget_id".into(),
1125 location: ParameterLocation::Path,
1126 type_ref: TypeRef::primitive("string"),
1127 required: true,
1128 },
1129 Parameter {
1130 name: "include_count".into(),
1131 location: ParameterLocation::Query,
1132 type_ref: TypeRef::primitive("boolean"),
1133 required: false,
1134 },
1135 ],
1136 request_body: Some(RequestBody {
1137 required: false,
1138 media_type: "application/json".into(),
1139 type_ref: Some(TypeRef::named("Widget")),
1140 }),
1141 responses: vec![Response {
1142 status: "200".into(),
1143 media_type: Some("application/json".into()),
1144 type_ref: Some(TypeRef::named("Widget")),
1145 attributes: Attributes::default(),
1146 }],
1147 attributes: Attributes::from([("tags".into(), json!(["widgets"]))]),
1148 source: None,
1149 }],
1150 ..Default::default()
1151 }
1152 }
1153
1154 #[test]
1155 fn renders_basic_go_package() {
1156 let files = generate_package(
1157 &sample_ir(),
1158 &GoPackageConfig::new("github.com/demo/client"),
1159 )
1160 .expect("package should render");
1161
1162 let go_mod = files
1163 .iter()
1164 .find(|file| file.path.ends_with("go.mod"))
1165 .expect("go.mod");
1166 let models = files
1167 .iter()
1168 .find(|file| file.path.ends_with("models.go"))
1169 .expect("models.go");
1170 let client = files
1171 .iter()
1172 .find(|file| file.path.ends_with("client.go"))
1173 .expect("client.go");
1174
1175 assert!(go_mod.contents.contains("module github.com/demo/client"));
1176 assert!(models.contents.contains("type Widget struct"));
1177 assert!(
1178 models
1179 .contents
1180 .contains("Count *int64 `json:\"count,omitempty\"`")
1181 );
1182 assert!(
1183 client
1184 .contents
1185 .contains("type ErrorHandler func(*http.Response) error")
1186 );
1187 assert!(client.contents.contains("type RequestOptions struct"));
1188 assert!(client.contents.contains("func (c *Client) GetWidgetRaw("));
1189 assert!(client.contents.contains("func (c *Client) GetWidget("));
1190 assert!(client.contents.contains("requestOptions *RequestOptions"));
1191 assert!(
1192 client
1193 .contents
1194 .contains("if err := client.handleError(response, requestOptions); err != nil {")
1195 );
1196 assert!(client.contents.contains("response, err := c.GetWidgetRaw("));
1197 }
1198
1199 #[test]
1200 fn groups_operations_by_tag_when_enabled() {
1201 let files = generate_package(
1202 &sample_ir(),
1203 &GoPackageConfig::new("github.com/demo/client").with_group_by_tag(true),
1204 )
1205 .expect("package should render");
1206 let client = files
1207 .iter()
1208 .find(|file| file.path.ends_with("client.go"))
1209 .expect("client.go");
1210
1211 assert!(client.contents.contains("Widgets *WidgetsService"));
1212 assert!(
1213 client
1214 .contents
1215 .contains("client.Widgets = &WidgetsService{client: client}")
1216 );
1217 assert!(client.contents.contains("type WidgetsService struct"));
1218 assert!(
1219 client
1220 .contents
1221 .contains("func (s *WidgetsService) GetWidgetRaw(")
1222 );
1223 }
1224
1225 #[test]
1226 fn supports_selective_template_overrides() {
1227 let tempdir = tempdir().expect("tempdir");
1228 let partial_dir = tempdir.path().join("partials");
1229 fs::create_dir_all(&partial_dir).expect("partials dir");
1230 fs::write(
1231 partial_dir.join("service.go.tera"),
1232 "type {{ service.struct_name }} struct { Overridden bool }\n",
1233 )
1234 .expect("override template");
1235
1236 let files = generate_package(
1237 &sample_ir(),
1238 &GoPackageConfig::new("github.com/demo/client")
1239 .with_group_by_tag(true)
1240 .with_template_dir(Some(tempdir.path().to_path_buf())),
1241 )
1242 .expect("package should render");
1243 let client = files
1244 .iter()
1245 .find(|file| file.path.ends_with("client.go"))
1246 .expect("client.go");
1247
1248 assert!(
1249 client
1250 .contents
1251 .contains("type WidgetsService struct { Overridden bool }")
1252 );
1253 }
1254}