1use std::collections::{HashMap, HashSet};
4use std::path::PathBuf;
5
6use crate::cli::GenerateArgs;
7use crate::config::{CONFIG_FILE_NAME, Config, SCHEMA_FILE_PATH};
8use crate::error::{CliError, CliResult};
9use crate::output::{self, success};
10
11pub async fn run(args: GenerateArgs) -> CliResult<()> {
13 output::header("Generate Prax Client");
14
15 let cwd = std::env::current_dir()?;
16
17 let config_path = cwd.join(CONFIG_FILE_NAME);
19 let config = if config_path.exists() {
20 Config::load(&config_path)?
21 } else {
22 Config::default()
23 };
24
25 let schema_path = args
27 .schema
28 .clone()
29 .unwrap_or_else(|| cwd.join(SCHEMA_FILE_PATH));
30 if !schema_path.exists() {
31 return Err(
32 CliError::Config(format!("Schema file not found: {}", schema_path.display())).into(),
33 );
34 }
35
36 let output_dir = args
38 .output
39 .clone()
40 .unwrap_or_else(|| PathBuf::from(&config.generator.output));
41
42 output::kv("Schema", &schema_path.display().to_string());
43 output::kv("Output", &output_dir.display().to_string());
44 output::newline();
45
46 output::step(1, 4, "Reading schema...");
47
48 let schema_content = std::fs::read_to_string(&schema_path)?;
50 let schema = parse_schema(&schema_content)?;
51
52 output::step(2, 4, "Validating schema...");
53
54 validate_schema(&schema)?;
56
57 output::step(3, 4, "Generating code...");
58
59 std::fs::create_dir_all(&output_dir)?;
61
62 let generated_files = generate_code(&schema, &output_dir, &args, &config)?;
64
65 output::step(4, 4, "Writing files...");
66
67 output::newline();
69 output::section("Generated files");
70
71 for file in &generated_files {
72 let relative_path = file
73 .strip_prefix(&cwd)
74 .unwrap_or(file)
75 .display()
76 .to_string();
77 output::list_item(&relative_path);
78 }
79
80 output::newline();
81 success(&format!(
82 "Generated {} files in {:.2}s",
83 generated_files.len(),
84 0.0 ));
86
87 Ok(())
88}
89
90fn parse_schema(content: &str) -> CliResult<prax_schema::Schema> {
92 prax_schema::validate_schema(content)
95 .map_err(|e| CliError::Schema(format!("Failed to parse/validate schema: {}", e)))
96}
97
98fn validate_schema(_schema: &prax_schema::Schema) -> CliResult<()> {
100 Ok(())
102}
103
104fn generate_code(
106 schema: &prax_schema::ast::Schema,
107 output_dir: &PathBuf,
108 args: &GenerateArgs,
109 config: &Config,
110) -> CliResult<Vec<PathBuf>> {
111 let mut generated_files = Vec::new();
112
113 let features = if !args.features.is_empty() {
115 args.features.clone()
116 } else {
117 config
118 .generator
119 .features
120 .clone()
121 .unwrap_or_else(|| vec!["client".to_string()])
122 };
123
124 let relation_graph = build_relation_graph(schema);
126
127 let client_path = output_dir.join("mod.rs");
129 let client_code = generate_client_module(schema, &features)?;
130 std::fs::write(&client_path, client_code)?;
131 generated_files.push(client_path);
132
133 for model in schema.models.values() {
135 let model_path = output_dir.join(format!("{}.rs", to_snake_case(model.name())));
136 let model_code = generate_model_module(model, &features, &relation_graph)?;
137 std::fs::write(&model_path, model_code)?;
138 generated_files.push(model_path);
139 }
140
141 for enum_def in schema.enums.values() {
143 let enum_path = output_dir.join(format!("{}.rs", to_snake_case(enum_def.name())));
144 let enum_code = generate_enum_module(enum_def)?;
145 std::fs::write(&enum_path, enum_code)?;
146 generated_files.push(enum_path);
147 }
148
149 let types_path = output_dir.join("types.rs");
151 let types_code = generate_types_module(schema)?;
152 std::fs::write(&types_path, types_code)?;
153 generated_files.push(types_path);
154
155 let filters_path = output_dir.join("filters.rs");
157 let filters_code = generate_filters_module(schema)?;
158 std::fs::write(&filters_path, filters_code)?;
159 generated_files.push(filters_path);
160
161 Ok(generated_files)
162}
163
164fn build_relation_graph(
168 schema: &prax_schema::ast::Schema,
169) -> HashMap<String, HashSet<String>> {
170 let mut graph: HashMap<String, HashSet<String>> = HashMap::new();
171
172 for model in schema.models.values() {
173 let entry = graph.entry(model.name().to_string()).or_default();
174 for field in model.fields.values() {
175 if let prax_schema::ast::FieldType::Model(ref target) = field.field_type {
176 if !field.is_list() {
177 entry.insert(target.to_string());
178 }
179 }
180 }
181 }
182
183 graph
184}
185
186fn needs_boxing(
190 source_model: &str,
191 target_model: &str,
192 graph: &HashMap<String, HashSet<String>>,
193) -> bool {
194 let mut visited = HashSet::new();
195 let mut stack = vec![target_model.to_string()];
196
197 while let Some(current) = stack.pop() {
198 if current == source_model {
199 return true;
200 }
201 if !visited.insert(current.clone()) {
202 continue;
203 }
204 if let Some(neighbors) = graph.get(¤t) {
205 for neighbor in neighbors {
206 stack.push(neighbor.clone());
207 }
208 }
209 }
210
211 false
212}
213
214fn generate_client_module(
216 schema: &prax_schema::ast::Schema,
217 _features: &[String],
218) -> CliResult<String> {
219 let mut code = String::new();
220
221 code.push_str("//! Auto-generated by Prax - DO NOT EDIT\n");
222 code.push_str("//!\n");
223 code.push_str("//! This module contains the generated Prax client.\n\n");
224
225 code.push_str("pub mod types;\n");
227 code.push_str("pub mod filters;\n\n");
228
229 for model in schema.models.values() {
230 code.push_str(&format!("pub mod {};\n", to_snake_case(model.name())));
231 }
232
233 for enum_def in schema.enums.values() {
234 code.push_str(&format!("pub mod {};\n", to_snake_case(enum_def.name())));
235 }
236
237 code.push_str("\n");
238
239 code.push_str("pub use types::*;\n");
241 code.push_str("pub use filters::*;\n\n");
242
243 for model in schema.models.values() {
244 code.push_str(&format!(
245 "pub use {}::{};\n",
246 to_snake_case(model.name()),
247 model.name()
248 ));
249 }
250
251 for enum_def in schema.enums.values() {
252 code.push_str(&format!(
253 "pub use {}::{};\n",
254 to_snake_case(enum_def.name()),
255 enum_def.name()
256 ));
257 }
258
259 code.push_str("\n");
260
261 code.push_str("/// The Prax database client\n");
263 code.push_str("#[derive(Clone)]\n");
264 code.push_str("pub struct PraxClient<E: prax_query::QueryEngine> {\n");
265 code.push_str(" engine: E,\n");
266 code.push_str("}\n\n");
267
268 code.push_str("impl<E: prax_query::QueryEngine> PraxClient<E> {\n");
269 code.push_str(" /// Create a new Prax client with the given query engine\n");
270 code.push_str(" pub fn new(engine: E) -> Self {\n");
271 code.push_str(" Self { engine }\n");
272 code.push_str(" }\n\n");
273
274 for model in schema.models.values() {
275 let snake_name = to_snake_case(model.name());
276 code.push_str(&format!(" /// Access {} operations\n", model.name()));
277 code.push_str(&format!(
278 " pub fn {}(&self) -> {}::{}Operations<E> {{\n",
279 snake_name,
280 snake_name,
281 model.name()
282 ));
283 code.push_str(&format!(
284 " {}::{}Operations::new(self.engine.clone())\n",
285 snake_name,
286 model.name()
287 ));
288 code.push_str(" }\n\n");
289 }
290
291 code.push_str("}\n");
292
293 Ok(code)
294}
295
296fn generate_model_module(
298 model: &prax_schema::ast::Model,
299 features: &[String],
300 relation_graph: &HashMap<String, HashSet<String>>,
301) -> CliResult<String> {
302 let mut code = String::new();
303
304 code.push_str(&format!(
305 "//! Auto-generated module for {} model\n\n",
306 model.name()
307 ));
308
309 code.push_str("use super::*;\n");
311 code.push_str("use prax_query::traits::Model;\n\n");
312
313 let mut derives = vec!["Debug", "Clone"];
315 if features.contains(&"serde".to_string()) {
316 derives.push("serde::Serialize");
317 derives.push("serde::Deserialize");
318 }
319
320 code.push_str(&format!("#[derive({})]\n", derives.join(", ")));
322 code.push_str(&format!("pub struct {} {{\n", model.name()));
323
324 for field in model.fields.values() {
325 let field_name = to_snake_case(field.name());
326
327 if let Some(attr) = field.get_attribute("map") {
329 if features.contains(&"serde".to_string()) {
330 if let Some(value) = attr.first_arg().and_then(|v| v.as_string()) {
331 code.push_str(&format!(" #[serde(rename = \"{}\")]\n", value));
332 }
333 }
334 }
335
336 let rust_type = field_type_to_rust_with_boxing(
337 &field.field_type,
338 field.modifier,
339 model.name(),
340 relation_graph,
341 );
342 code.push_str(&format!(" pub {}: {},\n", field_name, rust_type));
343 }
344
345 code.push_str("}\n\n");
346
347 let table_name = model.table_name();
349 let id_fields: Vec<&str> = model.id_fields().iter().map(|f| f.name()).collect();
350 let scalar_columns: Vec<String> = model
351 .scalar_fields()
352 .iter()
353 .map(|f| {
354 f.get_attribute("map")
356 .and_then(|a| a.first_arg())
357 .and_then(|v| v.as_string())
358 .map(|s| s.to_string())
359 .unwrap_or_else(|| to_snake_case(f.name()))
360 })
361 .collect();
362
363 code.push_str(&format!("impl Model for {} {{\n", model.name()));
364 code.push_str(&format!(
365 " const MODEL_NAME: &'static str = \"{}\";\n",
366 model.name()
367 ));
368 code.push_str(&format!(
369 " const TABLE_NAME: &'static str = \"{}\";\n",
370 table_name
371 ));
372 code.push_str(&format!(
373 " const PRIMARY_KEY: &'static [&'static str] = &[{}];\n",
374 id_fields
375 .iter()
376 .map(|f| format!("\"{}\"", to_snake_case(f)))
377 .collect::<Vec<_>>()
378 .join(", ")
379 ));
380 code.push_str(&format!(
381 " const COLUMNS: &'static [&'static str] = &[{}];\n",
382 scalar_columns
383 .iter()
384 .map(|c| format!("\"{}\"", c))
385 .collect::<Vec<_>>()
386 .join(", ")
387 ));
388 code.push_str("}\n\n");
389
390 code.push_str(&format!("/// Operations for the {} model\n", model.name()));
392 code.push_str(&format!(
393 "pub struct {}Operations<E: prax_query::QueryEngine> {{\n",
394 model.name()
395 ));
396 code.push_str(" engine: E,\n");
397 code.push_str("}\n\n");
398
399 code.push_str(&format!(
400 "impl<E: prax_query::QueryEngine> {}Operations<E> {{\n",
401 model.name()
402 ));
403 code.push_str(" pub fn new(engine: E) -> Self {\n");
404 code.push_str(" Self { engine }\n");
405 code.push_str(" }\n\n");
406
407 code.push_str(" /// Find many records\n");
409 code.push_str(&format!(
410 " pub fn find_many(&self) -> prax_query::FindManyOperation<E, {}> {{\n",
411 model.name()
412 ));
413 code.push_str(" prax_query::FindManyOperation::new(self.engine.clone())\n");
414 code.push_str(" }\n\n");
415
416 code.push_str(" /// Find a unique record\n");
417 code.push_str(&format!(
418 " pub fn find_unique(&self) -> prax_query::FindUniqueOperation<E, {}> {{\n",
419 model.name()
420 ));
421 code.push_str(" prax_query::FindUniqueOperation::new(self.engine.clone())\n");
422 code.push_str(" }\n\n");
423
424 code.push_str(" /// Find the first matching record\n");
425 code.push_str(&format!(
426 " pub fn find_first(&self) -> prax_query::FindFirstOperation<E, {}> {{\n",
427 model.name()
428 ));
429 code.push_str(" prax_query::FindFirstOperation::new(self.engine.clone())\n");
430 code.push_str(" }\n\n");
431
432 code.push_str(" /// Create a new record\n");
433 code.push_str(&format!(
434 " pub fn create(&self) -> prax_query::CreateOperation<E, {}> {{\n",
435 model.name()
436 ));
437 code.push_str(" prax_query::CreateOperation::new(self.engine.clone())\n");
438 code.push_str(" }\n\n");
439
440 code.push_str(" /// Update a record\n");
441 code.push_str(&format!(
442 " pub fn update(&self) -> prax_query::UpdateOperation<E, {}> {{\n",
443 model.name()
444 ));
445 code.push_str(" prax_query::UpdateOperation::new(self.engine.clone())\n");
446 code.push_str(" }\n\n");
447
448 code.push_str(" /// Delete a record\n");
449 code.push_str(&format!(
450 " pub fn delete(&self) -> prax_query::DeleteOperation<E, {}> {{\n",
451 model.name()
452 ));
453 code.push_str(" prax_query::DeleteOperation::new(self.engine.clone())\n");
454 code.push_str(" }\n\n");
455
456 code.push_str(" /// Count records\n");
457 code.push_str(&format!(
458 " pub fn count(&self) -> prax_query::CountOperation<E, {}> {{\n",
459 model.name()
460 ));
461 code.push_str(" prax_query::CountOperation::new(self.engine.clone())\n");
462 code.push_str(" }\n");
463
464 code.push_str("}\n");
465
466 Ok(code)
467}
468
469fn generate_enum_module(enum_def: &prax_schema::ast::Enum) -> CliResult<String> {
471 let mut code = String::new();
472
473 code.push_str(&format!(
474 "//! Auto-generated module for {} enum\n\n",
475 enum_def.name()
476 ));
477
478 code.push_str(
479 "#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]\n",
480 );
481 code.push_str(&format!("pub enum {} {{\n", enum_def.name()));
482
483 for variant in &enum_def.variants {
484 let raw_name = variant.name();
485 let pascal_name = to_pascal_case(raw_name);
486
487 if let Some(attr) = variant.attributes.iter().find(|a| a.is("map")) {
489 if let Some(value) = attr.first_arg().and_then(|v| v.as_string()) {
490 code.push_str(&format!(" #[serde(rename = \"{}\")]\n", value));
491 code.push_str(&format!(" {},\n", pascal_name));
492 continue;
493 }
494 }
495
496 if raw_name != pascal_name {
498 code.push_str(&format!(" #[serde(rename = \"{}\")]\n", raw_name));
499 }
500 code.push_str(&format!(" {},\n", pascal_name));
501 }
502
503 code.push_str("}\n\n");
504
505 code.push_str(&format!("impl std::fmt::Display for {} {{\n", enum_def.name()));
507 code.push_str(" fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {\n");
508 code.push_str(" match self {\n");
509 for variant in &enum_def.variants {
510 let raw_name = variant.name();
511 let pascal_name = to_pascal_case(raw_name);
512 let db_value = variant.db_value();
513 code.push_str(&format!(
514 " Self::{} => write!(f, \"{}\"),\n",
515 pascal_name, db_value
516 ));
517 }
518 code.push_str(" }\n");
519 code.push_str(" }\n");
520 code.push_str("}\n\n");
521
522 if let Some(default_variant) = enum_def.variants.first() {
524 let pascal_name = to_pascal_case(default_variant.name());
525 code.push_str(&format!("impl Default for {} {{\n", enum_def.name()));
526 code.push_str(&format!(
527 " fn default() -> Self {{\n Self::{}\n }}\n",
528 pascal_name
529 ));
530 code.push_str("}\n");
531 }
532
533 Ok(code)
534}
535
536fn generate_types_module(schema: &prax_schema::ast::Schema) -> CliResult<String> {
538 let mut code = String::new();
539
540 code.push_str("//! Common type definitions\n\n");
541 code.push_str("pub use chrono::{DateTime, Utc};\n");
542 code.push_str("pub use uuid::Uuid;\n");
543 code.push_str("pub use serde_json::Value as Json;\n");
544 code.push_str("\n");
545
546 for composite in schema.types.values() {
548 code.push_str("#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]\n");
549 code.push_str(&format!("pub struct {} {{\n", composite.name()));
550 for field in composite.fields.values() {
551 let rust_type = field_type_to_rust(&field.field_type, field.modifier);
552 let field_name = to_snake_case(field.name());
553 code.push_str(&format!(" pub {}: {},\n", field_name, rust_type));
554 }
555 code.push_str("}\n\n");
556 }
557
558 Ok(code)
559}
560
561fn generate_filters_module(schema: &prax_schema::ast::Schema) -> CliResult<String> {
563 let mut code = String::new();
564
565 code.push_str("//! Filter types for queries\n\n");
566 code.push_str("use prax_query::filter::{Filter, ScalarFilter};\n");
567
568 let mut referenced_enums = HashSet::new();
570 for model in schema.models.values() {
571 for field in model.fields.values() {
572 if !field.is_relation() {
573 if let prax_schema::ast::FieldType::Enum(ref name) = field.field_type {
574 referenced_enums.insert(name.to_string());
575 }
576 }
577 }
578 }
579
580 for enum_name in &referenced_enums {
582 code.push_str(&format!(
583 "use super::{}::{};\n",
584 to_snake_case(enum_name),
585 enum_name
586 ));
587 }
588
589 code.push_str("\n");
590
591 for model in schema.models.values() {
592 code.push_str(&format!("/// Filter input for {} queries\n", model.name()));
594 code.push_str("#[derive(Debug, Default, Clone)]\n");
595 code.push_str(&format!("pub struct {}WhereInput {{\n", model.name()));
596
597 for field in model.fields.values() {
598 if !field.is_relation() {
599 let filter_type = field_to_filter_type(&field.field_type);
600 let field_name = to_snake_case(field.name());
601 code.push_str(&format!(
602 " pub {}: Option<{}>,\n",
603 field_name, filter_type
604 ));
605 }
606 }
607
608 code.push_str(" pub and: Option<Vec<Self>>,\n");
609 code.push_str(" pub or: Option<Vec<Self>>,\n");
610 code.push_str(" pub not: Option<Box<Self>>,\n");
611 code.push_str("}\n\n");
612
613 code.push_str(&format!(
615 "/// Order by input for {} queries\n",
616 model.name()
617 ));
618 code.push_str("#[derive(Debug, Default, Clone)]\n");
619 code.push_str(&format!("pub struct {}OrderByInput {{\n", model.name()));
620
621 for field in model.fields.values() {
622 if !field.is_relation() {
623 let field_name = to_snake_case(field.name());
624 code.push_str(&format!(
625 " pub {}: Option<prax_query::SortOrder>,\n",
626 field_name
627 ));
628 }
629 }
630
631 code.push_str("}\n\n");
632 }
633
634 Ok(code)
635}
636
637fn field_type_to_rust(
639 field_type: &prax_schema::ast::FieldType,
640 modifier: prax_schema::ast::TypeModifier,
641) -> String {
642 use prax_schema::ast::{FieldType, ScalarType, TypeModifier};
643
644 let base_type = match field_type {
645 FieldType::Scalar(scalar) => match scalar {
646 ScalarType::Int => "i32".to_string(),
647 ScalarType::BigInt => "i64".to_string(),
648 ScalarType::Float => "f64".to_string(),
649 ScalarType::String => "String".to_string(),
650 ScalarType::Boolean => "bool".to_string(),
651 ScalarType::DateTime => "chrono::DateTime<chrono::Utc>".to_string(),
652 ScalarType::Date => "chrono::NaiveDate".to_string(),
653 ScalarType::Time => "chrono::NaiveTime".to_string(),
654 ScalarType::Json => "serde_json::Value".to_string(),
655 ScalarType::Bytes => "Vec<u8>".to_string(),
656 ScalarType::Decimal => "rust_decimal::Decimal".to_string(),
657 ScalarType::Uuid => "uuid::Uuid".to_string(),
658 ScalarType::Cuid => "String".to_string(),
659 ScalarType::Cuid2 => "String".to_string(),
660 ScalarType::NanoId => "String".to_string(),
661 ScalarType::Ulid => "String".to_string(),
662 ScalarType::Vector(_) | ScalarType::HalfVector(_) => "Vec<f32>".to_string(),
663 ScalarType::SparseVector(_) => "Vec<(u32, f32)>".to_string(),
664 ScalarType::Bit(_) => "Vec<u8>".to_string(),
665 },
666 FieldType::Model(name) => name.to_string(),
667 FieldType::Enum(name) => name.to_string(),
668 FieldType::Composite(name) => name.to_string(),
669 FieldType::Unsupported(_) => "serde_json::Value".to_string(),
670 };
671
672 match modifier {
673 TypeModifier::Optional | TypeModifier::OptionalList => format!("Option<{}>", base_type),
674 TypeModifier::List => format!("Vec<{}>", base_type),
675 TypeModifier::Required => base_type,
676 }
677}
678
679fn field_type_to_rust_with_boxing(
681 field_type: &prax_schema::ast::FieldType,
682 modifier: prax_schema::ast::TypeModifier,
683 source_model: &str,
684 relation_graph: &HashMap<String, HashSet<String>>,
685) -> String {
686 use prax_schema::ast::{FieldType, TypeModifier};
687
688 if let FieldType::Model(target) = field_type {
690 if !matches!(modifier, TypeModifier::List) {
691 let should_box = needs_boxing(source_model, target, relation_graph);
692 let base = target.to_string();
693 return match modifier {
694 TypeModifier::Optional | TypeModifier::OptionalList => {
695 if should_box {
696 format!("Option<Box<{}>>", base)
697 } else {
698 format!("Option<{}>", base)
699 }
700 }
701 TypeModifier::Required => {
702 if should_box {
703 format!("Box<{}>", base)
704 } else {
705 base
706 }
707 }
708 TypeModifier::List => unreachable!(),
709 };
710 }
711 }
712
713 field_type_to_rust(field_type, modifier)
715}
716
717fn field_to_filter_type(field_type: &prax_schema::ast::FieldType) -> String {
719 use prax_schema::ast::{FieldType, ScalarType};
720
721 match field_type {
722 FieldType::Scalar(scalar) => match scalar {
723 ScalarType::Int | ScalarType::BigInt => "ScalarFilter<i64>".to_string(),
724 ScalarType::Float | ScalarType::Decimal => "ScalarFilter<f64>".to_string(),
725 ScalarType::String
726 | ScalarType::Uuid
727 | ScalarType::Cuid
728 | ScalarType::Cuid2
729 | ScalarType::NanoId
730 | ScalarType::Ulid => "ScalarFilter<String>".to_string(),
731 ScalarType::Boolean => "ScalarFilter<bool>".to_string(),
732 ScalarType::DateTime => "ScalarFilter<chrono::DateTime<chrono::Utc>>".to_string(),
733 ScalarType::Date => "ScalarFilter<chrono::NaiveDate>".to_string(),
734 ScalarType::Time => "ScalarFilter<chrono::NaiveTime>".to_string(),
735 ScalarType::Json => "ScalarFilter<serde_json::Value>".to_string(),
736 ScalarType::Bytes => "ScalarFilter<Vec<u8>>".to_string(),
737 ScalarType::Vector(_) | ScalarType::HalfVector(_) => "VectorFilter".to_string(),
739 ScalarType::SparseVector(_) => "SparseVectorFilter".to_string(),
740 ScalarType::Bit(_) => "BitFilter".to_string(),
741 },
742 FieldType::Enum(name) => format!("ScalarFilter<{}>", name),
743 _ => "Filter".to_string(),
744 }
745}
746
747fn to_snake_case(name: &str) -> String {
749 let mut result = String::new();
750 for (i, c) in name.chars().enumerate() {
751 if c.is_uppercase() {
752 if i > 0 {
753 result.push('_');
754 }
755 result.push(c.to_lowercase().next().unwrap());
756 } else {
757 result.push(c);
758 }
759 }
760 result
761}
762
763fn to_pascal_case(name: &str) -> String {
765 if name.is_empty() {
766 return String::new();
767 }
768
769 let first = name.chars().next().unwrap();
771 if first.is_uppercase() && name.chars().any(|c| c.is_lowercase()) && !name.contains('_') {
772 return name.to_string();
773 }
774
775 name.split('_')
777 .filter(|s| !s.is_empty())
778 .map(|segment| {
779 let mut chars = segment.chars();
780 match chars.next() {
781 None => String::new(),
782 Some(first) => {
783 let rest: String = chars.collect();
784 format!("{}{}", first.to_uppercase(), rest.to_lowercase())
785 }
786 }
787 })
788 .collect()
789}
790
791#[cfg(test)]
792mod tests {
793 use super::*;
794
795 #[test]
796 fn test_to_snake_case() {
797 assert_eq!(to_snake_case("BoardMember"), "board_member");
798 assert_eq!(to_snake_case("User"), "user");
799 assert_eq!(to_snake_case("JiraImportConfig"), "jira_import_config");
800 }
801
802 #[test]
803 fn test_to_pascal_case_from_snake() {
804 assert_eq!(to_pascal_case("card_created"), "CardCreated");
805 assert_eq!(to_pascal_case("branch_deleted"), "BranchDeleted");
806 assert_eq!(to_pascal_case("pr_merged"), "PrMerged");
807 }
808
809 #[test]
810 fn test_to_pascal_case_from_screaming() {
811 assert_eq!(to_pascal_case("CARD_CREATED"), "CardCreated");
812 assert_eq!(to_pascal_case("PR_MERGED"), "PrMerged");
813 }
814
815 #[test]
816 fn test_to_pascal_case_already_pascal() {
817 assert_eq!(to_pascal_case("Admin"), "Admin");
818 assert_eq!(to_pascal_case("SuperAdmin"), "SuperAdmin");
819 assert_eq!(to_pascal_case("Low"), "Low");
820 }
821
822 #[test]
823 fn test_to_pascal_case_single_word() {
824 assert_eq!(to_pascal_case("active"), "Active");
825 assert_eq!(to_pascal_case("ACTIVE"), "Active");
826 }
827
828 #[test]
829 fn test_needs_boxing_direct_cycle() {
830 let mut graph = HashMap::new();
831 graph.insert(
832 "Board".to_string(),
833 HashSet::from(["JiraConfig".to_string()]),
834 );
835 graph.insert(
836 "JiraConfig".to_string(),
837 HashSet::from(["Board".to_string()]),
838 );
839
840 assert!(needs_boxing("Board", "JiraConfig", &graph));
841 assert!(needs_boxing("JiraConfig", "Board", &graph));
842 }
843
844 #[test]
845 fn test_needs_boxing_no_cycle() {
846 let mut graph = HashMap::new();
847 graph.insert("Post".to_string(), HashSet::from(["User".to_string()]));
848 graph.insert("User".to_string(), HashSet::new());
849
850 assert!(!needs_boxing("Post", "User", &graph));
851 }
852
853 #[test]
854 fn test_needs_boxing_indirect_cycle() {
855 let mut graph = HashMap::new();
856 graph.insert("A".to_string(), HashSet::from(["B".to_string()]));
857 graph.insert("B".to_string(), HashSet::from(["C".to_string()]));
858 graph.insert("C".to_string(), HashSet::from(["A".to_string()]));
859
860 assert!(needs_boxing("A", "B", &graph));
861 assert!(needs_boxing("B", "C", &graph));
862 assert!(needs_boxing("C", "A", &graph));
863 }
864}