Skip to main content

arvalez_target_go/
lib.rs

1use std::{
2    collections::BTreeMap,
3    fs,
4    path::{Path, PathBuf},
5};
6
7use anyhow::{Context, Result};
8use arvalez_ir::{CoreIr, HttpMethod, Operation, ParameterLocation, RequestBody, TypeRef};
9use serde::Serialize;
10use tera::{Context as TeraContext, Tera};
11
12const TEMPLATE_GO_MOD: &str = "package/go.mod.tera";
13const TEMPLATE_README: &str = "package/README.md.tera";
14const TEMPLATE_MODELS: &str = "package/models.go.tera";
15const TEMPLATE_CLIENT: &str = "package/client.go.tera";
16const TEMPLATE_MODEL_STRUCT: &str = "partials/model_struct.go.tera";
17const TEMPLATE_SERVICE: &str = "partials/service.go.tera";
18const TEMPLATE_CLIENT_METHOD: &str = "partials/client_method.go.tera";
19
20const BUILTIN_TEMPLATES: &[(&str, &str)] = &[
21    (
22        TEMPLATE_GO_MOD,
23        include_str!("../templates/package/go.mod.tera"),
24    ),
25    (
26        TEMPLATE_README,
27        include_str!("../templates/package/README.md.tera"),
28    ),
29    (
30        TEMPLATE_MODELS,
31        include_str!("../templates/package/models.go.tera"),
32    ),
33    (
34        TEMPLATE_CLIENT,
35        include_str!("../templates/package/client.go.tera"),
36    ),
37    (
38        TEMPLATE_MODEL_STRUCT,
39        include_str!("../templates/partials/model_struct.go.tera"),
40    ),
41    (
42        TEMPLATE_SERVICE,
43        include_str!("../templates/partials/service.go.tera"),
44    ),
45    (
46        TEMPLATE_CLIENT_METHOD,
47        include_str!("../templates/partials/client_method.go.tera"),
48    ),
49];
50
51const OVERRIDABLE_TEMPLATES: &[&str] = &[
52    TEMPLATE_GO_MOD,
53    TEMPLATE_README,
54    TEMPLATE_MODELS,
55    TEMPLATE_CLIENT,
56    TEMPLATE_MODEL_STRUCT,
57    TEMPLATE_SERVICE,
58    TEMPLATE_CLIENT_METHOD,
59];
60
61#[derive(Debug, Clone)]
62pub struct GoPackageConfig {
63    pub module_path: String,
64    pub package_name: String,
65    pub version: String,
66    pub template_dir: Option<PathBuf>,
67    pub group_by_tag: bool,
68}
69
70impl GoPackageConfig {
71    pub fn new(module_path: impl Into<String>) -> Self {
72        let module_path = module_path.into();
73        let package_name = default_package_name(&module_path);
74        Self {
75            module_path,
76            package_name,
77            version: "0.1.0".into(),
78            template_dir: None,
79            group_by_tag: false,
80        }
81    }
82
83    pub fn with_package_name(mut self, package_name: impl Into<String>) -> Self {
84        self.package_name = sanitize_package_name(&package_name.into());
85        self
86    }
87
88    pub fn with_version(mut self, version: impl Into<String>) -> Self {
89        self.version = version.into();
90        self
91    }
92
93    pub fn with_template_dir(mut self, template_dir: Option<PathBuf>) -> Self {
94        self.template_dir = template_dir;
95        self
96    }
97
98    pub fn with_group_by_tag(mut self, group_by_tag: bool) -> Self {
99        self.group_by_tag = group_by_tag;
100        self
101    }
102}
103
104#[derive(Debug, Clone)]
105pub struct GeneratedFile {
106    pub path: PathBuf,
107    pub contents: String,
108}
109
110pub fn generate_package(ir: &CoreIr, config: &GoPackageConfig) -> Result<Vec<GeneratedFile>> {
111    let tera = load_templates(config.template_dir.as_deref())?;
112    let package_context = PackageTemplateContext::from_ir(ir, config, &tera)?;
113    let mut context = TeraContext::new();
114    context.insert("package", &package_context);
115
116    Ok(vec![
117        GeneratedFile {
118            path: PathBuf::from("go.mod"),
119            contents: tera
120                .render(TEMPLATE_GO_MOD, &context)
121                .context("failed to render go.mod template")?,
122        },
123        GeneratedFile {
124            path: PathBuf::from("README.md"),
125            contents: tera
126                .render(TEMPLATE_README, &context)
127                .context("failed to render README template")?,
128        },
129        GeneratedFile {
130            path: PathBuf::from("models.go"),
131            contents: tera
132                .render(TEMPLATE_MODELS, &context)
133                .context("failed to render models template")?,
134        },
135        GeneratedFile {
136            path: PathBuf::from("client.go"),
137            contents: tera
138                .render(TEMPLATE_CLIENT, &context)
139                .context("failed to render client template")?,
140        },
141    ])
142}
143
144pub fn write_package(output_dir: impl AsRef<Path>, files: &[GeneratedFile]) -> Result<()> {
145    let output_dir = output_dir.as_ref();
146    fs::create_dir_all(output_dir).with_context(|| {
147        format!(
148            "failed to create output directory `{}`",
149            output_dir.display()
150        )
151    })?;
152
153    for file in files {
154        let path = output_dir.join(&file.path);
155        if let Some(parent) = path.parent() {
156            fs::create_dir_all(parent)
157                .with_context(|| format!("failed to create directory `{}`", parent.display()))?;
158        }
159        fs::write(&path, &file.contents)
160            .with_context(|| format!("failed to write `{}`", path.display()))?;
161    }
162
163    Ok(())
164}
165
166fn load_templates(template_dir: Option<&Path>) -> Result<Tera> {
167    let mut tera = Tera::default();
168    for (name, contents) in BUILTIN_TEMPLATES {
169        tera.add_raw_template(name, contents)
170            .with_context(|| format!("failed to register builtin template `{name}`"))?;
171    }
172
173    if let Some(template_dir) = template_dir {
174        for name in OVERRIDABLE_TEMPLATES {
175            let candidate = template_dir.join(name);
176            if !candidate.exists() {
177                continue;
178            }
179            let contents = fs::read_to_string(&candidate).with_context(|| {
180                format!("failed to read template override `{}`", candidate.display())
181            })?;
182            tera.add_raw_template(name, &contents).with_context(|| {
183                format!(
184                    "failed to register template override `{}`",
185                    candidate.display()
186                )
187            })?;
188        }
189    }
190
191    Ok(tera)
192}
193
194#[derive(Debug, Serialize)]
195struct PackageTemplateContext {
196    module_path: String,
197    package_name: String,
198    version: String,
199    model_blocks: Vec<String>,
200    service_fields_block: String,
201    service_init_block: String,
202    client_method_blocks: String,
203    service_blocks: Vec<String>,
204}
205
206impl PackageTemplateContext {
207    fn from_ir(ir: &CoreIr, config: &GoPackageConfig, tera: &Tera) -> Result<Self> {
208        let model_blocks = sorted_models(ir)
209            .into_iter()
210            .map(|model| render_model_block(tera, ModelView::from_model(model)))
211            .collect::<Result<Vec<_>>>()?;
212
213        let layout = ClientLayout::from_ir(ir);
214
215        let client_method_blocks = render_method_blocks(
216            tera,
217            if config.group_by_tag {
218                layout
219                    .untagged_operations
220                    .iter()
221                    .map(|operation| OperationMethodView::client_method(operation))
222                    .collect::<Vec<_>>()
223            } else {
224                layout
225                    .all_operations
226                    .iter()
227                    .map(|operation| OperationMethodView::client_method(operation))
228                    .collect::<Vec<_>>()
229            },
230        )?;
231
232        let service_fields_block = if config.group_by_tag {
233            indent_block(
234                &layout
235                    .tag_groups
236                    .iter()
237                    .map(|group| format!("{} *{}", group.field_name, group.struct_name))
238                    .collect::<Vec<_>>(),
239                4,
240            )
241        } else {
242            String::new()
243        };
244
245        let service_init_block = if config.group_by_tag {
246            indent_block(
247                &layout
248                    .tag_groups
249                    .iter()
250                    .map(|group| {
251                        format!(
252                            "client.{} = &{}{{client: client}}",
253                            group.field_name, group.struct_name
254                        )
255                    })
256                    .collect::<Vec<_>>(),
257                4,
258            )
259        } else {
260            String::new()
261        };
262
263        let service_blocks = if config.group_by_tag {
264            layout
265                .tag_groups
266                .iter()
267                .map(|group| render_service_block(tera, ServiceView::from_group(group, tera)))
268                .collect::<Result<Vec<_>>>()?
269        } else {
270            Vec::new()
271        };
272
273        Ok(Self {
274            module_path: config.module_path.clone(),
275            package_name: config.package_name.clone(),
276            version: config.version.clone(),
277            model_blocks,
278            service_fields_block,
279            service_init_block,
280            client_method_blocks,
281            service_blocks,
282        })
283    }
284}
285
286#[derive(Debug)]
287struct ClientLayout<'a> {
288    all_operations: Vec<&'a Operation>,
289    untagged_operations: Vec<&'a Operation>,
290    tag_groups: Vec<TagGroup<'a>>,
291}
292
293impl<'a> ClientLayout<'a> {
294    fn from_ir(ir: &'a CoreIr) -> Self {
295        let all_operations = sorted_operations(ir);
296        let mut tag_map: BTreeMap<String, Vec<&Operation>> = BTreeMap::new();
297        let mut untagged_operations = Vec::new();
298
299        for operation in &all_operations {
300            match operation_primary_tag(operation) {
301                Some(tag) => tag_map.entry(tag).or_default().push(*operation),
302                None => untagged_operations.push(*operation),
303            }
304        }
305
306        let tag_groups = tag_map
307            .into_iter()
308            .map(|(tag, operations)| TagGroup::new(tag, operations))
309            .collect::<Vec<_>>();
310
311        Self {
312            all_operations,
313            untagged_operations,
314            tag_groups,
315        }
316    }
317}
318
319#[derive(Debug)]
320struct TagGroup<'a> {
321    field_name: String,
322    struct_name: String,
323    operations: Vec<&'a Operation>,
324}
325
326impl<'a> TagGroup<'a> {
327    fn new(tag: String, operations: Vec<&'a Operation>) -> Self {
328        Self {
329            field_name: sanitize_exported_identifier(&tag),
330            struct_name: format!("{}Service", sanitize_exported_identifier(&tag)),
331            operations,
332        }
333    }
334}
335
336#[derive(Debug, Serialize)]
337struct ModelView {
338    struct_name: String,
339    has_fields: bool,
340    fields_block: String,
341}
342
343impl ModelView {
344    fn from_model(model: &arvalez_ir::Model) -> Self {
345        let field_lines = model
346            .fields
347            .iter()
348            .map(ModelFieldView::from_field)
349            .map(|field| field.declaration)
350            .collect::<Vec<_>>();
351
352        Self {
353            struct_name: sanitize_exported_identifier(&model.name),
354            has_fields: !field_lines.is_empty(),
355            fields_block: indent_block(&field_lines, 4),
356        }
357    }
358}
359
360#[derive(Debug)]
361struct ModelFieldView {
362    declaration: String,
363}
364
365impl ModelFieldView {
366    fn from_field(field: &arvalez_ir::Field) -> Self {
367        let field_name = sanitize_exported_identifier(&field.name);
368        let field_type = go_field_type(&field.type_ref, field.optional, field.nullable);
369        let tag = if field.optional {
370            format!("`json:\"{},omitempty\"`", field.name)
371        } else {
372            format!("`json:\"{}\"`", field.name)
373        };
374
375        Self {
376            declaration: format!("{field_name} {field_type} {tag}"),
377        }
378    }
379}
380
381#[derive(Debug, Serialize)]
382struct ServiceView {
383    struct_name: String,
384    methods_block: String,
385}
386
387impl ServiceView {
388    fn from_group(group: &TagGroup<'_>, tera: &Tera) -> Result<Self> {
389        Ok(Self {
390            struct_name: group.struct_name.clone(),
391            methods_block: render_method_blocks(
392                tera,
393                group
394                    .operations
395                    .iter()
396                    .map(|operation| {
397                        OperationMethodView::service_method(operation, &group.struct_name)
398                    })
399                    .collect::<Vec<_>>(),
400            )?,
401        })
402    }
403}
404
405#[derive(Debug, Serialize)]
406struct OperationMethodView {
407    receiver_name: String,
408    receiver_type: String,
409    client_expression: String,
410    method_name: String,
411    raw_method_name: String,
412    args_signature: String,
413    return_signature: String,
414    raw_block: String,
415    wrapper_request_call_line: String,
416    wrapper_error_block: String,
417    wrapper_post_request_block: String,
418}
419
420impl OperationMethodView {
421    fn client_method(operation: &Operation) -> Self {
422        Self::from_operation(operation, "c", "Client", "c")
423    }
424
425    fn service_method(operation: &Operation, service_name: &str) -> Self {
426        Self::from_operation(operation, "s", service_name, "s.client")
427    }
428
429    fn from_operation(
430        operation: &Operation,
431        receiver_name: &str,
432        receiver_type: &str,
433        client_expression: &str,
434    ) -> Self {
435        let return_shape = go_return_shape(operation);
436        let wrapper_forward_arguments = build_forward_arguments(operation);
437
438        Self {
439            receiver_name: receiver_name.into(),
440            receiver_type: receiver_type.into(),
441            client_expression: client_expression.into(),
442            method_name: sanitize_exported_identifier(&operation.name),
443            raw_method_name: format!("{}Raw", sanitize_exported_identifier(&operation.name)),
444            args_signature: build_method_args(operation).join(", "),
445            return_signature: return_shape.signature.clone(),
446            raw_block: build_raw_block(operation),
447            wrapper_request_call_line: format!(
448                "response, err := {receiver_name}.{}({wrapper_forward_arguments})",
449                format!("{}Raw", sanitize_exported_identifier(&operation.name))
450            ),
451            wrapper_error_block: indent_block(&return_shape.raw_error_lines, 4),
452            wrapper_post_request_block: indent_block(&return_shape.post_response_lines, 4),
453        }
454    }
455}
456
457#[derive(Debug, Clone)]
458struct GoReturnShape {
459    signature: String,
460    raw_error_lines: Vec<String>,
461    post_response_lines: Vec<String>,
462}
463
464fn go_return_shape(operation: &Operation) -> GoReturnShape {
465    let success = operation
466        .responses
467        .iter()
468        .find(|response| response.status.starts_with('2'));
469
470    match success.and_then(|response| response.type_ref.as_ref()) {
471        Some(type_ref) => {
472            let result_type = go_result_type(type_ref);
473            let zero_lines = if returns_nil_on_error(type_ref) {
474                vec![
475                    "if err != nil {".into(),
476                    "    return nil, err".into(),
477                    "}".into(),
478                ]
479            } else {
480                vec![
481                    "if err != nil {".into(),
482                    format!("    var zero {result_type}"),
483                    "    return zero, err".into(),
484                    "}".into(),
485                ]
486            };
487
488            let mut post_response_lines = vec![
489                "defer response.Body.Close()".into(),
490                "if err := client.handleError(response, requestOptions); err != nil {".into(),
491            ];
492            if returns_nil_on_error(type_ref) {
493                post_response_lines.push("    return nil, err".into());
494            } else {
495                post_response_lines.push(format!("    var zero {result_type}"));
496                post_response_lines.push("    return zero, err".into());
497            }
498            post_response_lines.push("}".into());
499
500            let decode_type = go_decode_type(type_ref);
501            post_response_lines.push(format!("var result {decode_type}"));
502            post_response_lines.push(
503                "if err := client.decodeJSONResponse(response, &result); err != nil {".into(),
504            );
505            if returns_nil_on_error(type_ref) {
506                post_response_lines.push("    return nil, err".into());
507            } else {
508                post_response_lines.push(format!("    var zero {result_type}"));
509                post_response_lines.push("    return zero, err".into());
510            }
511            post_response_lines.push("}".into());
512            if returns_pointer_result(type_ref) {
513                post_response_lines.push("return &result, nil".into());
514            } else {
515                post_response_lines.push("return result, nil".into());
516            }
517
518            GoReturnShape {
519                signature: format!("({result_type}, error)"),
520                raw_error_lines: zero_lines,
521                post_response_lines,
522            }
523        }
524        None => GoReturnShape {
525            signature: "error".into(),
526            raw_error_lines: vec![
527                "if err != nil {".into(),
528                "    return err".into(),
529                "}".into(),
530            ],
531            post_response_lines: vec![
532                "defer response.Body.Close()".into(),
533                "if err := client.handleError(response, requestOptions); err != nil {".into(),
534                "    return err".into(),
535                "}".into(),
536                "return nil".into(),
537            ],
538        },
539    }
540}
541
542fn build_raw_block(operation: &Operation) -> String {
543    let mut lines = vec![
544        render_go_path_line(&operation.path, &operation.params),
545        "query := url.Values{}".into(),
546    ];
547
548    for param in operation
549        .params
550        .iter()
551        .filter(|param| matches!(param.location, ParameterLocation::Query))
552    {
553        let name = sanitize_identifier(&param.name);
554        if param.required {
555            lines.push(format!("query.Set({:?}, fmt.Sprint({name}))", param.name));
556        } else {
557            lines.push(format!("if {name} != nil {{"));
558            lines.push(format!(
559                "    query.Set({:?}, fmt.Sprint(*{name}))",
560                param.name
561            ));
562            lines.push("}".into());
563        }
564    }
565    lines.push("query = client.mergeQuery(query, requestOptions)".into());
566
567    lines.push("headers := http.Header{}".into());
568    for param in operation
569        .params
570        .iter()
571        .filter(|param| matches!(param.location, ParameterLocation::Header))
572    {
573        let name = sanitize_identifier(&param.name);
574        if param.required {
575            lines.push(format!("headers.Set({:?}, fmt.Sprint({name}))", param.name));
576        } else {
577            lines.push(format!("if {name} != nil {{"));
578            lines.push(format!(
579                "    headers.Set({:?}, fmt.Sprint(*{name}))",
580                param.name
581            ));
582            lines.push("}".into());
583        }
584    }
585
586    lines.push("cookies := []*http.Cookie{}".into());
587    for param in operation
588        .params
589        .iter()
590        .filter(|param| matches!(param.location, ParameterLocation::Cookie))
591    {
592        let name = sanitize_identifier(&param.name);
593        if param.required {
594            lines.push(format!(
595                "cookies = append(cookies, &http.Cookie{{Name: {:?}, Value: fmt.Sprint({name})}})",
596                param.name
597            ));
598        } else {
599            lines.push(format!("if {name} != nil {{"));
600            lines.push(format!(
601                "    cookies = append(cookies, &http.Cookie{{Name: {:?}, Value: fmt.Sprint(*{name})}})",
602                param.name
603            ));
604            lines.push("}".into());
605        }
606    }
607    lines.push("cookies = client.mergeCookies(cookies, requestOptions)".into());
608
609    lines.push("var bodyReader io.Reader".into());
610    lines.push("var err error".into());
611    if let Some(request_body) = &operation.request_body {
612        match classify_request_body(request_body) {
613            RequestBodyKind::Json => {
614                lines.push("headers.Set(\"Content-Type\", \"application/json\")".into());
615                if request_body.required {
616                    lines.push("bodyReader, err = client.encodeJSONBody(body)".into());
617                    lines.push("if err != nil {".into());
618                    lines.push("    return nil, err".into());
619                    lines.push("}".into());
620                } else {
621                    lines.push("if body != nil {".into());
622                    lines.push("    bodyReader, err = client.encodeJSONBody(body)".into());
623                    lines.push("    if err != nil {".into());
624                    lines.push("        return nil, err".into());
625                    lines.push("    }".into());
626                    lines.push("}".into());
627                }
628            }
629            RequestBodyKind::Multipart => {
630                lines.push("var contentType string".into());
631                if request_body.required {
632                    lines.push(
633                        "bodyReader, contentType, err = client.encodeMultipartBody(body)".into(),
634                    );
635                    lines.push("if err != nil {".into());
636                    lines.push("    return nil, err".into());
637                    lines.push("}".into());
638                    lines.push("headers.Set(\"Content-Type\", contentType)".into());
639                } else {
640                    lines.push("if body != nil {".into());
641                    lines.push(
642                        "    bodyReader, contentType, err = client.encodeMultipartBody(body)"
643                            .into(),
644                    );
645                    lines.push("    if err != nil {".into());
646                    lines.push("        return nil, err".into());
647                    lines.push("    }".into());
648                    lines.push("    headers.Set(\"Content-Type\", contentType)".into());
649                    lines.push("}".into());
650                }
651            }
652            RequestBodyKind::BinaryOrOther => {
653                if request_body.required {
654                    lines.push("bodyReader = body".into());
655                } else {
656                    lines.push("bodyReader = body".into());
657                }
658            }
659        }
660    }
661    lines.push("headers = client.mergeHeaders(headers, requestOptions)".into());
662    lines.push(format!(
663        "request, err := http.NewRequestWithContext(client.resolveContext(ctx, requestOptions), {}, client.buildURL(path, query), bodyReader)",
664        go_http_method(operation.method)
665    ));
666    lines.push("if err != nil {".into());
667    lines.push("    return nil, err".into());
668    lines.push("}".into());
669    lines.push("request.Header = headers".into());
670    lines.push("for _, cookie := range cookies {".into());
671    lines.push("    request.AddCookie(cookie)".into());
672    lines.push("}".into());
673    lines.push("return client.httpClient.Do(request)".into());
674
675    indent_block(&lines, 4)
676}
677
678#[derive(Debug, Clone, Copy)]
679enum RequestBodyKind {
680    Json,
681    Multipart,
682    BinaryOrOther,
683}
684
685fn classify_request_body(request_body: &RequestBody) -> RequestBodyKind {
686    if request_body.media_type == "application/json" {
687        RequestBodyKind::Json
688    } else if request_body.media_type.starts_with("multipart/form-data") {
689        RequestBodyKind::Multipart
690    } else {
691        RequestBodyKind::BinaryOrOther
692    }
693}
694
695fn render_go_path_line(path: &str, params: &[arvalez_ir::Parameter]) -> String {
696    let path_params = params
697        .iter()
698        .filter(|param| matches!(param.location, ParameterLocation::Path))
699        .collect::<Vec<_>>();
700
701    if path_params.is_empty() {
702        format!("path := {:?}", path)
703    } else {
704        let mut format_path = String::new();
705        let mut chars = path.chars().peekable();
706        while let Some(ch) = chars.next() {
707            match ch {
708                '{' => {
709                    while let Some(next) = chars.peek() {
710                        if *next == '}' {
711                            chars.next();
712                            break;
713                        }
714                        chars.next();
715                    }
716                    format_path.push_str("%s");
717                }
718                '%' => format_path.push_str("%%"),
719                _ => format_path.push(ch),
720            }
721        }
722        let arguments = path_params
723            .into_iter()
724            .map(|param| {
725                format!(
726                    "url.PathEscape(fmt.Sprint({}))",
727                    sanitize_identifier(&param.name)
728                )
729            })
730            .collect::<Vec<_>>()
731            .join(", ");
732        format!("path := fmt.Sprintf({format_path:?}, {arguments})")
733    }
734}
735
736fn render_model_block(tera: &Tera, model: ModelView) -> Result<String> {
737    let mut context = TeraContext::new();
738    context.insert("model", &model);
739    tera.render(TEMPLATE_MODEL_STRUCT, &context)
740        .context("failed to render model struct partial")
741}
742
743fn render_service_block(tera: &Tera, service: Result<ServiceView>) -> Result<String> {
744    let service = service?;
745    let mut context = TeraContext::new();
746    context.insert("service", &service);
747    tera.render(TEMPLATE_SERVICE, &context)
748        .context("failed to render service partial")
749}
750
751fn render_method_blocks(tera: &Tera, methods: Vec<OperationMethodView>) -> Result<String> {
752    methods
753        .into_iter()
754        .map(|method| {
755            let mut context = TeraContext::new();
756            context.insert("operation", &method);
757            tera.render(TEMPLATE_CLIENT_METHOD, &context)
758                .context("failed to render client method partial")
759        })
760        .collect::<Result<Vec<_>>>()
761        .map(|blocks| blocks.join("\n\n"))
762}
763
764fn build_method_args(operation: &Operation) -> Vec<String> {
765    let mut args = vec!["ctx context.Context".into()];
766
767    for param in operation.params.iter().filter(|param| param.required) {
768        args.push(format!(
769            "{} {}",
770            sanitize_identifier(&param.name),
771            go_required_arg_type(&param.type_ref)
772        ));
773    }
774
775    if let Some(request_body) = &operation.request_body
776        && request_body.required
777    {
778        args.push(format!("body {}", go_body_arg_type(request_body, true)));
779    }
780
781    for param in operation.params.iter().filter(|param| !param.required) {
782        args.push(format!(
783            "{} {}",
784            sanitize_identifier(&param.name),
785            go_optional_arg_type(&param.type_ref)
786        ));
787    }
788
789    if let Some(request_body) = &operation.request_body
790        && !request_body.required
791    {
792        args.push(format!("body {}", go_body_arg_type(request_body, false)));
793    }
794
795    args.push("requestOptions *RequestOptions".into());
796    args
797}
798
799fn build_forward_arguments(operation: &Operation) -> String {
800    let mut args = vec!["ctx".into()];
801
802    for param in operation.params.iter().filter(|param| param.required) {
803        args.push(sanitize_identifier(&param.name));
804    }
805    if let Some(request_body) = &operation.request_body
806        && request_body.required
807    {
808        let _ = request_body;
809        args.push("body".into());
810    }
811    for param in operation.params.iter().filter(|param| !param.required) {
812        args.push(sanitize_identifier(&param.name));
813    }
814    if let Some(request_body) = &operation.request_body
815        && !request_body.required
816    {
817        let _ = request_body;
818        args.push("body".into());
819    }
820
821    args.push("requestOptions".into());
822    args.join(", ")
823}
824
825fn sorted_models(ir: &CoreIr) -> Vec<&arvalez_ir::Model> {
826    let mut models = ir.models.iter().collect::<Vec<_>>();
827    models.sort_by(|left, right| left.name.cmp(&right.name));
828    models
829}
830
831fn sorted_operations(ir: &CoreIr) -> Vec<&Operation> {
832    let mut operations = ir.operations.iter().collect::<Vec<_>>();
833    operations.sort_by(|left, right| left.name.cmp(&right.name));
834    operations
835}
836
837fn operation_primary_tag(operation: &Operation) -> Option<String> {
838    operation
839        .attributes
840        .get("tags")
841        .and_then(|value| value.as_array())
842        .and_then(|tags| tags.first())
843        .and_then(|tag| tag.as_str())
844        .map(str::trim)
845        .filter(|tag| !tag.is_empty())
846        .map(ToOwned::to_owned)
847}
848
849fn go_field_type(type_ref: &TypeRef, optional: bool, nullable: bool) -> String {
850    let base = go_type_ref(type_ref);
851    if optional || nullable {
852        match type_ref {
853            TypeRef::Primitive { name } if name == "string" => "*string".into(),
854            TypeRef::Primitive { name } if name == "integer" => "*int64".into(),
855            TypeRef::Primitive { name } if name == "number" => "*float64".into(),
856            TypeRef::Primitive { name } if name == "boolean" => "*bool".into(),
857            TypeRef::Named { .. } => format!("*{base}"),
858            _ => base,
859        }
860    } else {
861        base
862    }
863}
864
865fn go_type_ref(type_ref: &TypeRef) -> String {
866    match type_ref {
867        TypeRef::Primitive { name } => match name.as_str() {
868            "string" => "string".into(),
869            "integer" => "int64".into(),
870            "number" => "float64".into(),
871            "boolean" => "bool".into(),
872            "binary" => "[]byte".into(),
873            "null" => "any".into(),
874            "any" | "object" => "any".into(),
875            _ => "any".into(),
876        },
877        TypeRef::Named { name } => sanitize_exported_identifier(name),
878        TypeRef::Array { item } => format!("[]{}", go_type_ref(item)),
879        TypeRef::Map { value } => format!("map[string]{}", go_type_ref(value)),
880        TypeRef::Union { .. } => "any".into(),
881    }
882}
883
884fn go_body_arg_type(request_body: &RequestBody, required: bool) -> String {
885    match request_body.type_ref.as_ref() {
886        Some(TypeRef::Named { name }) => format!("*{}", sanitize_exported_identifier(name)),
887        Some(type_ref) => {
888            let base = go_type_ref(type_ref);
889            if required {
890                base
891            } else {
892                match type_ref {
893                    TypeRef::Primitive { name } if name == "string" => "*string".into(),
894                    TypeRef::Primitive { name } if name == "integer" => "*int64".into(),
895                    TypeRef::Primitive { name } if name == "number" => "*float64".into(),
896                    TypeRef::Primitive { name } if name == "boolean" => "*bool".into(),
897                    _ => base,
898                }
899            }
900        }
901        None => "io.Reader".into(),
902    }
903}
904
905fn go_required_arg_type(type_ref: &TypeRef) -> String {
906    go_type_ref(type_ref)
907}
908
909fn go_optional_arg_type(type_ref: &TypeRef) -> String {
910    match type_ref {
911        TypeRef::Primitive { name } if name == "string" => "*string".into(),
912        TypeRef::Primitive { name } if name == "integer" => "*int64".into(),
913        TypeRef::Primitive { name } if name == "number" => "*float64".into(),
914        TypeRef::Primitive { name } if name == "boolean" => "*bool".into(),
915        TypeRef::Named { name } => format!("*{}", sanitize_exported_identifier(name)),
916        _ => go_type_ref(type_ref),
917    }
918}
919
920fn go_result_type(type_ref: &TypeRef) -> String {
921    if returns_pointer_result(type_ref) {
922        format!("*{}", go_decode_type(type_ref))
923    } else {
924        go_decode_type(type_ref)
925    }
926}
927
928fn go_decode_type(type_ref: &TypeRef) -> String {
929    match type_ref {
930        TypeRef::Named { name } => sanitize_exported_identifier(name),
931        _ => go_type_ref(type_ref),
932    }
933}
934
935fn returns_pointer_result(type_ref: &TypeRef) -> bool {
936    matches!(type_ref, TypeRef::Named { .. })
937}
938
939fn returns_nil_on_error(type_ref: &TypeRef) -> bool {
940    match type_ref {
941        TypeRef::Named { .. } | TypeRef::Array { .. } | TypeRef::Map { .. } => true,
942        TypeRef::Primitive { name } => name == "binary",
943        TypeRef::Union { .. } => false,
944    }
945}
946
947fn go_http_method(method: HttpMethod) -> &'static str {
948    match method {
949        HttpMethod::Get => "http.MethodGet",
950        HttpMethod::Post => "http.MethodPost",
951        HttpMethod::Put => "http.MethodPut",
952        HttpMethod::Patch => "http.MethodPatch",
953        HttpMethod::Delete => "http.MethodDelete",
954    }
955}
956
957fn indent_block(lines: &[String], spaces: usize) -> String {
958    let indent = " ".repeat(spaces);
959    lines
960        .iter()
961        .map(|line| format!("{indent}{line}"))
962        .collect::<Vec<_>>()
963        .join("\n")
964}
965
966fn default_package_name(module_path: &str) -> String {
967    module_path
968        .rsplit('/')
969        .next()
970        .map(sanitize_package_name)
971        .unwrap_or_else(|| "client".into())
972}
973
974fn sanitize_package_name(name: &str) -> String {
975    let mut out = split_words(name).join("");
976    if out.is_empty() {
977        out = "client".into();
978    }
979    if out.chars().next().is_some_and(|ch| ch.is_ascii_digit()) {
980        out.insert(0, 'x');
981    }
982    if is_go_keyword(&out) {
983        out.push_str("pkg");
984    }
985    out.to_ascii_lowercase()
986}
987
988fn sanitize_exported_identifier(name: &str) -> String {
989    let mut out = String::new();
990    for word in split_words(name) {
991        let mut chars = word.chars();
992        if let Some(first) = chars.next() {
993            out.extend(first.to_uppercase());
994            out.push_str(chars.as_str());
995        }
996    }
997    if out.is_empty() {
998        out = "Generated".into();
999    }
1000    if out.chars().next().is_some_and(|ch| ch.is_ascii_digit()) {
1001        out.insert(0, 'X');
1002    }
1003    if is_go_keyword(&out.to_ascii_lowercase()) {
1004        out.push('_');
1005    }
1006    out
1007}
1008
1009fn sanitize_identifier(name: &str) -> String {
1010    let words = split_words(name);
1011    let mut out = if words.is_empty() {
1012        "value".into()
1013    } else {
1014        let mut iter = words.into_iter();
1015        let mut result = iter.next().unwrap_or_else(|| "value".into());
1016        for word in iter {
1017            let mut chars = word.chars();
1018            if let Some(first) = chars.next() {
1019                result.extend(first.to_uppercase());
1020                result.push_str(chars.as_str());
1021            }
1022        }
1023        result
1024    };
1025
1026    if out.chars().next().is_some_and(|ch| ch.is_ascii_digit()) {
1027        out.insert(0, 'x');
1028    }
1029    if is_go_keyword(&out) {
1030        out.push('_');
1031    }
1032    out
1033}
1034
1035fn split_words(input: &str) -> Vec<String> {
1036    let mut words = Vec::new();
1037    let mut current = String::new();
1038
1039    for ch in input.chars() {
1040        if ch.is_ascii_alphanumeric() {
1041            if ch.is_uppercase() && !current.is_empty() {
1042                words.push(current.clone());
1043                current.clear();
1044            }
1045            current.push(ch.to_ascii_lowercase());
1046        } else if !current.is_empty() {
1047            words.push(current.clone());
1048            current.clear();
1049        }
1050    }
1051
1052    if !current.is_empty() {
1053        words.push(current);
1054    }
1055
1056    words
1057}
1058
1059fn is_go_keyword(value: &str) -> bool {
1060    matches!(
1061        value,
1062        "break"
1063            | "default"
1064            | "func"
1065            | "interface"
1066            | "select"
1067            | "case"
1068            | "defer"
1069            | "go"
1070            | "map"
1071            | "struct"
1072            | "chan"
1073            | "else"
1074            | "goto"
1075            | "package"
1076            | "switch"
1077            | "const"
1078            | "fallthrough"
1079            | "if"
1080            | "range"
1081            | "type"
1082            | "continue"
1083            | "for"
1084            | "import"
1085            | "return"
1086            | "var"
1087    )
1088}
1089
1090#[cfg(test)]
1091mod tests {
1092    use std::fs;
1093
1094    use super::*;
1095    use arvalez_ir::{Attributes, Field, Parameter, Response};
1096    use serde_json::json;
1097    use tempfile::tempdir;
1098
1099    fn sample_ir() -> CoreIr {
1100        CoreIr {
1101            models: vec![arvalez_ir::Model {
1102                id: "model.widget".into(),
1103                name: "Widget".into(),
1104                fields: vec![
1105                    Field::new("id", TypeRef::primitive("string")),
1106                    Field {
1107                        name: "count".into(),
1108                        type_ref: TypeRef::primitive("integer"),
1109                        optional: true,
1110                        nullable: false,
1111                        attributes: Attributes::default(),
1112                    },
1113                ],
1114                attributes: Attributes::default(),
1115                source: None,
1116            }],
1117            operations: vec![Operation {
1118                id: "operation.get_widget".into(),
1119                name: "get_widget".into(),
1120                method: HttpMethod::Get,
1121                path: "/widgets/{widget_id}".into(),
1122                params: vec![
1123                    Parameter {
1124                        name: "widget_id".into(),
1125                        location: ParameterLocation::Path,
1126                        type_ref: TypeRef::primitive("string"),
1127                        required: true,
1128                    },
1129                    Parameter {
1130                        name: "include_count".into(),
1131                        location: ParameterLocation::Query,
1132                        type_ref: TypeRef::primitive("boolean"),
1133                        required: false,
1134                    },
1135                ],
1136                request_body: Some(RequestBody {
1137                    required: false,
1138                    media_type: "application/json".into(),
1139                    type_ref: Some(TypeRef::named("Widget")),
1140                }),
1141                responses: vec![Response {
1142                    status: "200".into(),
1143                    media_type: Some("application/json".into()),
1144                    type_ref: Some(TypeRef::named("Widget")),
1145                    attributes: Attributes::default(),
1146                }],
1147                attributes: Attributes::from([("tags".into(), json!(["widgets"]))]),
1148                source: None,
1149            }],
1150            ..Default::default()
1151        }
1152    }
1153
1154    #[test]
1155    fn renders_basic_go_package() {
1156        let files = generate_package(
1157            &sample_ir(),
1158            &GoPackageConfig::new("github.com/demo/client"),
1159        )
1160        .expect("package should render");
1161
1162        let go_mod = files
1163            .iter()
1164            .find(|file| file.path.ends_with("go.mod"))
1165            .expect("go.mod");
1166        let models = files
1167            .iter()
1168            .find(|file| file.path.ends_with("models.go"))
1169            .expect("models.go");
1170        let client = files
1171            .iter()
1172            .find(|file| file.path.ends_with("client.go"))
1173            .expect("client.go");
1174
1175        assert!(go_mod.contents.contains("module github.com/demo/client"));
1176        assert!(models.contents.contains("type Widget struct"));
1177        assert!(
1178            models
1179                .contents
1180                .contains("Count *int64 `json:\"count,omitempty\"`")
1181        );
1182        assert!(
1183            client
1184                .contents
1185                .contains("type ErrorHandler func(*http.Response) error")
1186        );
1187        assert!(client.contents.contains("type RequestOptions struct"));
1188        assert!(client.contents.contains("func (c *Client) GetWidgetRaw("));
1189        assert!(client.contents.contains("func (c *Client) GetWidget("));
1190        assert!(client.contents.contains("requestOptions *RequestOptions"));
1191        assert!(
1192            client
1193                .contents
1194                .contains("if err := client.handleError(response, requestOptions); err != nil {")
1195        );
1196        assert!(client.contents.contains("response, err := c.GetWidgetRaw("));
1197    }
1198
1199    #[test]
1200    fn groups_operations_by_tag_when_enabled() {
1201        let files = generate_package(
1202            &sample_ir(),
1203            &GoPackageConfig::new("github.com/demo/client").with_group_by_tag(true),
1204        )
1205        .expect("package should render");
1206        let client = files
1207            .iter()
1208            .find(|file| file.path.ends_with("client.go"))
1209            .expect("client.go");
1210
1211        assert!(client.contents.contains("Widgets *WidgetsService"));
1212        assert!(
1213            client
1214                .contents
1215                .contains("client.Widgets = &WidgetsService{client: client}")
1216        );
1217        assert!(client.contents.contains("type WidgetsService struct"));
1218        assert!(
1219            client
1220                .contents
1221                .contains("func (s *WidgetsService) GetWidgetRaw(")
1222        );
1223    }
1224
1225    #[test]
1226    fn supports_selective_template_overrides() {
1227        let tempdir = tempdir().expect("tempdir");
1228        let partial_dir = tempdir.path().join("partials");
1229        fs::create_dir_all(&partial_dir).expect("partials dir");
1230        fs::write(
1231            partial_dir.join("service.go.tera"),
1232            "type {{ service.struct_name }} struct { Overridden bool }\n",
1233        )
1234        .expect("override template");
1235
1236        let files = generate_package(
1237            &sample_ir(),
1238            &GoPackageConfig::new("github.com/demo/client")
1239                .with_group_by_tag(true)
1240                .with_template_dir(Some(tempdir.path().to_path_buf())),
1241        )
1242        .expect("package should render");
1243        let client = files
1244            .iter()
1245            .find(|file| file.path.ends_with("client.go"))
1246            .expect("client.go");
1247
1248        assert!(
1249            client
1250                .contents
1251                .contains("type WidgetsService struct { Overridden bool }")
1252        );
1253    }
1254}