1use anyhow::Result;
2use rustex_ir::{Field, Function, FunctionKind, IrPackage, Table, TypeNode};
3use rustex_project::RustexConfig;
4use std::collections::{BTreeMap, BTreeSet};
5use tracing::debug;
6
7#[derive(Debug, Clone)]
8pub struct GeneratedFile {
9 pub path: String,
10 pub contents: String,
11}
12
13pub fn generate(package: &IrPackage, config: &RustexConfig) -> Result<Vec<GeneratedFile>> {
14 let _span = tracing::info_span!(
15 "rustex_rustgen.generate",
16 package = %package.project.name,
17 tables = package.tables.len(),
18 functions = package.functions.len()
19 )
20 .entered();
21 debug!(
22 emit_custom_derives = config.custom_derives.len(),
23 "rendering Rust bindings"
24 );
25 Ok(vec![
26 GeneratedFile {
27 path: "Cargo.toml".into(),
28 contents: cargo_toml(package),
29 },
30 GeneratedFile {
31 path: "lib.rs".into(),
32 contents: lib_rs(),
33 },
34 GeneratedFile {
35 path: "ids.rs".into(),
36 contents: ids_rs(package),
37 },
38 GeneratedFile {
39 path: "models.rs".into(),
40 contents: models_rs(package, config),
41 },
42 GeneratedFile {
43 path: "api.rs".into(),
44 contents: api_rs(package, config),
45 },
46 ])
47}
48
49fn cargo_toml(package: &IrPackage) -> String {
50 let runtime_dependency = runtime_dependency();
51 format!(
52 "[package]\nname = \"{}-generated\"\nversion = \"0.1.0\"\nedition = \"2024\"\n\n[lib]\npath = \"lib.rs\"\n\n[dependencies]\nserde = {{ version = \"1\", features = [\"derive\"] }}\nserde_json = \"1\"\n{runtime_dependency}\ntracing = \"0.1\"\ntracing-subscriber = {{ version = \"0.3\", features = [\"env-filter\", \"fmt\"] }}\n",
53 package.project.name,
54 )
55}
56
57fn runtime_dependency() -> String {
58 if let Ok(path) = std::env::var("RUSTEX_GENERATED_RUNTIME_PATH") {
59 return format!(
60 "rustex-runtime = {{ path = \"{}\", version = \"{}\" }}",
61 escape_toml_string(&path),
62 env!("CARGO_PKG_VERSION")
63 );
64 }
65
66 let local_runtime_path =
67 std::path::Path::new(env!("CARGO_MANIFEST_DIR")).join("../rustex-runtime");
68 if local_runtime_path.join("Cargo.toml").is_file() {
69 if let Ok(path) = local_runtime_path.canonicalize() {
70 return format!(
71 "rustex-runtime = {{ path = \"{}\", version = \"{}\" }}",
72 escape_toml_string(&path.display().to_string()),
73 env!("CARGO_PKG_VERSION")
74 );
75 }
76 }
77
78 format!("rustex-runtime = \"{}\"", env!("CARGO_PKG_VERSION"))
79}
80
81fn escape_toml_string(value: &str) -> String {
82 value.replace('\\', "\\\\").replace('"', "\\\"")
83}
84
85fn lib_rs() -> String {
86 "pub mod api;\npub mod ids;\npub mod models;\n\npub use rustex_runtime::{RustexClient, init_default_tracing};\n".into()
87}
88
89fn ids_rs(package: &IrPackage) -> String {
90 let mut seen = BTreeSet::new();
91 let mut out = String::from("use serde::{Deserialize, Serialize};\n\n");
92 for table in &package.tables {
93 if seen.insert(table.name.clone()) {
94 let id_name = format!("{}Id", pascal_case(&table.name));
95 out.push_str("#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]\n");
96 out.push_str(&format!("pub struct {id_name}(pub String);\n\n"));
97 }
98 }
99 out
100}
101
102fn models_rs(package: &IrPackage, config: &RustexConfig) -> String {
103 let mut generator = TypeGenerator::new(config);
104 let mut out = String::from(
105 "#![allow(unused_imports)]\nuse serde::{Deserialize, Serialize};\nuse std::collections::BTreeMap;\nuse crate::ids::*;\n\n",
106 );
107 for table in &package.tables {
108 render_table(table, &mut generator);
109 }
110 out.push_str(&generator.finish());
111 out
112}
113
114fn render_table(table: &Table, generator: &mut TypeGenerator) {
115 if let TypeNode::Object { fields, .. } = &table.document_type {
116 let mut struct_fields = vec![
117 RenderedField {
118 rust_name: "_id".into(),
119 original_name: "_id".into(),
120 ty: format!("{}Id", pascal_case(&table.name)),
121 required: true,
122 },
123 RenderedField {
124 rust_name: "_creation_time".into(),
125 original_name: "_creation_time".into(),
126 ty: "f64".into(),
127 required: true,
128 },
129 ];
130 struct_fields.extend(generator.render_fields(fields, &table.doc_name));
131 generator.push_struct_named(&table.doc_name, &struct_fields);
132 }
133}
134
135fn api_rs(package: &IrPackage, config: &RustexConfig) -> String {
136 let mut out = String::from(
137 "#![allow(unused_imports)]\nuse serde::{Deserialize, Serialize};\nuse std::collections::BTreeMap;\nuse crate::ids::*;\nuse crate::models::*;\nuse rustex_runtime::{ActionSpec, FunctionSpec, MutationSpec, QuerySpec};\n\n",
138 );
139
140 let mut grouped: BTreeMap<String, Vec<&Function>> = BTreeMap::new();
141 for function in &package.functions {
142 grouped
143 .entry(function.module_path.clone())
144 .or_default()
145 .push(function);
146 }
147
148 for (module_path, functions) in grouped {
149 let mut generator = TypeGenerator::new_with_indent(" ", config);
150 out.push_str(&format!("pub mod {} {{\n", module_ident(&module_path)));
151 out.push_str(" use super::*;\n\n");
152 for function in functions {
153 render_function(function, &mut generator);
154 }
155 out.push_str(&generator.finish());
156 out.push_str("}\n\n");
157 }
158
159 out.push_str(
160 "#[doc(hidden)]\n#[macro_export]\nmacro_rules! __rustex_arg_value {\n ($field:ident, $value:expr) => {\n ::core::convert::Into::into($value)\n };\n ($field:ident) => {\n ::core::convert::Into::into($field)\n };\n}\n\n",
161 );
162 out.push_str(&render_operation_macro(
163 "query",
164 "query",
165 FunctionKind::Query,
166 &package.functions,
167 ));
168 out.push_str(&render_operation_macro(
169 "mutation",
170 "mutation",
171 FunctionKind::Mutation,
172 &package.functions,
173 ));
174 out.push_str(&render_operation_macro(
175 "action",
176 "action",
177 FunctionKind::Action,
178 &package.functions,
179 ));
180 out.push_str(&render_operation_macro(
181 "subscribe",
182 "subscribe",
183 FunctionKind::Query,
184 &package.functions,
185 ));
186
187 out
188}
189
190fn render_operation_macro(
191 macro_name: &str,
192 method_name: &str,
193 function_kind: FunctionKind,
194 functions: &[Function],
195) -> String {
196 let mut out = format!("#[macro_export]\nmacro_rules! {macro_name} {{\n");
197 let mut has_rule = false;
198 for function in functions
199 .iter()
200 .filter(|function| function.kind == function_kind)
201 {
202 has_rule = true;
203 let module = module_ident(&function.module_path);
204 let function_name = snake_case(&function.export_name);
205 let fn_path = format!("$crate::api::{module}::{function_name}()");
206 let args_ty = format!(
207 "$crate::api::{module}::{}Args",
208 pascal_case(&function.export_name)
209 );
210 match &function.args_type {
211 None => {
212 out.push_str(&format!(
213 " ($client:expr, {module}::{function_name}) => {{\n $client.{method_name}({fn_path}, &())\n }};\n",
214 ));
215 out.push_str(&format!(
216 " ($client:expr, {module}::{function_name}, {{}}) => {{\n $client.{method_name}({fn_path}, &())\n }};\n",
217 ));
218 }
219 Some(TypeNode::Object { .. }) => {
220 out.push_str(&format!(
221 " ($client:expr, {module}::{function_name}, {{ $($field:ident $( : $value:expr )?),* $(,)? }}) => {{\n $client.{method_name}({fn_path}, &{args_ty} {{\n $( $field: $crate::__rustex_arg_value!($field $(, $value)?), )*\n }})\n }};\n",
222 ));
223 }
224 Some(_) => {
225 out.push_str(&format!(
226 " ($client:expr, {module}::{function_name}, $args:expr) => {{\n $client.{method_name}({fn_path}, &$args)\n }};\n",
227 ));
228 }
229 }
230 }
231 if !has_rule {
232 out.push_str(
233 " ($($tt:tt)*) => {\n compile_error!(\"no generated functions support this operation macro in this crate\")\n };\n",
234 );
235 }
236 out.push_str("}\n\n");
237 out
238}
239
240fn render_function(function: &Function, generator: &mut TypeGenerator) {
241 let base = pascal_case(&function.export_name);
242 let args_ty = format!("{base}Args");
243 let output_ty = format!("{base}Response");
244
245 match &function.args_type {
246 Some(TypeNode::Object { fields, .. }) => {
247 let rendered = generator.render_fields(fields, &args_ty);
248 generator.push_struct_named(&args_ty, &rendered);
249 }
250 Some(other) => {
251 let ty = generator.render_type(other, true, &args_ty);
252 generator.push_alias_named(&args_ty, &ty);
253 }
254 None => generator.push_alias_named(&args_ty, "()"),
255 }
256
257 match &function.returns_type {
258 Some(node) => match node {
259 TypeNode::Object { fields, .. } => {
260 let rendered = generator.render_fields(fields, &output_ty);
261 generator.push_struct_named(&output_ty, &rendered);
262 }
263 _ => {
264 let ty = generator.render_type(node, true, &output_ty);
265 generator.push_alias_named(&output_ty, &ty);
266 }
267 },
268 None => generator.push_alias_named(&output_ty, "()"),
269 }
270
271 generator.push_raw(&format!(
272 "#[derive(Clone, Copy, Debug, Default)]\npub struct {base};\n\n"
273 ));
274 generator.push_raw(&format!(
275 "pub fn {}() -> {base} {{\n {base}\n}}\n\n",
276 snake_case(&function.export_name)
277 ));
278 generator.push_raw(&format!("impl FunctionSpec for {base} {{\n"));
279 generator.push_raw(&format!(" type Args = {args_ty};\n"));
280 generator.push_raw(&format!(" type Output = {output_ty};\n"));
281 generator.push_raw(&format!(
282 " const PATH: &'static str = \"{}\";\n",
283 function.canonical_path
284 ));
285 generator.push_raw("}\n");
286
287 match function.kind {
288 FunctionKind::Query => generator.push_raw(&format!("impl QuerySpec for {base} {{}}\n\n")),
289 FunctionKind::Mutation => {
290 generator.push_raw(&format!("impl MutationSpec for {base} {{}}\n\n"))
291 }
292 FunctionKind::Action => generator.push_raw(&format!("impl ActionSpec for {base} {{}}\n\n")),
293 }
294}
295
296#[derive(Debug, Clone)]
297struct RenderedField {
298 rust_name: String,
299 original_name: String,
300 ty: String,
301 required: bool,
302}
303
304struct TypeGenerator {
305 indent: &'static str,
306 items: Vec<String>,
307 used_names: BTreeSet<String>,
308 derives: Vec<String>,
309 attributes: Vec<String>,
310}
311
312impl TypeGenerator {
313 fn new(config: &RustexConfig) -> Self {
314 Self {
315 indent: "",
316 items: Vec::new(),
317 used_names: BTreeSet::new(),
318 derives: config.custom_derives.clone(),
319 attributes: config.custom_attributes.clone(),
320 }
321 }
322
323 fn new_with_indent(indent: &'static str, config: &RustexConfig) -> Self {
324 Self {
325 indent,
326 items: Vec::new(),
327 used_names: BTreeSet::new(),
328 derives: config.custom_derives.clone(),
329 attributes: config.custom_attributes.clone(),
330 }
331 }
332
333 fn finish(self) -> String {
334 self.items.concat()
335 }
336
337 fn push_raw(&mut self, raw: &str) {
338 for line in raw.lines() {
339 self.items.push(format!("{}{}\n", self.indent, line));
340 }
341 }
342
343 fn push_alias_named(&mut self, name: &str, ty: &str) {
344 let name = self.claim_name(name);
345 self.items
346 .push(format!("{}pub type {name} = {ty};\n\n", self.indent));
347 }
348
349 fn push_struct_named(&mut self, name: &str, fields: &[RenderedField]) {
350 let name = self.claim_name(name);
351 self.push_type_header(true);
352 self.items
353 .push(format!("{}pub struct {name} {{\n", self.indent));
354 for field in fields {
355 if field.rust_name != field.original_name {
356 self.items.push(format!(
357 "{} #[serde(rename = \"{}\")]\n",
358 self.indent, field.original_name
359 ));
360 }
361 if !field.required {
362 self.items.push(format!(
363 "{} #[serde(skip_serializing_if = \"Option::is_none\")]\n",
364 self.indent
365 ));
366 }
367 self.items.push(format!(
368 "{} pub {}: {},\n",
369 self.indent, field.rust_name, field.ty
370 ));
371 }
372 self.items.push(format!("{}}}\n\n", self.indent));
373 }
374
375 fn push_literal_enum_named(&mut self, name: &str, values: &[String]) -> String {
376 let name = self.claim_name(name);
377 self.push_type_header(false);
378 self.items
379 .push(format!("{}pub enum {name} {{\n", self.indent));
380 let mut used_variants = BTreeSet::new();
381 for value in values {
382 let base = sanitize_variant(value);
383 let variant = dedupe_name(&mut used_variants, &base);
384 self.items.push(format!(
385 "{} #[serde(rename = \"{}\")]\n",
386 self.indent, value
387 ));
388 self.items
389 .push(format!("{} {},\n", self.indent, variant));
390 }
391 self.items.push(format!("{}}}\n\n", self.indent));
392 name
393 }
394
395 fn push_discriminated_enum_named(
396 &mut self,
397 name: &str,
398 tag: &str,
399 variants: &[(String, Vec<RenderedField>)],
400 ) -> String {
401 let name = self.claim_name(name);
402 self.push_type_header(false);
403 self.items
404 .push(format!("{}#[serde(tag = \"{}\")]\n", self.indent, tag));
405 self.items
406 .push(format!("{}pub enum {name} {{\n", self.indent));
407 let mut used_variants = BTreeSet::new();
408 for (value, fields) in variants {
409 let variant = dedupe_name(&mut used_variants, &sanitize_variant(value));
410 self.items.push(format!(
411 "{} #[serde(rename = \"{}\")]\n",
412 self.indent, value
413 ));
414 if fields.is_empty() {
415 self.items
416 .push(format!("{} {},\n", self.indent, variant));
417 } else {
418 self.items
419 .push(format!("{} {} {{\n", self.indent, variant));
420 for field in fields {
421 if field.rust_name != field.original_name {
422 self.items.push(format!(
423 "{} #[serde(rename = \"{}\")]\n",
424 self.indent, field.original_name
425 ));
426 }
427 if !field.required {
428 self.items.push(format!(
429 "{} #[serde(skip_serializing_if = \"Option::is_none\")]\n",
430 self.indent
431 ));
432 }
433 self.items.push(format!(
434 "{} {}: {},\n",
435 self.indent, field.rust_name, field.ty
436 ));
437 }
438 self.items.push(format!("{} }},\n", self.indent));
439 }
440 }
441 self.items.push(format!("{}}}\n\n", self.indent));
442 name
443 }
444
445 fn push_untagged_enum_named(
446 &mut self,
447 name: &str,
448 variants: &[(String, Vec<RenderedField>)],
449 ) -> String {
450 let name = self.claim_name(name);
451 self.push_type_header(false);
452 self.items
453 .push(format!("{}#[serde(untagged)]\n", self.indent));
454 self.items
455 .push(format!("{}pub enum {name} {{\n", self.indent));
456 let mut used_variants = BTreeSet::new();
457 for (variant_name, fields) in variants {
458 let variant = dedupe_name(&mut used_variants, &sanitize_variant(variant_name));
459 if fields.is_empty() {
460 self.items
461 .push(format!("{} {},\n", self.indent, variant));
462 } else {
463 self.items
464 .push(format!("{} {} {{\n", self.indent, variant));
465 for field in fields {
466 if field.rust_name != field.original_name {
467 self.items.push(format!(
468 "{} #[serde(rename = \"{}\")]\n",
469 self.indent, field.original_name
470 ));
471 }
472 if !field.required {
473 self.items.push(format!(
474 "{} #[serde(skip_serializing_if = \"Option::is_none\")]\n",
475 self.indent
476 ));
477 }
478 self.items.push(format!(
479 "{} {}: {},\n",
480 self.indent, field.rust_name, field.ty
481 ));
482 }
483 self.items.push(format!("{} }},\n", self.indent));
484 }
485 }
486 self.items.push(format!("{}}}\n\n", self.indent));
487 name
488 }
489
490 fn claim_name(&mut self, base: &str) -> String {
491 let name = dedupe_name(&mut self.used_names, base);
492 name
493 }
494
495 fn render_fields(&mut self, fields: &[Field], owner_name: &str) -> Vec<RenderedField> {
496 let mut used = BTreeSet::new();
497 fields
498 .iter()
499 .map(|field| {
500 let rust_name = dedupe_name(&mut used, &snake_case(&field.name));
501 let hint = format!("{owner_name}{}", pascal_case(&field.name));
502 RenderedField {
503 rust_name,
504 original_name: field.name.clone(),
505 ty: self.render_type(&field.r#type, field.required, &hint),
506 required: field.required,
507 }
508 })
509 .collect()
510 }
511
512 fn render_type(&mut self, node: &TypeNode, required: bool, hint: &str) -> String {
513 let base = match node {
514 TypeNode::String => "String".into(),
515 TypeNode::Float64 => "f64".into(),
516 TypeNode::Int64 => "i64".into(),
517 TypeNode::Boolean => "bool".into(),
518 TypeNode::Null => "()".into(),
519 TypeNode::Bytes => "Vec<u8>".into(),
520 TypeNode::Any => "serde_json::Value".into(),
521 TypeNode::LiteralString { value } => {
522 let enum_name = self.push_literal_enum_named(hint, std::slice::from_ref(value));
523 enum_name
524 }
525 TypeNode::LiteralNumber { .. } => "f64".into(),
526 TypeNode::LiteralBoolean { .. } => "bool".into(),
527 TypeNode::Id { table } => format!("{}Id", pascal_case(table)),
528 TypeNode::Array { element } => {
529 let inner = self.render_type(element, true, &format!("{hint}Item"));
530 format!("Vec<{inner}>")
531 }
532 TypeNode::Record { value } => {
533 let inner = self.render_type(value, true, &format!("{hint}Value"));
534 format!("BTreeMap<String, {inner}>")
535 }
536 TypeNode::Object { fields, .. } => {
537 let struct_name = self.claim_name(hint);
538 let rendered = self.render_fields(fields, &struct_name);
539 self.push_struct_body(&struct_name, &rendered);
540 struct_name
541 }
542 TypeNode::Union { members } => self.render_union(members, hint),
543 TypeNode::Unknown { .. } => "serde_json::Value".into(),
544 };
545
546 if required {
547 base
548 } else {
549 format!("Option<{base}>")
550 }
551 }
552
553 fn render_union(&mut self, members: &[TypeNode], hint: &str) -> String {
554 if let Some(non_null) = optional_member(members) {
555 let inner = self.render_type(non_null, true, hint);
556 return format!("Option<{inner}>");
557 }
558
559 if let Some(literals) = literal_string_union(members) {
560 return self.push_literal_enum_named(hint, &literals);
561 }
562
563 if let Some((tag, variants)) = discriminated_union_members(members) {
564 let rendered_variants = variants
565 .into_iter()
566 .map(|(value, fields)| {
567 let rendered =
568 self.render_fields(&fields, &format!("{hint}{}", sanitize_variant(&value)));
569 (value, rendered)
570 })
571 .collect::<Vec<_>>();
572 return self.push_discriminated_enum_named(hint, &tag, &rendered_variants);
573 }
574
575 if let Some(variants) = object_union_members(members) {
576 let rendered_variants = variants
577 .into_iter()
578 .enumerate()
579 .map(|(index, fields)| {
580 let variant_name = object_union_variant_name(&fields, index);
581 let rendered =
582 self.render_fields(&fields, &format!("{hint}Variant{}", index + 1));
583 (variant_name, rendered)
584 })
585 .collect::<Vec<_>>();
586 return self.push_untagged_enum_named(hint, &rendered_variants);
587 }
588
589 "serde_json::Value".into()
590 }
591
592 fn push_struct_body(&mut self, name: &str, fields: &[RenderedField]) {
593 self.push_type_header(true);
594 self.items
595 .push(format!("{}pub struct {name} {{\n", self.indent));
596 for field in fields {
597 if field.rust_name != field.original_name {
598 self.items.push(format!(
599 "{} #[serde(rename = \"{}\")]\n",
600 self.indent, field.original_name
601 ));
602 }
603 if !field.required {
604 self.items.push(format!(
605 "{} #[serde(skip_serializing_if = \"Option::is_none\")]\n",
606 self.indent
607 ));
608 }
609 self.items.push(format!(
610 "{} pub {}: {},\n",
611 self.indent, field.rust_name, field.ty
612 ));
613 }
614 self.items.push(format!("{}}}\n\n", self.indent));
615 }
616
617 fn push_type_header(&mut self, _is_struct: bool) {
618 let mut derives = vec!["Clone", "Debug", "Serialize", "Deserialize", "PartialEq"];
619 for derive in &self.derives {
620 derives.push(derive);
621 }
622 self.items.push(format!(
623 "{}#[derive({})]\n",
624 self.indent,
625 derives.join(", ")
626 ));
627 for attribute in &self.attributes {
628 self.items
629 .push(format!("{}#[{}]\n", self.indent, attribute));
630 }
631 }
632}
633
634fn optional_member(members: &[TypeNode]) -> Option<&TypeNode> {
635 if members.len() == 2
636 && members
637 .iter()
638 .any(|member| matches!(member, TypeNode::Null))
639 {
640 members
641 .iter()
642 .find(|member| !matches!(member, TypeNode::Null))
643 } else {
644 None
645 }
646}
647
648fn literal_string_union(members: &[TypeNode]) -> Option<Vec<String>> {
649 let mut values = Vec::new();
650 for member in members {
651 if let TypeNode::LiteralString { value } = member {
652 values.push(value.clone());
653 } else {
654 return None;
655 }
656 }
657 if values.is_empty() {
658 None
659 } else {
660 Some(values)
661 }
662}
663
664fn discriminated_union_members(
665 members: &[TypeNode],
666) -> Option<(String, Vec<(String, Vec<Field>)>)> {
667 let object_members = members
668 .iter()
669 .map(|member| match member {
670 TypeNode::Object { fields, .. } => Some(fields.clone()),
671 _ => None,
672 })
673 .collect::<Option<Vec<_>>>()?;
674
675 let candidate_tags = object_members
676 .first()?
677 .iter()
678 .filter_map(|field| match &field.r#type {
679 TypeNode::LiteralString { .. } => Some(field.name.clone()),
680 _ => None,
681 })
682 .collect::<Vec<_>>();
683
684 for tag in candidate_tags {
685 let mut variants = Vec::new();
686 let mut seen_values = BTreeSet::new();
687 let mut valid = true;
688 for fields in &object_members {
689 let Some(discriminant) = fields.iter().find(|field| field.name == tag) else {
690 valid = false;
691 break;
692 };
693 let TypeNode::LiteralString { value } = &discriminant.r#type else {
694 valid = false;
695 break;
696 };
697 if !seen_values.insert(value.clone()) {
698 valid = false;
699 break;
700 }
701 let variant_fields = fields
702 .iter()
703 .filter(|field| field.name != tag)
704 .cloned()
705 .collect::<Vec<_>>();
706 variants.push((value.clone(), variant_fields));
707 }
708 if valid {
709 return Some((tag, variants));
710 }
711 }
712
713 None
714}
715
716fn object_union_members(members: &[TypeNode]) -> Option<Vec<Vec<Field>>> {
717 let mut object_members = members
718 .iter()
719 .map(|member| match member {
720 TypeNode::Object { fields, .. } => Some(fields.clone()),
721 _ => None,
722 })
723 .collect::<Option<Vec<_>>>()?;
724 object_members.sort_by(|left, right| {
725 right.len().cmp(&left.len()).then_with(|| {
726 object_union_variant_name(left, 0).cmp(&object_union_variant_name(right, 0))
727 })
728 });
729 Some(object_members)
730}
731
732fn object_union_variant_name(fields: &[Field], index: usize) -> String {
733 let joined = fields
734 .iter()
735 .map(|field| pascal_case(&field.name))
736 .filter(|name| !name.is_empty())
737 .collect::<Vec<_>>()
738 .join("");
739 if joined.is_empty() {
740 format!("Variant{}", index + 1)
741 } else {
742 joined
743 }
744}
745
746fn dedupe_name(used: &mut BTreeSet<String>, base: &str) -> String {
747 if used.insert(base.to_string()) {
748 return base.to_string();
749 }
750 let mut suffix = 2;
751 loop {
752 let candidate = format!("{base}{suffix}");
753 if used.insert(candidate.clone()) {
754 return candidate;
755 }
756 suffix += 1;
757 }
758}
759
760fn sanitize_variant(input: &str) -> String {
761 let base = pascal_case(input);
762 match base.as_str() {
763 "" => "Unknown".into(),
764 "Self" | "Super" | "Crate" => format!("{base}Value"),
765 _ if base
766 .chars()
767 .next()
768 .map(|ch| ch.is_ascii_digit())
769 .unwrap_or(false) =>
770 {
771 format!("V{base}")
772 }
773 _ => base,
774 }
775}
776
777fn pascal_case(input: &str) -> String {
778 input
779 .split(|c: char| !c.is_alphanumeric())
780 .filter(|part| !part.is_empty())
781 .map(|part| {
782 let mut chars = part.chars();
783 match chars.next() {
784 Some(first) => first.to_uppercase().collect::<String>() + chars.as_str(),
785 None => String::new(),
786 }
787 })
788 .collect::<String>()
789}
790
791fn snake_case(input: &str) -> String {
792 let mut out = String::new();
793 let mut prev_is_separator = true;
794 for ch in input.chars() {
795 if ch.is_alphanumeric() {
796 if ch.is_uppercase() && !prev_is_separator && !out.is_empty() {
797 out.push('_');
798 }
799 out.extend(ch.to_lowercase());
800 prev_is_separator = false;
801 } else if !prev_is_separator && !out.is_empty() {
802 out.push('_');
803 prev_is_separator = true;
804 }
805 }
806
807 if out.is_empty() {
808 "value".into()
809 } else {
810 out.trim_end_matches('_').to_string()
811 }
812}
813
814fn module_ident(module_path: &str) -> String {
815 module_path
816 .split('/')
817 .filter(|segment| !segment.is_empty())
818 .map(snake_case)
819 .collect::<Vec<_>>()
820 .join("_")
821}
822
823#[cfg(test)]
824mod tests {
825 use super::*;
826
827 #[test]
828 fn renders_object_unions_as_untagged_enums() {
829 let mut generator = TypeGenerator::new(&RustexConfig::default());
830 let ty = TypeNode::Union {
831 members: vec![
832 TypeNode::Object {
833 fields: vec![Field {
834 name: "error".into(),
835 required: true,
836 r#type: TypeNode::String,
837 doc: None,
838 source: None,
839 }],
840 open: false,
841 },
842 TypeNode::Object {
843 fields: vec![
844 Field {
845 name: "count".into(),
846 required: true,
847 r#type: TypeNode::Float64,
848 doc: None,
849 source: None,
850 },
851 Field {
852 name: "error".into(),
853 required: true,
854 r#type: TypeNode::String,
855 doc: None,
856 source: None,
857 },
858 ],
859 open: false,
860 },
861 ],
862 };
863
864 let rendered = generator.render_type(&ty, true, "MultiReturnDemoResponse");
865 let output = generator.finish();
866
867 assert_eq!(rendered, "MultiReturnDemoResponse");
868 assert!(output.contains("#[serde(untagged)]"));
869 assert!(output.contains("pub enum MultiReturnDemoResponse"));
870 assert!(output.contains("CountError {"));
871 assert!(output.contains("count: f64"));
872 assert!(output.contains("Error {"));
873 assert!(output.find("CountError {") < output.find("Error {"));
874 }
875
876 #[test]
877 fn renders_short_unique_nested_names_for_object_union_variants() {
878 let mut generator = TypeGenerator::new(&RustexConfig::default());
879 let ty = TypeNode::Union {
880 members: vec![
881 TypeNode::Object {
882 fields: vec![Field {
883 name: "error".into(),
884 required: true,
885 r#type: TypeNode::String,
886 doc: None,
887 source: None,
888 }],
889 open: false,
890 },
891 TypeNode::Object {
892 fields: vec![
893 Field {
894 name: "messages".into(),
895 required: true,
896 r#type: TypeNode::Array {
897 element: Box::new(TypeNode::Object {
898 fields: vec![Field {
899 name: "body".into(),
900 required: true,
901 r#type: TypeNode::String,
902 doc: None,
903 source: None,
904 }],
905 open: false,
906 }),
907 },
908 doc: None,
909 source: None,
910 },
911 Field {
912 name: "count".into(),
913 required: true,
914 r#type: TypeNode::Float64,
915 doc: None,
916 source: None,
917 },
918 Field {
919 name: "error".into(),
920 required: true,
921 r#type: TypeNode::String,
922 doc: None,
923 source: None,
924 },
925 ],
926 open: false,
927 },
928 ],
929 };
930
931 generator.render_type(&ty, true, "MultiReturnDemoResponse");
932 let output = generator.finish();
933
934 assert!(output.contains("MultiReturnDemoResponseVariant1MessagesItem"));
935 assert!(!output.contains("MultiReturnDemoResponseMessagesCountErrorMessagesItem"));
936 }
937}