1use std::collections::{HashMap, HashSet};
4use std::path::PathBuf;
5
6use crate::cli::GenerateArgs;
7use crate::config::{CONFIG_FILE_NAME, Config, SCHEMA_FILE_PATH};
8use crate::error::{CliError, CliResult};
9use crate::output::{self, success};
10
11pub async fn run(args: GenerateArgs) -> CliResult<()> {
13 output::header("Generate Prax Client");
14
15 let cwd = std::env::current_dir()?;
16
17 let config_path = cwd.join(CONFIG_FILE_NAME);
19 let config = if config_path.exists() {
20 Config::load(&config_path)?
21 } else {
22 Config::default()
23 };
24
25 let schema_path = args
27 .schema
28 .clone()
29 .unwrap_or_else(|| cwd.join(SCHEMA_FILE_PATH));
30 if !schema_path.exists() {
31 return Err(
32 CliError::Config(format!("Schema file not found: {}", schema_path.display())).into(),
33 );
34 }
35
36 let output_dir = args
38 .output
39 .clone()
40 .unwrap_or_else(|| PathBuf::from(&config.generator.output));
41
42 output::kv("Schema", &schema_path.display().to_string());
43 output::kv("Output", &output_dir.display().to_string());
44 output::newline();
45
46 output::step(1, 4, "Reading schema...");
47
48 let schema_content = std::fs::read_to_string(&schema_path)?;
50 let schema = parse_schema(&schema_content)?;
51
52 output::step(2, 4, "Validating schema...");
53
54 validate_schema(&schema)?;
56
57 output::step(3, 4, "Generating code...");
58
59 std::fs::create_dir_all(&output_dir)?;
61
62 let generated_files = generate_code(&schema, &output_dir, &args, &config)?;
64
65 output::step(4, 4, "Writing files...");
66
67 output::newline();
69 output::section("Generated files");
70
71 for file in &generated_files {
72 let relative_path = file
73 .strip_prefix(&cwd)
74 .unwrap_or(file)
75 .display()
76 .to_string();
77 output::list_item(&relative_path);
78 }
79
80 output::newline();
81 success(&format!(
82 "Generated {} files in {:.2}s",
83 generated_files.len(),
84 0.0 ));
86
87 Ok(())
88}
89
90fn parse_schema(content: &str) -> CliResult<prax_schema::Schema> {
92 prax_schema::validate_schema(content)
95 .map_err(|e| CliError::Schema(format!("Failed to parse/validate schema: {}", e)))
96}
97
98fn validate_schema(_schema: &prax_schema::Schema) -> CliResult<()> {
100 Ok(())
102}
103
104fn generate_code(
106 schema: &prax_schema::ast::Schema,
107 output_dir: &PathBuf,
108 args: &GenerateArgs,
109 config: &Config,
110) -> CliResult<Vec<PathBuf>> {
111 let mut generated_files = Vec::new();
112
113 let features = if !args.features.is_empty() {
115 args.features.clone()
116 } else {
117 config
118 .generator
119 .features
120 .clone()
121 .unwrap_or_else(|| vec!["client".to_string()])
122 };
123
124 let relation_graph = build_relation_graph(schema);
126
127 let client_path = output_dir.join("mod.rs");
129 let client_code = generate_client_module(schema, &features)?;
130 std::fs::write(&client_path, client_code)?;
131 generated_files.push(client_path);
132
133 for model in schema.models.values() {
135 let model_path = output_dir.join(format!("{}.rs", to_snake_case(model.name())));
136 let model_code = generate_model_module(model, &features, &relation_graph)?;
137 std::fs::write(&model_path, model_code)?;
138 generated_files.push(model_path);
139 }
140
141 for enum_def in schema.enums.values() {
143 let enum_path = output_dir.join(format!("{}.rs", to_snake_case(enum_def.name())));
144 let enum_code = generate_enum_module(enum_def)?;
145 std::fs::write(&enum_path, enum_code)?;
146 generated_files.push(enum_path);
147 }
148
149 let types_path = output_dir.join("types.rs");
151 let types_code = generate_types_module(schema)?;
152 std::fs::write(&types_path, types_code)?;
153 generated_files.push(types_path);
154
155 let filters_path = output_dir.join("filters.rs");
157 let filters_code = generate_filters_module(schema)?;
158 std::fs::write(&filters_path, filters_code)?;
159 generated_files.push(filters_path);
160
161 Ok(generated_files)
162}
163
164fn build_relation_graph(
168 schema: &prax_schema::ast::Schema,
169) -> HashMap<String, HashSet<String>> {
170 let mut graph: HashMap<String, HashSet<String>> = HashMap::new();
171
172 for model in schema.models.values() {
173 let entry = graph.entry(model.name().to_string()).or_default();
174 for field in model.fields.values() {
175 if let prax_schema::ast::FieldType::Model(ref target) = field.field_type {
176 if !field.is_list() {
177 entry.insert(target.to_string());
178 }
179 }
180 }
181 }
182
183 graph
184}
185
186fn needs_boxing(
190 source_model: &str,
191 target_model: &str,
192 graph: &HashMap<String, HashSet<String>>,
193) -> bool {
194 let mut visited = HashSet::new();
195 let mut stack = vec![target_model.to_string()];
196
197 while let Some(current) = stack.pop() {
198 if current == source_model {
199 return true;
200 }
201 if !visited.insert(current.clone()) {
202 continue;
203 }
204 if let Some(neighbors) = graph.get(¤t) {
205 for neighbor in neighbors {
206 stack.push(neighbor.clone());
207 }
208 }
209 }
210
211 false
212}
213
214fn generate_client_module(
216 schema: &prax_schema::ast::Schema,
217 _features: &[String],
218) -> CliResult<String> {
219 let mut code = String::new();
220
221 code.push_str("//! Auto-generated by Prax - DO NOT EDIT\n");
222 code.push_str("//!\n");
223 code.push_str("//! This module contains the generated Prax client.\n\n");
224
225 code.push_str("pub mod types;\n");
227 code.push_str("pub mod filters;\n\n");
228
229 for model in schema.models.values() {
230 code.push_str(&format!("pub mod {};\n", to_snake_case(model.name())));
231 }
232
233 for enum_def in schema.enums.values() {
234 code.push_str(&format!("pub mod {};\n", to_snake_case(enum_def.name())));
235 }
236
237 code.push_str("\n");
238
239 code.push_str("#[allow(unused_imports)]\npub use types::*;\n");
241 code.push_str("#[allow(unused_imports)]\npub use filters::*;\n\n");
242
243 for model in schema.models.values() {
244 code.push_str(&format!(
245 "#[allow(unused_imports)]\npub use {}::{};\n",
246 to_snake_case(model.name()),
247 model.name()
248 ));
249 }
250
251 for enum_def in schema.enums.values() {
252 code.push_str(&format!(
253 "#[allow(unused_imports)]\npub use {}::{};\n",
254 to_snake_case(enum_def.name()),
255 enum_def.name()
256 ));
257 }
258
259 code.push_str("\n");
260
261 code.push_str("#[allow(dead_code)]\n");
263 code.push_str("/// The Prax database client\n");
264 code.push_str("#[derive(Clone)]\n");
265 code.push_str("pub struct PraxClient<E: prax_query::QueryEngine> {\n");
266 code.push_str(" engine: E,\n");
267 code.push_str("}\n\n");
268
269 code.push_str("impl<E: prax_query::QueryEngine> PraxClient<E> {\n");
270 code.push_str(" /// Create a new Prax client with the given query engine\n");
271 code.push_str(" pub fn new(engine: E) -> Self {\n");
272 code.push_str(" Self { engine }\n");
273 code.push_str(" }\n\n");
274
275 for model in schema.models.values() {
276 let snake_name = to_snake_case(model.name());
277 code.push_str(&format!(" /// Access {} operations\n", model.name()));
278 code.push_str(&format!(
279 " pub fn {}(&self) -> {}::{}Operations<E> {{\n",
280 snake_name,
281 snake_name,
282 model.name()
283 ));
284 code.push_str(&format!(
285 " {}::{}Operations::new(self.engine.clone())\n",
286 snake_name,
287 model.name()
288 ));
289 code.push_str(" }\n\n");
290 }
291
292 code.push_str("}\n");
293
294 Ok(code)
295}
296
297fn generate_model_module(
299 model: &prax_schema::ast::Model,
300 features: &[String],
301 relation_graph: &HashMap<String, HashSet<String>>,
302) -> CliResult<String> {
303 let mut code = String::new();
304
305 code.push_str(&format!(
306 "//! Auto-generated module for {} model\n\n",
307 model.name()
308 ));
309
310 code.push_str("#[allow(unused_imports)]\n");
312 code.push_str("use super::*;\n");
313 code.push_str("#[allow(unused_imports)]\n");
314 code.push_str("use prax_query::traits::Model;\n\n");
315
316 let mut derives = vec!["Debug", "Clone"];
318 if features.contains(&"serde".to_string()) {
319 derives.push("serde::Serialize");
320 derives.push("serde::Deserialize");
321 }
322
323 code.push_str("#[allow(dead_code)]\n");
325 code.push_str(&format!("#[derive({})]\n", derives.join(", ")));
326 code.push_str(&format!("pub struct {} {{\n", model.name()));
327
328 for field in model.fields.values() {
329 let field_name = to_snake_case(field.name());
330
331 if let Some(attr) = field.get_attribute("map") {
333 if features.contains(&"serde".to_string()) {
334 if let Some(value) = attr.first_arg().and_then(|v| v.as_string()) {
335 code.push_str(&format!(" #[serde(rename = \"{}\")]\n", value));
336 }
337 }
338 }
339
340 let rust_type = field_type_to_rust_with_boxing(
341 &field.field_type,
342 field.modifier,
343 model.name(),
344 relation_graph,
345 );
346 code.push_str(&format!(" pub {}: {},\n", field_name, rust_type));
347 }
348
349 code.push_str("}\n\n");
350
351 let table_name = model.table_name();
353 let id_fields: Vec<&str> = model.id_fields().iter().map(|f| f.name()).collect();
354 let scalar_columns: Vec<String> = model
355 .scalar_fields()
356 .iter()
357 .map(|f| {
358 f.get_attribute("map")
360 .and_then(|a| a.first_arg())
361 .and_then(|v| v.as_string())
362 .map(|s| s.to_string())
363 .unwrap_or_else(|| to_snake_case(f.name()))
364 })
365 .collect();
366
367 code.push_str(&format!("impl Model for {} {{\n", model.name()));
368 code.push_str(&format!(
369 " const MODEL_NAME: &'static str = \"{}\";\n",
370 model.name()
371 ));
372 code.push_str(&format!(
373 " const TABLE_NAME: &'static str = \"{}\";\n",
374 table_name
375 ));
376 code.push_str(&format!(
377 " const PRIMARY_KEY: &'static [&'static str] = &[{}];\n",
378 id_fields
379 .iter()
380 .map(|f| format!("\"{}\"", to_snake_case(f)))
381 .collect::<Vec<_>>()
382 .join(", ")
383 ));
384 code.push_str(&format!(
385 " const COLUMNS: &'static [&'static str] = &[{}];\n",
386 scalar_columns
387 .iter()
388 .map(|c| format!("\"{}\"", c))
389 .collect::<Vec<_>>()
390 .join(", ")
391 ));
392 code.push_str("}\n\n");
393
394 code.push_str("#[allow(dead_code)]\n");
396 code.push_str(&format!("/// Operations for the {} model\n", model.name()));
397 code.push_str(&format!(
398 "pub struct {}Operations<E: prax_query::QueryEngine> {{\n",
399 model.name()
400 ));
401 code.push_str(" engine: E,\n");
402 code.push_str("}\n\n");
403
404 code.push_str(&format!(
405 "impl<E: prax_query::QueryEngine> {}Operations<E> {{\n",
406 model.name()
407 ));
408 code.push_str(" pub fn new(engine: E) -> Self {\n");
409 code.push_str(" Self { engine }\n");
410 code.push_str(" }\n\n");
411
412 code.push_str(" /// Find many records\n");
414 code.push_str(&format!(
415 " pub fn find_many(&self) -> prax_query::FindManyOperation<E, {}> {{\n",
416 model.name()
417 ));
418 code.push_str(" prax_query::FindManyOperation::new(self.engine.clone())\n");
419 code.push_str(" }\n\n");
420
421 code.push_str(" /// Find a unique record\n");
422 code.push_str(&format!(
423 " pub fn find_unique(&self) -> prax_query::FindUniqueOperation<E, {}> {{\n",
424 model.name()
425 ));
426 code.push_str(" prax_query::FindUniqueOperation::new(self.engine.clone())\n");
427 code.push_str(" }\n\n");
428
429 code.push_str(" /// Find the first matching record\n");
430 code.push_str(&format!(
431 " pub fn find_first(&self) -> prax_query::FindFirstOperation<E, {}> {{\n",
432 model.name()
433 ));
434 code.push_str(" prax_query::FindFirstOperation::new(self.engine.clone())\n");
435 code.push_str(" }\n\n");
436
437 code.push_str(" /// Create a new record\n");
438 code.push_str(&format!(
439 " pub fn create(&self) -> prax_query::CreateOperation<E, {}> {{\n",
440 model.name()
441 ));
442 code.push_str(" prax_query::CreateOperation::new(self.engine.clone())\n");
443 code.push_str(" }\n\n");
444
445 code.push_str(" /// Update a record\n");
446 code.push_str(&format!(
447 " pub fn update(&self) -> prax_query::UpdateOperation<E, {}> {{\n",
448 model.name()
449 ));
450 code.push_str(" prax_query::UpdateOperation::new(self.engine.clone())\n");
451 code.push_str(" }\n\n");
452
453 code.push_str(" /// Delete a record\n");
454 code.push_str(&format!(
455 " pub fn delete(&self) -> prax_query::DeleteOperation<E, {}> {{\n",
456 model.name()
457 ));
458 code.push_str(" prax_query::DeleteOperation::new(self.engine.clone())\n");
459 code.push_str(" }\n\n");
460
461 code.push_str(" /// Count records\n");
462 code.push_str(&format!(
463 " pub fn count(&self) -> prax_query::CountOperation<E, {}> {{\n",
464 model.name()
465 ));
466 code.push_str(" prax_query::CountOperation::new(self.engine.clone())\n");
467 code.push_str(" }\n");
468
469 code.push_str("}\n");
470
471 Ok(code)
472}
473
474fn generate_enum_module(enum_def: &prax_schema::ast::Enum) -> CliResult<String> {
476 let mut code = String::new();
477
478 code.push_str(&format!(
479 "//! Auto-generated module for {} enum\n\n",
480 enum_def.name()
481 ));
482
483 code.push_str("#[allow(dead_code)]\n");
484 code.push_str(
485 "#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]\n",
486 );
487 code.push_str(&format!("pub enum {} {{\n", enum_def.name()));
488
489 for variant in &enum_def.variants {
490 let raw_name = variant.name();
491 let pascal_name = to_pascal_case(raw_name);
492
493 if let Some(attr) = variant.attributes.iter().find(|a| a.is("map")) {
495 if let Some(value) = attr.first_arg().and_then(|v| v.as_string()) {
496 code.push_str(&format!(" #[serde(rename = \"{}\")]\n", value));
497 code.push_str(&format!(" {},\n", pascal_name));
498 continue;
499 }
500 }
501
502 if raw_name != pascal_name {
504 code.push_str(&format!(" #[serde(rename = \"{}\")]\n", raw_name));
505 }
506 code.push_str(&format!(" {},\n", pascal_name));
507 }
508
509 code.push_str("}\n\n");
510
511 code.push_str(&format!("impl std::fmt::Display for {} {{\n", enum_def.name()));
513 code.push_str(" fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {\n");
514 code.push_str(" match self {\n");
515 for variant in &enum_def.variants {
516 let raw_name = variant.name();
517 let pascal_name = to_pascal_case(raw_name);
518 let db_value = variant.db_value();
519 code.push_str(&format!(
520 " Self::{} => write!(f, \"{}\"),\n",
521 pascal_name, db_value
522 ));
523 }
524 code.push_str(" }\n");
525 code.push_str(" }\n");
526 code.push_str("}\n\n");
527
528 if let Some(default_variant) = enum_def.variants.first() {
530 let pascal_name = to_pascal_case(default_variant.name());
531 code.push_str(&format!("impl Default for {} {{\n", enum_def.name()));
532 code.push_str(&format!(
533 " fn default() -> Self {{\n Self::{}\n }}\n",
534 pascal_name
535 ));
536 code.push_str("}\n");
537 }
538
539 Ok(code)
540}
541
542fn generate_types_module(schema: &prax_schema::ast::Schema) -> CliResult<String> {
544 let mut code = String::new();
545
546 code.push_str("//! Common type definitions\n\n");
547 code.push_str("#[allow(unused_imports)]\npub use chrono::{DateTime, Utc};\n");
548 code.push_str("#[allow(unused_imports)]\npub use uuid::Uuid;\n");
549 code.push_str("#[allow(unused_imports)]\npub use serde_json::Value as Json;\n");
550 code.push_str("\n");
551
552 for composite in schema.types.values() {
554 code.push_str("#[allow(dead_code)]\n");
555 code.push_str("#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]\n");
556 code.push_str(&format!("pub struct {} {{\n", composite.name()));
557 for field in composite.fields.values() {
558 let rust_type = field_type_to_rust(&field.field_type, field.modifier);
559 let field_name = to_snake_case(field.name());
560 code.push_str(&format!(" pub {}: {},\n", field_name, rust_type));
561 }
562 code.push_str("}\n\n");
563 }
564
565 Ok(code)
566}
567
568fn generate_filters_module(schema: &prax_schema::ast::Schema) -> CliResult<String> {
570 let mut code = String::new();
571
572 code.push_str("//! Filter types for queries\n\n");
573 code.push_str("#[allow(unused_imports)]\n");
574 code.push_str("use prax_query::filter::{Filter, ScalarFilter};\n");
575
576 let mut referenced_enums = HashSet::new();
578 for model in schema.models.values() {
579 for field in model.fields.values() {
580 if !field.is_relation() {
581 if let prax_schema::ast::FieldType::Enum(ref name) = field.field_type {
582 referenced_enums.insert(name.to_string());
583 }
584 }
585 }
586 }
587
588 for enum_name in &referenced_enums {
590 code.push_str(&format!(
591 "#[allow(unused_imports)]\nuse super::{}::{};\n",
592 to_snake_case(enum_name),
593 enum_name
594 ));
595 }
596
597 code.push_str("\n");
598
599 for model in schema.models.values() {
600 code.push_str("#[allow(dead_code)]\n");
602 code.push_str(&format!("/// Filter input for {} queries\n", model.name()));
603 code.push_str("#[derive(Debug, Default, Clone)]\n");
604 code.push_str(&format!("pub struct {}WhereInput {{\n", model.name()));
605
606 for field in model.fields.values() {
607 if !field.is_relation() {
608 let filter_type = field_to_filter_type(&field.field_type);
609 let field_name = to_snake_case(field.name());
610 code.push_str(&format!(
611 " pub {}: Option<{}>,\n",
612 field_name, filter_type
613 ));
614 }
615 }
616
617 code.push_str(" pub and: Option<Vec<Self>>,\n");
618 code.push_str(" pub or: Option<Vec<Self>>,\n");
619 code.push_str(" pub not: Option<Box<Self>>,\n");
620 code.push_str("}\n\n");
621
622 code.push_str("#[allow(dead_code)]\n");
624 code.push_str(&format!(
625 "/// Order by input for {} queries\n",
626 model.name()
627 ));
628 code.push_str("#[derive(Debug, Default, Clone)]\n");
629 code.push_str(&format!("pub struct {}OrderByInput {{\n", model.name()));
630
631 for field in model.fields.values() {
632 if !field.is_relation() {
633 let field_name = to_snake_case(field.name());
634 code.push_str(&format!(
635 " pub {}: Option<prax_query::SortOrder>,\n",
636 field_name
637 ));
638 }
639 }
640
641 code.push_str("}\n\n");
642 }
643
644 Ok(code)
645}
646
647fn field_type_to_rust(
649 field_type: &prax_schema::ast::FieldType,
650 modifier: prax_schema::ast::TypeModifier,
651) -> String {
652 use prax_schema::ast::{FieldType, ScalarType, TypeModifier};
653
654 let base_type = match field_type {
655 FieldType::Scalar(scalar) => match scalar {
656 ScalarType::Int => "i32".to_string(),
657 ScalarType::BigInt => "i64".to_string(),
658 ScalarType::Float => "f64".to_string(),
659 ScalarType::String => "String".to_string(),
660 ScalarType::Boolean => "bool".to_string(),
661 ScalarType::DateTime => "chrono::DateTime<chrono::Utc>".to_string(),
662 ScalarType::Date => "chrono::NaiveDate".to_string(),
663 ScalarType::Time => "chrono::NaiveTime".to_string(),
664 ScalarType::Json => "serde_json::Value".to_string(),
665 ScalarType::Bytes => "Vec<u8>".to_string(),
666 ScalarType::Decimal => "rust_decimal::Decimal".to_string(),
667 ScalarType::Uuid => "uuid::Uuid".to_string(),
668 ScalarType::Cuid => "String".to_string(),
669 ScalarType::Cuid2 => "String".to_string(),
670 ScalarType::NanoId => "String".to_string(),
671 ScalarType::Ulid => "String".to_string(),
672 ScalarType::Vector(_) | ScalarType::HalfVector(_) => "Vec<f32>".to_string(),
673 ScalarType::SparseVector(_) => "Vec<(u32, f32)>".to_string(),
674 ScalarType::Bit(_) => "Vec<u8>".to_string(),
675 },
676 FieldType::Model(name) => name.to_string(),
677 FieldType::Enum(name) => name.to_string(),
678 FieldType::Composite(name) => name.to_string(),
679 FieldType::Unsupported(_) => "serde_json::Value".to_string(),
680 };
681
682 match modifier {
683 TypeModifier::Optional | TypeModifier::OptionalList => format!("Option<{}>", base_type),
684 TypeModifier::List => format!("Vec<{}>", base_type),
685 TypeModifier::Required => base_type,
686 }
687}
688
689fn field_type_to_rust_with_boxing(
691 field_type: &prax_schema::ast::FieldType,
692 modifier: prax_schema::ast::TypeModifier,
693 source_model: &str,
694 relation_graph: &HashMap<String, HashSet<String>>,
695) -> String {
696 use prax_schema::ast::{FieldType, TypeModifier};
697
698 if let FieldType::Model(target) = field_type {
700 if !matches!(modifier, TypeModifier::List) {
701 let should_box = needs_boxing(source_model, target, relation_graph);
702 let base = target.to_string();
703 return match modifier {
704 TypeModifier::Optional | TypeModifier::OptionalList => {
705 if should_box {
706 format!("Option<Box<{}>>", base)
707 } else {
708 format!("Option<{}>", base)
709 }
710 }
711 TypeModifier::Required => {
712 if should_box {
713 format!("Box<{}>", base)
714 } else {
715 base
716 }
717 }
718 TypeModifier::List => unreachable!(),
719 };
720 }
721 }
722
723 field_type_to_rust(field_type, modifier)
725}
726
727fn field_to_filter_type(field_type: &prax_schema::ast::FieldType) -> String {
729 use prax_schema::ast::{FieldType, ScalarType};
730
731 match field_type {
732 FieldType::Scalar(scalar) => match scalar {
733 ScalarType::Int | ScalarType::BigInt => "ScalarFilter<i64>".to_string(),
734 ScalarType::Float | ScalarType::Decimal => "ScalarFilter<f64>".to_string(),
735 ScalarType::String
736 | ScalarType::Uuid
737 | ScalarType::Cuid
738 | ScalarType::Cuid2
739 | ScalarType::NanoId
740 | ScalarType::Ulid => "ScalarFilter<String>".to_string(),
741 ScalarType::Boolean => "ScalarFilter<bool>".to_string(),
742 ScalarType::DateTime => "ScalarFilter<chrono::DateTime<chrono::Utc>>".to_string(),
743 ScalarType::Date => "ScalarFilter<chrono::NaiveDate>".to_string(),
744 ScalarType::Time => "ScalarFilter<chrono::NaiveTime>".to_string(),
745 ScalarType::Json => "ScalarFilter<serde_json::Value>".to_string(),
746 ScalarType::Bytes => "ScalarFilter<Vec<u8>>".to_string(),
747 ScalarType::Vector(_) | ScalarType::HalfVector(_) => "VectorFilter".to_string(),
749 ScalarType::SparseVector(_) => "SparseVectorFilter".to_string(),
750 ScalarType::Bit(_) => "BitFilter".to_string(),
751 },
752 FieldType::Enum(name) => format!("ScalarFilter<{}>", name),
753 _ => "Filter".to_string(),
754 }
755}
756
757fn to_snake_case(name: &str) -> String {
759 let mut result = String::new();
760 for (i, c) in name.chars().enumerate() {
761 if c.is_uppercase() {
762 if i > 0 {
763 result.push('_');
764 }
765 result.push(c.to_lowercase().next().unwrap());
766 } else {
767 result.push(c);
768 }
769 }
770 result
771}
772
773fn to_pascal_case(name: &str) -> String {
775 if name.is_empty() {
776 return String::new();
777 }
778
779 let first = name.chars().next().unwrap();
781 if first.is_uppercase() && name.chars().any(|c| c.is_lowercase()) && !name.contains('_') {
782 return name.to_string();
783 }
784
785 name.split('_')
787 .filter(|s| !s.is_empty())
788 .map(|segment| {
789 let mut chars = segment.chars();
790 match chars.next() {
791 None => String::new(),
792 Some(first) => {
793 let rest: String = chars.collect();
794 format!("{}{}", first.to_uppercase(), rest.to_lowercase())
795 }
796 }
797 })
798 .collect()
799}
800
801#[cfg(test)]
802mod tests {
803 use super::*;
804
805 #[test]
806 fn test_to_snake_case() {
807 assert_eq!(to_snake_case("BoardMember"), "board_member");
808 assert_eq!(to_snake_case("User"), "user");
809 assert_eq!(to_snake_case("JiraImportConfig"), "jira_import_config");
810 }
811
812 #[test]
813 fn test_to_pascal_case_from_snake() {
814 assert_eq!(to_pascal_case("card_created"), "CardCreated");
815 assert_eq!(to_pascal_case("branch_deleted"), "BranchDeleted");
816 assert_eq!(to_pascal_case("pr_merged"), "PrMerged");
817 }
818
819 #[test]
820 fn test_to_pascal_case_from_screaming() {
821 assert_eq!(to_pascal_case("CARD_CREATED"), "CardCreated");
822 assert_eq!(to_pascal_case("PR_MERGED"), "PrMerged");
823 }
824
825 #[test]
826 fn test_to_pascal_case_already_pascal() {
827 assert_eq!(to_pascal_case("Admin"), "Admin");
828 assert_eq!(to_pascal_case("SuperAdmin"), "SuperAdmin");
829 assert_eq!(to_pascal_case("Low"), "Low");
830 }
831
832 #[test]
833 fn test_to_pascal_case_single_word() {
834 assert_eq!(to_pascal_case("active"), "Active");
835 assert_eq!(to_pascal_case("ACTIVE"), "Active");
836 }
837
838 #[test]
839 fn test_needs_boxing_direct_cycle() {
840 let mut graph = HashMap::new();
841 graph.insert(
842 "Board".to_string(),
843 HashSet::from(["JiraConfig".to_string()]),
844 );
845 graph.insert(
846 "JiraConfig".to_string(),
847 HashSet::from(["Board".to_string()]),
848 );
849
850 assert!(needs_boxing("Board", "JiraConfig", &graph));
851 assert!(needs_boxing("JiraConfig", "Board", &graph));
852 }
853
854 #[test]
855 fn test_needs_boxing_no_cycle() {
856 let mut graph = HashMap::new();
857 graph.insert("Post".to_string(), HashSet::from(["User".to_string()]));
858 graph.insert("User".to_string(), HashSet::new());
859
860 assert!(!needs_boxing("Post", "User", &graph));
861 }
862
863 #[test]
864 fn test_needs_boxing_indirect_cycle() {
865 let mut graph = HashMap::new();
866 graph.insert("A".to_string(), HashSet::from(["B".to_string()]));
867 graph.insert("B".to_string(), HashSet::from(["C".to_string()]));
868 graph.insert("C".to_string(), HashSet::from(["A".to_string()]));
869
870 assert!(needs_boxing("A", "B", &graph));
871 assert!(needs_boxing("B", "C", &graph));
872 assert!(needs_boxing("C", "A", &graph));
873 }
874}