Skip to main content

ferro_cli/commands/
make_projection.rs

1use console::style;
2use quote::ToTokens;
3use std::fs;
4use std::path::Path;
5use syn::visit::Visit;
6use syn::{Attribute, Fields, ItemStruct, Type};
7use walkdir::WalkDir;
8
9// ---------------------------------------------------------------------------
10// Model field metadata (self-contained, mirrors make_api::FieldInfo)
11// ---------------------------------------------------------------------------
12
13#[derive(Debug, Clone)]
14struct ModelField {
15    name: String,
16    rust_type: String,
17    is_primary_key: bool,
18    is_nullable: bool,
19}
20
21// ---------------------------------------------------------------------------
22// AST visitor for model detection
23// ---------------------------------------------------------------------------
24
25struct ModelVisitor {
26    fields: Vec<ModelField>,
27    found: bool,
28}
29
30impl ModelVisitor {
31    fn new() -> Self {
32        Self {
33            fields: Vec::new(),
34            found: false,
35        }
36    }
37
38    fn has_model_derive(attrs: &[Attribute]) -> bool {
39        for attr in attrs {
40            if attr.path().is_ident("derive") {
41                if let Ok(nested) = attr.parse_args_with(
42                    syn::punctuated::Punctuated::<syn::Path, syn::Token![,]>::parse_terminated,
43                ) {
44                    for path in nested {
45                        let ident = path.segments.last().map(|s| s.ident.to_string());
46                        if matches!(
47                            ident.as_deref(),
48                            Some("DeriveEntityModel") | Some("FerroModel")
49                        ) {
50                            return true;
51                        }
52                    }
53                }
54            }
55        }
56        false
57    }
58
59    fn is_field_primary_key(attrs: &[Attribute]) -> bool {
60        for attr in attrs {
61            if attr.path().is_ident("sea_orm") {
62                let tokens = attr.meta.to_token_stream().to_string();
63                if tokens.contains("primary_key") {
64                    return true;
65                }
66            }
67        }
68        false
69    }
70
71    fn type_to_string(ty: &Type) -> String {
72        ty.to_token_stream().to_string().replace(' ', "")
73    }
74
75    fn extract_fields(fields: &Fields) -> Vec<ModelField> {
76        let mut result = Vec::new();
77        if let Fields::Named(named) = fields {
78            for field in &named.named {
79                if let Some(ident) = &field.ident {
80                    let name = ident.to_string();
81                    let rust_type = Self::type_to_string(&field.ty);
82                    let is_nullable = rust_type.starts_with("Option<");
83                    let is_primary_key = Self::is_field_primary_key(&field.attrs);
84                    result.push(ModelField {
85                        name,
86                        rust_type,
87                        is_primary_key,
88                        is_nullable,
89                    });
90                }
91            }
92        }
93        result
94    }
95}
96
97impl<'ast> Visit<'ast> for ModelVisitor {
98    fn visit_item_struct(&mut self, node: &'ast ItemStruct) {
99        if Self::has_model_derive(&node.attrs) && node.ident == "Model" {
100            self.fields = Self::extract_fields(&node.fields);
101            self.found = true;
102        }
103        syn::visit::visit_item_struct(self, node);
104    }
105}
106
107// ---------------------------------------------------------------------------
108// Model scanning
109// ---------------------------------------------------------------------------
110
111/// Scan `src/models/` for model files and return matching model fields.
112fn scan_models(project_root: &Path) -> Vec<(String, Vec<ModelField>)> {
113    let models_dir = project_root.join("src/models");
114    if !models_dir.exists() || !models_dir.is_dir() {
115        return Vec::new();
116    }
117
118    let mut results = Vec::new();
119
120    for entry in WalkDir::new(&models_dir)
121        .into_iter()
122        .filter_map(|e| e.ok())
123        .filter(|e| e.path().extension().is_some_and(|ext| ext == "rs"))
124    {
125        let file_stem = entry
126            .path()
127            .file_stem()
128            .map(|s| s.to_string_lossy().to_string())
129            .unwrap_or_default();
130
131        if file_stem == "mod" {
132            continue;
133        }
134
135        let Ok(content) = fs::read_to_string(entry.path()) else {
136            continue;
137        };
138        let Ok(syntax) = syn::parse_file(&content) else {
139            continue;
140        };
141
142        let mut visitor = ModelVisitor::new();
143        visitor.visit_file(&syntax);
144
145        if visitor.found {
146            let is_entity_file = entry
147                .path()
148                .parent()
149                .and_then(|p| p.file_name())
150                .is_some_and(|dir| dir == "entities");
151
152            let singular_stem = if is_entity_file {
153                singularize(&file_stem)
154            } else {
155                file_stem.clone()
156            };
157
158            results.push((singular_stem, visitor.fields));
159        }
160    }
161
162    results
163}
164
165/// Scan models from a specific directory (for testing).
166#[cfg(test)]
167fn scan_models_from_dir(models_dir: &Path) -> Vec<(String, Vec<ModelField>)> {
168    if !models_dir.exists() || !models_dir.is_dir() {
169        return Vec::new();
170    }
171
172    let mut results = Vec::new();
173
174    for entry in WalkDir::new(models_dir)
175        .into_iter()
176        .filter_map(|e| e.ok())
177        .filter(|e| e.path().extension().is_some_and(|ext| ext == "rs"))
178    {
179        let file_stem = entry
180            .path()
181            .file_stem()
182            .map(|s| s.to_string_lossy().to_string())
183            .unwrap_or_default();
184
185        if file_stem == "mod" {
186            continue;
187        }
188
189        let Ok(content) = fs::read_to_string(entry.path()) else {
190            continue;
191        };
192        let Ok(syntax) = syn::parse_file(&content) else {
193            continue;
194        };
195
196        let mut visitor = ModelVisitor::new();
197        visitor.visit_file(&syntax);
198
199        if visitor.found {
200            results.push((file_stem, visitor.fields));
201        }
202    }
203
204    results
205}
206
207fn singularize(name: &str) -> String {
208    if name.ends_with("ies") && name.len() > 3 {
209        format!("{}y", &name[..name.len() - 3])
210    } else if name.ends_with("ses")
211        || name.ends_with("xes")
212        || name.ends_with("ches")
213        || name.ends_with("shes")
214    {
215        name[..name.len() - 2].to_string()
216    } else if name.ends_with('s') && !name.ends_with("ss") {
217        name[..name.len() - 1].to_string()
218    } else {
219        name.to_string()
220    }
221}
222
223// ---------------------------------------------------------------------------
224// Type and meaning mapping
225// ---------------------------------------------------------------------------
226
227/// Map a Rust type string to a DataType name for code generation.
228fn rust_type_to_data_type(rust_type: &str) -> &'static str {
229    let inner = unwrap_option(rust_type);
230
231    match inner {
232        "String" | "&str" => "DataType::String",
233        "i8" | "i16" | "i32" | "i64" | "u8" | "u16" | "u32" | "u64" | "usize" | "isize" => {
234            "DataType::Integer"
235        }
236        "f32" | "f64" | "Decimal" => "DataType::Float",
237        "bool" => "DataType::Boolean",
238        "DateTime"
239        | "NaiveDateTime"
240        | "DateTimeUtc"
241        | "DateTimeWithTimeZone"
242        | "chrono::DateTime<chrono::Utc>"
243        | "chrono::NaiveDateTime" => "DataType::DateTime",
244        "NaiveDate" | "Date" | "chrono::NaiveDate" => "DataType::Date",
245        "Uuid" | "uuid::Uuid" => "DataType::Uuid",
246        "Vec<u8>" => "DataType::Binary",
247        "Json" | "serde_json::Value" | "JsonValue" => "DataType::Json",
248        _ => "DataType::String",
249    }
250}
251
252/// Unwrap `Option<T>` to get the inner type string.
253fn unwrap_option(ty: &str) -> &str {
254    if let Some(inner) = ty.strip_prefix("Option<") {
255        if let Some(inner) = inner.strip_suffix('>') {
256            return inner;
257        }
258    }
259    ty
260}
261
262/// Infer a FieldMeaning variant name from a field name.
263/// Replicates the logic from ferro-projections/src/field.rs.
264fn infer_meaning(field_name: &str) -> &'static str {
265    match field_name {
266        "id" => "FieldMeaning::Identifier",
267        "email" => "FieldMeaning::Email",
268        "created_at" => "FieldMeaning::CreatedAt",
269        "updated_at" => "FieldMeaning::UpdatedAt",
270        _ => {
271            if field_name.ends_with("_id") {
272                return "FieldMeaning::ForeignKey";
273            }
274            if field_name.ends_with("_at") {
275                return "FieldMeaning::DateTime";
276            }
277            if field_name.starts_with("is_") || field_name.starts_with("has_") {
278                return "FieldMeaning::Boolean";
279            }
280            if is_sensitive_field(field_name) {
281                return "FieldMeaning::Sensitive";
282            }
283            "custom"
284        }
285    }
286}
287
288/// Check if a field name matches sensitive patterns.
289fn is_sensitive_field(field_name: &str) -> bool {
290    const SENSITIVE: &[&str] = &["password", "secret", "token", "api_key", "hashed_key"];
291    SENSITIVE.iter().any(|s| field_name.contains(s))
292}
293
294/// Determine the builder method for a field.
295/// Returns None if the field should be skipped (sensitive).
296fn field_builder_method(field: &ModelField, meaning: &str) -> Option<&'static str> {
297    if meaning == "FieldMeaning::Sensitive" {
298        return None;
299    }
300
301    if field.is_primary_key
302        || meaning == "FieldMeaning::CreatedAt"
303        || meaning == "FieldMeaning::UpdatedAt"
304        || meaning == "FieldMeaning::ForeignKey"
305    {
306        return Some("read_only_field");
307    }
308
309    if field.is_nullable {
310        return Some("optional_field");
311    }
312
313    Some("field")
314}
315
316// ---------------------------------------------------------------------------
317// Model-aware template generation
318// ---------------------------------------------------------------------------
319
320/// Generate a projection template populated from model fields.
321fn model_aware_template(name: &str, display_name: &str, fields: &[ModelField]) -> String {
322    let mut field_lines = Vec::new();
323    let mut belongs_to_lines = Vec::new();
324
325    for field in fields {
326        let meaning = infer_meaning(&field.name);
327        let data_type = rust_type_to_data_type(&field.rust_type);
328
329        let Some(builder) = field_builder_method(field, meaning) else {
330            continue;
331        };
332
333        let meaning_str = if meaning == "custom" {
334            format!("FieldMeaning::Custom(\"{}\".into())", field.name)
335        } else {
336            meaning.to_string()
337        };
338
339        field_lines.push(format!(
340            "        .{builder}(\"{}\", {data_type}, {meaning_str})",
341            field.name
342        ));
343
344        if field.name.ends_with("_id") && field.name.len() > 3 {
345            let rel_name = &field.name[..field.name.len() - 3];
346            belongs_to_lines.push(format!(
347                "        .belongs_to(\"{rel_name}\", \"{rel_name}\")"
348            ));
349        }
350    }
351
352    let all_lines: Vec<String> = field_lines.into_iter().chain(belongs_to_lines).collect();
353
354    let builder_calls = all_lines.join("\n");
355
356    format!(
357        r#"use ferro::{{
358    DataType, FieldMeaning, ServiceDef,
359}};
360
361/// Build the {display_name} service projection.
362///
363/// Derived from the {display_name} model.
364/// Describes the {display_name} entity's fields, relationships,
365/// and behavioral semantics for intent derivation and UI rendering.
366pub fn {name}_service() -> ServiceDef {{
367    ServiceDef::new("{name}")
368        .display_name("{display_name}")
369{builder_calls}
370}}
371"#
372    )
373}
374
375// ---------------------------------------------------------------------------
376// Public API
377// ---------------------------------------------------------------------------
378
379pub fn execute(name: &str, from_model: bool) {
380    let file_name = to_snake_case(name);
381
382    if !is_valid_identifier(&file_name) {
383        eprintln!(
384            "{} '{}' is not a valid projection name",
385            style("Error:").red().bold(),
386            name
387        );
388        std::process::exit(1);
389    }
390
391    let display_name = to_pascal_case(name);
392    let projections_dir = Path::new("src/projections");
393    let projection_file = projections_dir.join(format!("{file_name}.rs"));
394    let mod_file = projections_dir.join("mod.rs");
395
396    if !projections_dir.exists() {
397        if let Err(e) = fs::create_dir_all(projections_dir) {
398            eprintln!(
399                "{} Failed to create src/projections directory: {}",
400                style("Error:").red().bold(),
401                e
402            );
403            std::process::exit(1);
404        }
405        println!("{} Created src/projections/", style("✓").green());
406    }
407
408    if projection_file.exists() {
409        eprintln!(
410            "{} Projection '{}' already exists at {}",
411            style("Info:").yellow().bold(),
412            file_name,
413            projection_file.display()
414        );
415        std::process::exit(0);
416    }
417
418    if mod_file.exists() {
419        let mod_content = fs::read_to_string(&mod_file).unwrap_or_default();
420        let mod_decl = format!("mod {file_name};");
421        let pub_mod_decl = format!("pub mod {file_name};");
422        if mod_content.contains(&mod_decl) || mod_content.contains(&pub_mod_decl) {
423            eprintln!(
424                "{} Module '{}' is already declared in src/projections/mod.rs",
425                style("Info:").yellow().bold(),
426                file_name
427            );
428            std::process::exit(0);
429        }
430    }
431
432    let content = if from_model {
433        let project_root = Path::new(".");
434        let available = scan_models(project_root);
435
436        let matched = available
437            .iter()
438            .find(|(sn, _)| sn.eq_ignore_ascii_case(&file_name));
439
440        match matched {
441            Some((_, fields)) => {
442                println!(
443                    "{} Found model '{}' with {} fields",
444                    style("✓").green(),
445                    display_name,
446                    fields.len()
447                );
448                model_aware_template(&file_name, &display_name, fields)
449            }
450            None => {
451                let model_names: Vec<&str> = available.iter().map(|(n, _)| n.as_str()).collect();
452                if model_names.is_empty() {
453                    eprintln!(
454                        "{} No models found in src/models/",
455                        style("Error:").red().bold()
456                    );
457                } else {
458                    eprintln!(
459                        "{} Model '{}' not found. Available models: {}",
460                        style("Error:").red().bold(),
461                        file_name,
462                        model_names.join(", ")
463                    );
464                }
465                std::process::exit(1);
466            }
467        }
468    } else {
469        projection_template(&file_name, &display_name)
470    };
471
472    if let Err(e) = fs::write(&projection_file, &content) {
473        eprintln!(
474            "{} Failed to write projection file: {}",
475            style("Error:").red().bold(),
476            e
477        );
478        std::process::exit(1);
479    }
480    println!(
481        "{} Created {}",
482        style("✓").green(),
483        projection_file.display()
484    );
485
486    if mod_file.exists() {
487        if let Err(e) = update_mod_file(&mod_file, &file_name) {
488            eprintln!(
489                "{} Failed to update mod.rs: {}",
490                style("Error:").red().bold(),
491                e
492            );
493            std::process::exit(1);
494        }
495        println!("{} Updated src/projections/mod.rs", style("✓").green());
496    } else {
497        let mod_content = format!("pub mod {file_name};\n");
498        if let Err(e) = fs::write(&mod_file, mod_content) {
499            eprintln!(
500                "{} Failed to create mod.rs: {}",
501                style("Error:").red().bold(),
502                e
503            );
504            std::process::exit(1);
505        }
506        println!("{} Created src/projections/mod.rs", style("✓").green());
507    }
508
509    println!();
510    println!(
511        "Projection {} created successfully!",
512        style(&file_name).cyan().bold()
513    );
514    println!();
515    println!("Usage:");
516    println!(
517        "  {} Define fields matching your model in src/projections/{file_name}.rs",
518        style("1.").dim()
519    );
520    println!("  {} Use in a handler:", style("2.").dim());
521    println!("     use crate::projections::{file_name};");
522    println!();
523    println!("     let service = {file_name}::{file_name}_service();");
524    println!("     let intents = derive_intents(&service);");
525    println!();
526}
527
528fn projection_template(name: &str, display_name: &str) -> String {
529    format!(
530        r#"use ferro::{{
531    DataType, FieldMeaning, ServiceDef,
532}};
533
534/// Build the {display_name} service projection.
535///
536/// Describes the {display_name} entity's fields, relationships,
537/// and behavioral semantics for intent derivation and UI rendering.
538pub fn {name}_service() -> ServiceDef {{
539    ServiceDef::new("{name}")
540        .display_name("{display_name}")
541        .field("id", DataType::Integer, FieldMeaning::Identifier)
542        // Add fields matching your model:
543        // .field("name", DataType::String, FieldMeaning::EntityName)
544        // .field("email", DataType::String, FieldMeaning::Email)
545        // .field("status", DataType::String, FieldMeaning::Status)
546        // .field("created_at", DataType::DateTime, FieldMeaning::CreatedAt)
547        // .field("updated_at", DataType::DateTime, FieldMeaning::UpdatedAt)
548}}
549"#
550    )
551}
552
553fn is_valid_identifier(name: &str) -> bool {
554    if name.is_empty() {
555        return false;
556    }
557
558    let mut chars = name.chars();
559
560    match chars.next() {
561        Some(c) if c.is_alphabetic() || c == '_' => {}
562        _ => return false,
563    }
564
565    chars.all(|c| c.is_alphanumeric() || c == '_')
566}
567
568fn to_snake_case(s: &str) -> String {
569    let mut result = String::new();
570    for (i, c) in s.chars().enumerate() {
571        if c.is_uppercase() {
572            if i > 0 {
573                result.push('_');
574            }
575            result.push(c.to_lowercase().next().unwrap());
576        } else {
577            result.push(c);
578        }
579    }
580    result
581}
582
583fn to_pascal_case(s: &str) -> String {
584    let mut result = String::new();
585    let mut capitalize_next = true;
586
587    for c in s.chars() {
588        if c == '_' || c == '-' || c == ' ' {
589            capitalize_next = true;
590        } else if capitalize_next {
591            result.push(c.to_uppercase().next().unwrap());
592            capitalize_next = false;
593        } else {
594            result.push(c);
595        }
596    }
597    result
598}
599
600/// Generate projection files at the given base directory.
601/// Returns (projection_file_path, mod_file_path) on success.
602#[cfg(test)]
603fn generate_in_dir(
604    base_dir: &Path,
605    name: &str,
606) -> Result<(std::path::PathBuf, std::path::PathBuf), String> {
607    let file_name = to_snake_case(name);
608    let display_name = to_pascal_case(name);
609    let projections_dir = base_dir.join("src/projections");
610    let projection_file = projections_dir.join(format!("{file_name}.rs"));
611    let mod_file = projections_dir.join("mod.rs");
612
613    fs::create_dir_all(&projections_dir)
614        .map_err(|e| format!("Failed to create projections directory: {e}"))?;
615
616    let content = projection_template(&file_name, &display_name);
617    fs::write(&projection_file, content)
618        .map_err(|e| format!("Failed to write projection file: {e}"))?;
619
620    if mod_file.exists() {
621        let mod_content = fs::read_to_string(&mod_file).unwrap_or_default();
622        let pub_mod_decl = format!("pub mod {file_name};");
623        if !mod_content.contains(&pub_mod_decl) {
624            update_mod_file(&mod_file, &file_name)?;
625        }
626    } else {
627        let mod_content = format!("pub mod {file_name};\n");
628        fs::write(&mod_file, mod_content).map_err(|e| format!("Failed to create mod.rs: {e}"))?;
629    }
630
631    Ok((projection_file, mod_file))
632}
633
634fn update_mod_file(mod_file: &Path, file_name: &str) -> Result<(), String> {
635    let content =
636        fs::read_to_string(mod_file).map_err(|e| format!("Failed to read mod.rs: {e}"))?;
637
638    let pub_mod_decl = format!("pub mod {file_name};");
639
640    let mut lines: Vec<&str> = content.lines().collect();
641
642    let mut last_pub_mod_idx = None;
643    for (i, line) in lines.iter().enumerate() {
644        if line.trim().starts_with("pub mod ") {
645            last_pub_mod_idx = Some(i);
646        }
647    }
648
649    let insert_idx = match last_pub_mod_idx {
650        Some(idx) => idx + 1,
651        None => {
652            let mut insert_idx = 0;
653            for (i, line) in lines.iter().enumerate() {
654                if line.starts_with("//!") || line.is_empty() {
655                    insert_idx = i + 1;
656                } else {
657                    break;
658                }
659            }
660            insert_idx
661        }
662    };
663    lines.insert(insert_idx, &pub_mod_decl);
664
665    let new_content = lines.join("\n");
666    fs::write(mod_file, new_content).map_err(|e| format!("Failed to write mod.rs: {e}"))?;
667
668    Ok(())
669}
670
671#[cfg(test)]
672mod tests {
673    use super::*;
674    use tempfile::TempDir;
675
676    #[test]
677    fn test_template_generation() {
678        let template = projection_template("user", "User");
679
680        assert!(template.contains("pub fn user_service() -> ServiceDef"));
681        assert!(template.contains("ServiceDef::new(\"user\")"));
682        assert!(template.contains(".display_name(\"User\")"));
683        assert!(template.contains("DataType::Integer, FieldMeaning::Identifier"));
684        assert!(template.contains("use ferro::{"));
685        assert!(template.contains("/// Build the User service projection."));
686    }
687
688    #[test]
689    fn test_creates_directory_and_file() {
690        let tmp = TempDir::new().unwrap();
691        let (proj_file, _mod_file) = generate_in_dir(tmp.path(), "order").unwrap();
692
693        assert!(tmp.path().join("src/projections").exists());
694        assert!(proj_file.exists());
695
696        let content = fs::read_to_string(&proj_file).unwrap();
697        assert!(content.contains("pub fn order_service() -> ServiceDef"));
698        assert!(content.contains(".display_name(\"Order\")"));
699    }
700
701    #[test]
702    fn test_mod_rs_creation() {
703        let tmp = TempDir::new().unwrap();
704        let (_proj_file, mod_file) = generate_in_dir(tmp.path(), "product").unwrap();
705
706        assert!(mod_file.exists());
707        let mod_content = fs::read_to_string(&mod_file).unwrap();
708        assert!(mod_content.contains("pub mod product;"));
709    }
710
711    #[test]
712    fn test_mod_rs_append() {
713        let tmp = TempDir::new().unwrap();
714
715        generate_in_dir(tmp.path(), "user").unwrap();
716        let mod_file = tmp.path().join("src/projections/mod.rs");
717        let content = fs::read_to_string(&mod_file).unwrap();
718        assert!(content.contains("pub mod user;"));
719
720        generate_in_dir(tmp.path(), "order").unwrap();
721        let content = fs::read_to_string(&mod_file).unwrap();
722        assert!(content.contains("pub mod user;"));
723        assert!(content.contains("pub mod order;"));
724
725        generate_in_dir(tmp.path(), "order").unwrap();
726        let content = fs::read_to_string(&mod_file).unwrap();
727        let count = content.matches("pub mod order;").count();
728        assert_eq!(count, 1, "pub mod order; should appear exactly once");
729    }
730
731    // -- Model-aware tests --
732
733    fn write_mock_model(dir: &Path, filename: &str, content: &str) {
734        fs::create_dir_all(dir).unwrap();
735        fs::write(dir.join(filename), content).unwrap();
736    }
737
738    #[test]
739    fn test_model_aware_template_basic() {
740        let tmp = TempDir::new().unwrap();
741        let models_dir = tmp.path().join("models");
742        write_mock_model(
743            &models_dir,
744            "user.rs",
745            r#"
746use sea_orm::entity::prelude::*;
747
748#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
749#[sea_orm(table_name = "users")]
750pub struct Model {
751    #[sea_orm(primary_key)]
752    pub id: i32,
753    pub name: String,
754    pub email: String,
755    pub created_at: DateTime,
756    pub updated_at: DateTime,
757}
758"#,
759        );
760
761        let models = scan_models_from_dir(&models_dir);
762        assert_eq!(models.len(), 1);
763        let (name, fields) = &models[0];
764        assert_eq!(name, "user");
765
766        let output = model_aware_template("user", "User", fields);
767
768        assert!(output.contains("pub fn user_service() -> ServiceDef"));
769        assert!(output.contains("Derived from the User model"));
770        assert!(output
771            .contains(r#".read_only_field("id", DataType::Integer, FieldMeaning::Identifier)"#));
772        assert!(output
773            .contains(r#".field("name", DataType::String, FieldMeaning::Custom("name".into()))"#));
774        assert!(output.contains(r#".field("email", DataType::String, FieldMeaning::Email)"#));
775        assert!(output.contains(
776            r#".read_only_field("created_at", DataType::DateTime, FieldMeaning::CreatedAt)"#
777        ));
778        assert!(output.contains(
779            r#".read_only_field("updated_at", DataType::DateTime, FieldMeaning::UpdatedAt)"#
780        ));
781    }
782
783    #[test]
784    fn test_model_aware_excludes_sensitive() {
785        let tmp = TempDir::new().unwrap();
786        let models_dir = tmp.path().join("models");
787        write_mock_model(
788            &models_dir,
789            "user.rs",
790            r#"
791use sea_orm::entity::prelude::*;
792
793#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
794#[sea_orm(table_name = "users")]
795pub struct Model {
796    #[sea_orm(primary_key)]
797    pub id: i32,
798    pub name: String,
799    pub password_hash: String,
800    pub remember_token: Option<String>,
801}
802"#,
803        );
804
805        let models = scan_models_from_dir(&models_dir);
806        let (_, fields) = &models[0];
807        let output = model_aware_template("user", "User", fields);
808
809        assert!(output.contains(r#".read_only_field("id""#));
810        assert!(output.contains(r#".field("name""#));
811        assert!(!output.contains("password_hash"));
812        assert!(!output.contains("remember_token"));
813    }
814
815    #[test]
816    fn test_model_aware_foreign_keys() {
817        let tmp = TempDir::new().unwrap();
818        let models_dir = tmp.path().join("models");
819        write_mock_model(
820            &models_dir,
821            "order.rs",
822            r#"
823use sea_orm::entity::prelude::*;
824
825#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
826#[sea_orm(table_name = "orders")]
827pub struct Model {
828    #[sea_orm(primary_key)]
829    pub id: i32,
830    pub user_id: i32,
831    pub total: f64,
832}
833"#,
834        );
835
836        let models = scan_models_from_dir(&models_dir);
837        let (_, fields) = &models[0];
838        let output = model_aware_template("order", "Order", fields);
839
840        assert!(output.contains(
841            r#".read_only_field("user_id", DataType::Integer, FieldMeaning::ForeignKey)"#
842        ));
843        assert!(output.contains(r#".belongs_to("user", "user")"#));
844    }
845
846    #[test]
847    fn test_model_aware_optional_fields() {
848        let tmp = TempDir::new().unwrap();
849        let models_dir = tmp.path().join("models");
850        write_mock_model(
851            &models_dir,
852            "post.rs",
853            r#"
854use sea_orm::entity::prelude::*;
855
856#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
857#[sea_orm(table_name = "posts")]
858pub struct Model {
859    #[sea_orm(primary_key)]
860    pub id: i32,
861    pub title: String,
862    pub notes: Option<String>,
863}
864"#,
865        );
866
867        let models = scan_models_from_dir(&models_dir);
868        let (_, fields) = &models[0];
869        let output = model_aware_template("post", "Post", fields);
870
871        assert!(output.contains(
872            r#".optional_field("notes", DataType::String, FieldMeaning::Custom("notes".into()))"#
873        ));
874    }
875
876    #[test]
877    fn test_rust_type_to_data_type() {
878        assert_eq!(rust_type_to_data_type("String"), "DataType::String");
879        assert_eq!(rust_type_to_data_type("i32"), "DataType::Integer");
880        assert_eq!(rust_type_to_data_type("i64"), "DataType::Integer");
881        assert_eq!(rust_type_to_data_type("u32"), "DataType::Integer");
882        assert_eq!(rust_type_to_data_type("f32"), "DataType::Float");
883        assert_eq!(rust_type_to_data_type("f64"), "DataType::Float");
884        assert_eq!(rust_type_to_data_type("bool"), "DataType::Boolean");
885        assert_eq!(rust_type_to_data_type("DateTime"), "DataType::DateTime");
886        assert_eq!(
887            rust_type_to_data_type("NaiveDateTime"),
888            "DataType::DateTime"
889        );
890        assert_eq!(rust_type_to_data_type("NaiveDate"), "DataType::Date");
891        assert_eq!(rust_type_to_data_type("Uuid"), "DataType::Uuid");
892        assert_eq!(rust_type_to_data_type("Vec<u8>"), "DataType::Binary");
893        assert_eq!(rust_type_to_data_type("Json"), "DataType::Json");
894        assert_eq!(rust_type_to_data_type("Option<String>"), "DataType::String");
895        assert_eq!(rust_type_to_data_type("Option<i32>"), "DataType::Integer");
896        assert_eq!(
897            rust_type_to_data_type("SomeUnknownType"),
898            "DataType::String"
899        );
900    }
901
902    #[test]
903    fn test_infer_field_meaning() {
904        assert_eq!(infer_meaning("id"), "FieldMeaning::Identifier");
905        assert_eq!(infer_meaning("email"), "FieldMeaning::Email");
906        assert_eq!(infer_meaning("created_at"), "FieldMeaning::CreatedAt");
907        assert_eq!(infer_meaning("updated_at"), "FieldMeaning::UpdatedAt");
908        assert_eq!(infer_meaning("user_id"), "FieldMeaning::ForeignKey");
909        assert_eq!(infer_meaning("deleted_at"), "FieldMeaning::DateTime");
910        assert_eq!(infer_meaning("is_active"), "FieldMeaning::Boolean");
911        assert_eq!(infer_meaning("has_premium"), "FieldMeaning::Boolean");
912        assert_eq!(infer_meaning("password"), "FieldMeaning::Sensitive");
913        assert_eq!(infer_meaning("remember_token"), "FieldMeaning::Sensitive");
914        assert_eq!(infer_meaning("title"), "custom");
915    }
916}