1use console::style;
2use quote::ToTokens;
3use std::fs;
4use std::path::Path;
5use syn::visit::Visit;
6use syn::{Attribute, Fields, ItemStruct, Type};
7use walkdir::WalkDir;
8
9#[derive(Debug, Clone)]
14struct ModelField {
15 name: String,
16 rust_type: String,
17 is_primary_key: bool,
18 is_nullable: bool,
19}
20
21struct ModelVisitor {
26 fields: Vec<ModelField>,
27 found: bool,
28}
29
30impl ModelVisitor {
31 fn new() -> Self {
32 Self {
33 fields: Vec::new(),
34 found: false,
35 }
36 }
37
38 fn has_model_derive(attrs: &[Attribute]) -> bool {
39 for attr in attrs {
40 if attr.path().is_ident("derive") {
41 if let Ok(nested) = attr.parse_args_with(
42 syn::punctuated::Punctuated::<syn::Path, syn::Token![,]>::parse_terminated,
43 ) {
44 for path in nested {
45 let ident = path.segments.last().map(|s| s.ident.to_string());
46 if matches!(
47 ident.as_deref(),
48 Some("DeriveEntityModel") | Some("FerroModel")
49 ) {
50 return true;
51 }
52 }
53 }
54 }
55 }
56 false
57 }
58
59 fn is_field_primary_key(attrs: &[Attribute]) -> bool {
60 for attr in attrs {
61 if attr.path().is_ident("sea_orm") {
62 let tokens = attr.meta.to_token_stream().to_string();
63 if tokens.contains("primary_key") {
64 return true;
65 }
66 }
67 }
68 false
69 }
70
71 fn type_to_string(ty: &Type) -> String {
72 ty.to_token_stream().to_string().replace(' ', "")
73 }
74
75 fn extract_fields(fields: &Fields) -> Vec<ModelField> {
76 let mut result = Vec::new();
77 if let Fields::Named(named) = fields {
78 for field in &named.named {
79 if let Some(ident) = &field.ident {
80 let name = ident.to_string();
81 let rust_type = Self::type_to_string(&field.ty);
82 let is_nullable = rust_type.starts_with("Option<");
83 let is_primary_key = Self::is_field_primary_key(&field.attrs);
84 result.push(ModelField {
85 name,
86 rust_type,
87 is_primary_key,
88 is_nullable,
89 });
90 }
91 }
92 }
93 result
94 }
95}
96
97impl<'ast> Visit<'ast> for ModelVisitor {
98 fn visit_item_struct(&mut self, node: &'ast ItemStruct) {
99 if Self::has_model_derive(&node.attrs) && node.ident == "Model" {
100 self.fields = Self::extract_fields(&node.fields);
101 self.found = true;
102 }
103 syn::visit::visit_item_struct(self, node);
104 }
105}
106
107fn scan_models(project_root: &Path) -> Vec<(String, Vec<ModelField>)> {
113 let models_dir = project_root.join("src/models");
114 if !models_dir.exists() || !models_dir.is_dir() {
115 return Vec::new();
116 }
117
118 let mut results = Vec::new();
119
120 for entry in WalkDir::new(&models_dir)
121 .into_iter()
122 .filter_map(|e| e.ok())
123 .filter(|e| e.path().extension().is_some_and(|ext| ext == "rs"))
124 {
125 let file_stem = entry
126 .path()
127 .file_stem()
128 .map(|s| s.to_string_lossy().to_string())
129 .unwrap_or_default();
130
131 if file_stem == "mod" {
132 continue;
133 }
134
135 let Ok(content) = fs::read_to_string(entry.path()) else {
136 continue;
137 };
138 let Ok(syntax) = syn::parse_file(&content) else {
139 continue;
140 };
141
142 let mut visitor = ModelVisitor::new();
143 visitor.visit_file(&syntax);
144
145 if visitor.found {
146 let is_entity_file = entry
147 .path()
148 .parent()
149 .and_then(|p| p.file_name())
150 .is_some_and(|dir| dir == "entities");
151
152 let singular_stem = if is_entity_file {
153 singularize(&file_stem)
154 } else {
155 file_stem.clone()
156 };
157
158 results.push((singular_stem, visitor.fields));
159 }
160 }
161
162 results
163}
164
165#[cfg(test)]
167fn scan_models_from_dir(models_dir: &Path) -> Vec<(String, Vec<ModelField>)> {
168 if !models_dir.exists() || !models_dir.is_dir() {
169 return Vec::new();
170 }
171
172 let mut results = Vec::new();
173
174 for entry in WalkDir::new(models_dir)
175 .into_iter()
176 .filter_map(|e| e.ok())
177 .filter(|e| e.path().extension().is_some_and(|ext| ext == "rs"))
178 {
179 let file_stem = entry
180 .path()
181 .file_stem()
182 .map(|s| s.to_string_lossy().to_string())
183 .unwrap_or_default();
184
185 if file_stem == "mod" {
186 continue;
187 }
188
189 let Ok(content) = fs::read_to_string(entry.path()) else {
190 continue;
191 };
192 let Ok(syntax) = syn::parse_file(&content) else {
193 continue;
194 };
195
196 let mut visitor = ModelVisitor::new();
197 visitor.visit_file(&syntax);
198
199 if visitor.found {
200 results.push((file_stem, visitor.fields));
201 }
202 }
203
204 results
205}
206
207fn singularize(name: &str) -> String {
208 if name.ends_with("ies") && name.len() > 3 {
209 format!("{}y", &name[..name.len() - 3])
210 } else if name.ends_with("ses")
211 || name.ends_with("xes")
212 || name.ends_with("ches")
213 || name.ends_with("shes")
214 {
215 name[..name.len() - 2].to_string()
216 } else if name.ends_with('s') && !name.ends_with("ss") {
217 name[..name.len() - 1].to_string()
218 } else {
219 name.to_string()
220 }
221}
222
223fn rust_type_to_data_type(rust_type: &str) -> &'static str {
229 let inner = unwrap_option(rust_type);
230
231 match inner {
232 "String" | "&str" => "DataType::String",
233 "i8" | "i16" | "i32" | "i64" | "u8" | "u16" | "u32" | "u64" | "usize" | "isize" => {
234 "DataType::Integer"
235 }
236 "f32" | "f64" | "Decimal" => "DataType::Float",
237 "bool" => "DataType::Boolean",
238 "DateTime"
239 | "NaiveDateTime"
240 | "DateTimeUtc"
241 | "DateTimeWithTimeZone"
242 | "chrono::DateTime<chrono::Utc>"
243 | "chrono::NaiveDateTime" => "DataType::DateTime",
244 "NaiveDate" | "Date" | "chrono::NaiveDate" => "DataType::Date",
245 "Uuid" | "uuid::Uuid" => "DataType::Uuid",
246 "Vec<u8>" => "DataType::Binary",
247 "Json" | "serde_json::Value" | "JsonValue" => "DataType::Json",
248 _ => "DataType::String",
249 }
250}
251
252fn unwrap_option(ty: &str) -> &str {
254 if let Some(inner) = ty.strip_prefix("Option<") {
255 if let Some(inner) = inner.strip_suffix('>') {
256 return inner;
257 }
258 }
259 ty
260}
261
262fn infer_meaning(field_name: &str) -> &'static str {
265 match field_name {
266 "id" => "FieldMeaning::Identifier",
267 "email" => "FieldMeaning::Email",
268 "created_at" => "FieldMeaning::CreatedAt",
269 "updated_at" => "FieldMeaning::UpdatedAt",
270 _ => {
271 if field_name.ends_with("_id") {
272 return "FieldMeaning::ForeignKey";
273 }
274 if field_name.ends_with("_at") {
275 return "FieldMeaning::DateTime";
276 }
277 if field_name.starts_with("is_") || field_name.starts_with("has_") {
278 return "FieldMeaning::Boolean";
279 }
280 if is_sensitive_field(field_name) {
281 return "FieldMeaning::Sensitive";
282 }
283 "custom"
284 }
285 }
286}
287
288fn is_sensitive_field(field_name: &str) -> bool {
290 const SENSITIVE: &[&str] = &["password", "secret", "token", "api_key", "hashed_key"];
291 SENSITIVE.iter().any(|s| field_name.contains(s))
292}
293
294fn field_builder_method(field: &ModelField, meaning: &str) -> Option<&'static str> {
297 if meaning == "FieldMeaning::Sensitive" {
298 return None;
299 }
300
301 if field.is_primary_key
302 || meaning == "FieldMeaning::CreatedAt"
303 || meaning == "FieldMeaning::UpdatedAt"
304 || meaning == "FieldMeaning::ForeignKey"
305 {
306 return Some("read_only_field");
307 }
308
309 if field.is_nullable {
310 return Some("optional_field");
311 }
312
313 Some("field")
314}
315
316fn model_aware_template(name: &str, display_name: &str, fields: &[ModelField]) -> String {
322 let mut field_lines = Vec::new();
323 let mut belongs_to_lines = Vec::new();
324
325 for field in fields {
326 let meaning = infer_meaning(&field.name);
327 let data_type = rust_type_to_data_type(&field.rust_type);
328
329 let Some(builder) = field_builder_method(field, meaning) else {
330 continue;
331 };
332
333 let meaning_str = if meaning == "custom" {
334 format!("FieldMeaning::Custom(\"{}\".into())", field.name)
335 } else {
336 meaning.to_string()
337 };
338
339 field_lines.push(format!(
340 " .{builder}(\"{}\", {data_type}, {meaning_str})",
341 field.name
342 ));
343
344 if field.name.ends_with("_id") && field.name.len() > 3 {
345 let rel_name = &field.name[..field.name.len() - 3];
346 belongs_to_lines.push(format!(
347 " .belongs_to(\"{rel_name}\", \"{rel_name}\")"
348 ));
349 }
350 }
351
352 let all_lines: Vec<String> = field_lines.into_iter().chain(belongs_to_lines).collect();
353
354 let builder_calls = all_lines.join("\n");
355
356 format!(
357 r#"use ferro::{{
358 DataType, FieldMeaning, ServiceDef,
359}};
360
361/// Build the {display_name} service projection.
362///
363/// Derived from the {display_name} model.
364/// Describes the {display_name} entity's fields, relationships,
365/// and behavioral semantics for intent derivation and UI rendering.
366pub fn {name}_service() -> ServiceDef {{
367 ServiceDef::new("{name}")
368 .display_name("{display_name}")
369{builder_calls}
370}}
371"#
372 )
373}
374
375pub fn execute(name: &str, from_model: bool) {
380 let file_name = to_snake_case(name);
381
382 if !is_valid_identifier(&file_name) {
383 eprintln!(
384 "{} '{}' is not a valid projection name",
385 style("Error:").red().bold(),
386 name
387 );
388 std::process::exit(1);
389 }
390
391 let display_name = to_pascal_case(name);
392 let projections_dir = Path::new("src/projections");
393 let projection_file = projections_dir.join(format!("{file_name}.rs"));
394 let mod_file = projections_dir.join("mod.rs");
395
396 if !projections_dir.exists() {
397 if let Err(e) = fs::create_dir_all(projections_dir) {
398 eprintln!(
399 "{} Failed to create src/projections directory: {}",
400 style("Error:").red().bold(),
401 e
402 );
403 std::process::exit(1);
404 }
405 println!("{} Created src/projections/", style("✓").green());
406 }
407
408 if projection_file.exists() {
409 eprintln!(
410 "{} Projection '{}' already exists at {}",
411 style("Info:").yellow().bold(),
412 file_name,
413 projection_file.display()
414 );
415 std::process::exit(0);
416 }
417
418 if mod_file.exists() {
419 let mod_content = fs::read_to_string(&mod_file).unwrap_or_default();
420 let mod_decl = format!("mod {file_name};");
421 let pub_mod_decl = format!("pub mod {file_name};");
422 if mod_content.contains(&mod_decl) || mod_content.contains(&pub_mod_decl) {
423 eprintln!(
424 "{} Module '{}' is already declared in src/projections/mod.rs",
425 style("Info:").yellow().bold(),
426 file_name
427 );
428 std::process::exit(0);
429 }
430 }
431
432 let content = if from_model {
433 let project_root = Path::new(".");
434 let available = scan_models(project_root);
435
436 let matched = available
437 .iter()
438 .find(|(sn, _)| sn.eq_ignore_ascii_case(&file_name));
439
440 match matched {
441 Some((_, fields)) => {
442 println!(
443 "{} Found model '{}' with {} fields",
444 style("✓").green(),
445 display_name,
446 fields.len()
447 );
448 model_aware_template(&file_name, &display_name, fields)
449 }
450 None => {
451 let model_names: Vec<&str> = available.iter().map(|(n, _)| n.as_str()).collect();
452 if model_names.is_empty() {
453 eprintln!(
454 "{} No models found in src/models/",
455 style("Error:").red().bold()
456 );
457 } else {
458 eprintln!(
459 "{} Model '{}' not found. Available models: {}",
460 style("Error:").red().bold(),
461 file_name,
462 model_names.join(", ")
463 );
464 }
465 std::process::exit(1);
466 }
467 }
468 } else {
469 projection_template(&file_name, &display_name)
470 };
471
472 if let Err(e) = fs::write(&projection_file, &content) {
473 eprintln!(
474 "{} Failed to write projection file: {}",
475 style("Error:").red().bold(),
476 e
477 );
478 std::process::exit(1);
479 }
480 println!(
481 "{} Created {}",
482 style("✓").green(),
483 projection_file.display()
484 );
485
486 if mod_file.exists() {
487 if let Err(e) = update_mod_file(&mod_file, &file_name) {
488 eprintln!(
489 "{} Failed to update mod.rs: {}",
490 style("Error:").red().bold(),
491 e
492 );
493 std::process::exit(1);
494 }
495 println!("{} Updated src/projections/mod.rs", style("✓").green());
496 } else {
497 let mod_content = format!("pub mod {file_name};\n");
498 if let Err(e) = fs::write(&mod_file, mod_content) {
499 eprintln!(
500 "{} Failed to create mod.rs: {}",
501 style("Error:").red().bold(),
502 e
503 );
504 std::process::exit(1);
505 }
506 println!("{} Created src/projections/mod.rs", style("✓").green());
507 }
508
509 println!();
510 println!(
511 "Projection {} created successfully!",
512 style(&file_name).cyan().bold()
513 );
514 println!();
515 println!("Usage:");
516 println!(
517 " {} Define fields matching your model in src/projections/{file_name}.rs",
518 style("1.").dim()
519 );
520 println!(" {} Use in a handler:", style("2.").dim());
521 println!(" use crate::projections::{file_name};");
522 println!();
523 println!(" let service = {file_name}::{file_name}_service();");
524 println!(" let intents = derive_intents(&service);");
525 println!();
526}
527
528fn projection_template(name: &str, display_name: &str) -> String {
529 format!(
530 r#"use ferro::{{
531 DataType, FieldMeaning, ServiceDef,
532}};
533
534/// Build the {display_name} service projection.
535///
536/// Describes the {display_name} entity's fields, relationships,
537/// and behavioral semantics for intent derivation and UI rendering.
538pub fn {name}_service() -> ServiceDef {{
539 ServiceDef::new("{name}")
540 .display_name("{display_name}")
541 .field("id", DataType::Integer, FieldMeaning::Identifier)
542 // Add fields matching your model:
543 // .field("name", DataType::String, FieldMeaning::EntityName)
544 // .field("email", DataType::String, FieldMeaning::Email)
545 // .field("status", DataType::String, FieldMeaning::Status)
546 // .field("created_at", DataType::DateTime, FieldMeaning::CreatedAt)
547 // .field("updated_at", DataType::DateTime, FieldMeaning::UpdatedAt)
548}}
549"#
550 )
551}
552
553fn is_valid_identifier(name: &str) -> bool {
554 if name.is_empty() {
555 return false;
556 }
557
558 let mut chars = name.chars();
559
560 match chars.next() {
561 Some(c) if c.is_alphabetic() || c == '_' => {}
562 _ => return false,
563 }
564
565 chars.all(|c| c.is_alphanumeric() || c == '_')
566}
567
568fn to_snake_case(s: &str) -> String {
569 let mut result = String::new();
570 for (i, c) in s.chars().enumerate() {
571 if c.is_uppercase() {
572 if i > 0 {
573 result.push('_');
574 }
575 result.push(c.to_lowercase().next().unwrap());
576 } else {
577 result.push(c);
578 }
579 }
580 result
581}
582
583fn to_pascal_case(s: &str) -> String {
584 let mut result = String::new();
585 let mut capitalize_next = true;
586
587 for c in s.chars() {
588 if c == '_' || c == '-' || c == ' ' {
589 capitalize_next = true;
590 } else if capitalize_next {
591 result.push(c.to_uppercase().next().unwrap());
592 capitalize_next = false;
593 } else {
594 result.push(c);
595 }
596 }
597 result
598}
599
600#[cfg(test)]
603fn generate_in_dir(
604 base_dir: &Path,
605 name: &str,
606) -> Result<(std::path::PathBuf, std::path::PathBuf), String> {
607 let file_name = to_snake_case(name);
608 let display_name = to_pascal_case(name);
609 let projections_dir = base_dir.join("src/projections");
610 let projection_file = projections_dir.join(format!("{file_name}.rs"));
611 let mod_file = projections_dir.join("mod.rs");
612
613 fs::create_dir_all(&projections_dir)
614 .map_err(|e| format!("Failed to create projections directory: {e}"))?;
615
616 let content = projection_template(&file_name, &display_name);
617 fs::write(&projection_file, content)
618 .map_err(|e| format!("Failed to write projection file: {e}"))?;
619
620 if mod_file.exists() {
621 let mod_content = fs::read_to_string(&mod_file).unwrap_or_default();
622 let pub_mod_decl = format!("pub mod {file_name};");
623 if !mod_content.contains(&pub_mod_decl) {
624 update_mod_file(&mod_file, &file_name)?;
625 }
626 } else {
627 let mod_content = format!("pub mod {file_name};\n");
628 fs::write(&mod_file, mod_content).map_err(|e| format!("Failed to create mod.rs: {e}"))?;
629 }
630
631 Ok((projection_file, mod_file))
632}
633
634fn update_mod_file(mod_file: &Path, file_name: &str) -> Result<(), String> {
635 let content =
636 fs::read_to_string(mod_file).map_err(|e| format!("Failed to read mod.rs: {e}"))?;
637
638 let pub_mod_decl = format!("pub mod {file_name};");
639
640 let mut lines: Vec<&str> = content.lines().collect();
641
642 let mut last_pub_mod_idx = None;
643 for (i, line) in lines.iter().enumerate() {
644 if line.trim().starts_with("pub mod ") {
645 last_pub_mod_idx = Some(i);
646 }
647 }
648
649 let insert_idx = match last_pub_mod_idx {
650 Some(idx) => idx + 1,
651 None => {
652 let mut insert_idx = 0;
653 for (i, line) in lines.iter().enumerate() {
654 if line.starts_with("//!") || line.is_empty() {
655 insert_idx = i + 1;
656 } else {
657 break;
658 }
659 }
660 insert_idx
661 }
662 };
663 lines.insert(insert_idx, &pub_mod_decl);
664
665 let new_content = lines.join("\n");
666 fs::write(mod_file, new_content).map_err(|e| format!("Failed to write mod.rs: {e}"))?;
667
668 Ok(())
669}
670
671#[cfg(test)]
672mod tests {
673 use super::*;
674 use tempfile::TempDir;
675
676 #[test]
677 fn test_template_generation() {
678 let template = projection_template("user", "User");
679
680 assert!(template.contains("pub fn user_service() -> ServiceDef"));
681 assert!(template.contains("ServiceDef::new(\"user\")"));
682 assert!(template.contains(".display_name(\"User\")"));
683 assert!(template.contains("DataType::Integer, FieldMeaning::Identifier"));
684 assert!(template.contains("use ferro::{"));
685 assert!(template.contains("/// Build the User service projection."));
686 }
687
688 #[test]
689 fn test_creates_directory_and_file() {
690 let tmp = TempDir::new().unwrap();
691 let (proj_file, _mod_file) = generate_in_dir(tmp.path(), "order").unwrap();
692
693 assert!(tmp.path().join("src/projections").exists());
694 assert!(proj_file.exists());
695
696 let content = fs::read_to_string(&proj_file).unwrap();
697 assert!(content.contains("pub fn order_service() -> ServiceDef"));
698 assert!(content.contains(".display_name(\"Order\")"));
699 }
700
701 #[test]
702 fn test_mod_rs_creation() {
703 let tmp = TempDir::new().unwrap();
704 let (_proj_file, mod_file) = generate_in_dir(tmp.path(), "product").unwrap();
705
706 assert!(mod_file.exists());
707 let mod_content = fs::read_to_string(&mod_file).unwrap();
708 assert!(mod_content.contains("pub mod product;"));
709 }
710
711 #[test]
712 fn test_mod_rs_append() {
713 let tmp = TempDir::new().unwrap();
714
715 generate_in_dir(tmp.path(), "user").unwrap();
716 let mod_file = tmp.path().join("src/projections/mod.rs");
717 let content = fs::read_to_string(&mod_file).unwrap();
718 assert!(content.contains("pub mod user;"));
719
720 generate_in_dir(tmp.path(), "order").unwrap();
721 let content = fs::read_to_string(&mod_file).unwrap();
722 assert!(content.contains("pub mod user;"));
723 assert!(content.contains("pub mod order;"));
724
725 generate_in_dir(tmp.path(), "order").unwrap();
726 let content = fs::read_to_string(&mod_file).unwrap();
727 let count = content.matches("pub mod order;").count();
728 assert_eq!(count, 1, "pub mod order; should appear exactly once");
729 }
730
731 fn write_mock_model(dir: &Path, filename: &str, content: &str) {
734 fs::create_dir_all(dir).unwrap();
735 fs::write(dir.join(filename), content).unwrap();
736 }
737
738 #[test]
739 fn test_model_aware_template_basic() {
740 let tmp = TempDir::new().unwrap();
741 let models_dir = tmp.path().join("models");
742 write_mock_model(
743 &models_dir,
744 "user.rs",
745 r#"
746use sea_orm::entity::prelude::*;
747
748#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
749#[sea_orm(table_name = "users")]
750pub struct Model {
751 #[sea_orm(primary_key)]
752 pub id: i32,
753 pub name: String,
754 pub email: String,
755 pub created_at: DateTime,
756 pub updated_at: DateTime,
757}
758"#,
759 );
760
761 let models = scan_models_from_dir(&models_dir);
762 assert_eq!(models.len(), 1);
763 let (name, fields) = &models[0];
764 assert_eq!(name, "user");
765
766 let output = model_aware_template("user", "User", fields);
767
768 assert!(output.contains("pub fn user_service() -> ServiceDef"));
769 assert!(output.contains("Derived from the User model"));
770 assert!(output
771 .contains(r#".read_only_field("id", DataType::Integer, FieldMeaning::Identifier)"#));
772 assert!(output
773 .contains(r#".field("name", DataType::String, FieldMeaning::Custom("name".into()))"#));
774 assert!(output.contains(r#".field("email", DataType::String, FieldMeaning::Email)"#));
775 assert!(output.contains(
776 r#".read_only_field("created_at", DataType::DateTime, FieldMeaning::CreatedAt)"#
777 ));
778 assert!(output.contains(
779 r#".read_only_field("updated_at", DataType::DateTime, FieldMeaning::UpdatedAt)"#
780 ));
781 }
782
783 #[test]
784 fn test_model_aware_excludes_sensitive() {
785 let tmp = TempDir::new().unwrap();
786 let models_dir = tmp.path().join("models");
787 write_mock_model(
788 &models_dir,
789 "user.rs",
790 r#"
791use sea_orm::entity::prelude::*;
792
793#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
794#[sea_orm(table_name = "users")]
795pub struct Model {
796 #[sea_orm(primary_key)]
797 pub id: i32,
798 pub name: String,
799 pub password_hash: String,
800 pub remember_token: Option<String>,
801}
802"#,
803 );
804
805 let models = scan_models_from_dir(&models_dir);
806 let (_, fields) = &models[0];
807 let output = model_aware_template("user", "User", fields);
808
809 assert!(output.contains(r#".read_only_field("id""#));
810 assert!(output.contains(r#".field("name""#));
811 assert!(!output.contains("password_hash"));
812 assert!(!output.contains("remember_token"));
813 }
814
815 #[test]
816 fn test_model_aware_foreign_keys() {
817 let tmp = TempDir::new().unwrap();
818 let models_dir = tmp.path().join("models");
819 write_mock_model(
820 &models_dir,
821 "order.rs",
822 r#"
823use sea_orm::entity::prelude::*;
824
825#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
826#[sea_orm(table_name = "orders")]
827pub struct Model {
828 #[sea_orm(primary_key)]
829 pub id: i32,
830 pub user_id: i32,
831 pub total: f64,
832}
833"#,
834 );
835
836 let models = scan_models_from_dir(&models_dir);
837 let (_, fields) = &models[0];
838 let output = model_aware_template("order", "Order", fields);
839
840 assert!(output.contains(
841 r#".read_only_field("user_id", DataType::Integer, FieldMeaning::ForeignKey)"#
842 ));
843 assert!(output.contains(r#".belongs_to("user", "user")"#));
844 }
845
846 #[test]
847 fn test_model_aware_optional_fields() {
848 let tmp = TempDir::new().unwrap();
849 let models_dir = tmp.path().join("models");
850 write_mock_model(
851 &models_dir,
852 "post.rs",
853 r#"
854use sea_orm::entity::prelude::*;
855
856#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
857#[sea_orm(table_name = "posts")]
858pub struct Model {
859 #[sea_orm(primary_key)]
860 pub id: i32,
861 pub title: String,
862 pub notes: Option<String>,
863}
864"#,
865 );
866
867 let models = scan_models_from_dir(&models_dir);
868 let (_, fields) = &models[0];
869 let output = model_aware_template("post", "Post", fields);
870
871 assert!(output.contains(
872 r#".optional_field("notes", DataType::String, FieldMeaning::Custom("notes".into()))"#
873 ));
874 }
875
876 #[test]
877 fn test_rust_type_to_data_type() {
878 assert_eq!(rust_type_to_data_type("String"), "DataType::String");
879 assert_eq!(rust_type_to_data_type("i32"), "DataType::Integer");
880 assert_eq!(rust_type_to_data_type("i64"), "DataType::Integer");
881 assert_eq!(rust_type_to_data_type("u32"), "DataType::Integer");
882 assert_eq!(rust_type_to_data_type("f32"), "DataType::Float");
883 assert_eq!(rust_type_to_data_type("f64"), "DataType::Float");
884 assert_eq!(rust_type_to_data_type("bool"), "DataType::Boolean");
885 assert_eq!(rust_type_to_data_type("DateTime"), "DataType::DateTime");
886 assert_eq!(
887 rust_type_to_data_type("NaiveDateTime"),
888 "DataType::DateTime"
889 );
890 assert_eq!(rust_type_to_data_type("NaiveDate"), "DataType::Date");
891 assert_eq!(rust_type_to_data_type("Uuid"), "DataType::Uuid");
892 assert_eq!(rust_type_to_data_type("Vec<u8>"), "DataType::Binary");
893 assert_eq!(rust_type_to_data_type("Json"), "DataType::Json");
894 assert_eq!(rust_type_to_data_type("Option<String>"), "DataType::String");
895 assert_eq!(rust_type_to_data_type("Option<i32>"), "DataType::Integer");
896 assert_eq!(
897 rust_type_to_data_type("SomeUnknownType"),
898 "DataType::String"
899 );
900 }
901
902 #[test]
903 fn test_infer_field_meaning() {
904 assert_eq!(infer_meaning("id"), "FieldMeaning::Identifier");
905 assert_eq!(infer_meaning("email"), "FieldMeaning::Email");
906 assert_eq!(infer_meaning("created_at"), "FieldMeaning::CreatedAt");
907 assert_eq!(infer_meaning("updated_at"), "FieldMeaning::UpdatedAt");
908 assert_eq!(infer_meaning("user_id"), "FieldMeaning::ForeignKey");
909 assert_eq!(infer_meaning("deleted_at"), "FieldMeaning::DateTime");
910 assert_eq!(infer_meaning("is_active"), "FieldMeaning::Boolean");
911 assert_eq!(infer_meaning("has_premium"), "FieldMeaning::Boolean");
912 assert_eq!(infer_meaning("password"), "FieldMeaning::Sensitive");
913 assert_eq!(infer_meaning("remember_token"), "FieldMeaning::Sensitive");
914 assert_eq!(infer_meaning("title"), "custom");
915 }
916}