Skip to main content

ferro_cli/commands/
make_projection.rs

1use crate::naming::{is_valid_identifier, to_snake_case};
2use console::style;
3use quote::ToTokens;
4use std::fs;
5use std::path::Path;
6use syn::visit::Visit;
7use syn::{Attribute, Fields, ItemStruct, Type};
8use walkdir::WalkDir;
9
10// ---------------------------------------------------------------------------
11// Model field metadata (self-contained, mirrors make_api::FieldInfo)
12// ---------------------------------------------------------------------------
13
14#[derive(Debug, Clone)]
15struct ModelField {
16    name: String,
17    rust_type: String,
18    is_primary_key: bool,
19    is_nullable: bool,
20}
21
22// ---------------------------------------------------------------------------
23// AST visitor for model detection
24// ---------------------------------------------------------------------------
25
26struct ModelVisitor {
27    fields: Vec<ModelField>,
28    found: bool,
29}
30
31impl ModelVisitor {
32    fn new() -> Self {
33        Self {
34            fields: Vec::new(),
35            found: false,
36        }
37    }
38
39    fn has_model_derive(attrs: &[Attribute]) -> bool {
40        for attr in attrs {
41            if attr.path().is_ident("derive") {
42                if let Ok(nested) = attr.parse_args_with(
43                    syn::punctuated::Punctuated::<syn::Path, syn::Token![,]>::parse_terminated,
44                ) {
45                    for path in nested {
46                        let ident = path.segments.last().map(|s| s.ident.to_string());
47                        if matches!(
48                            ident.as_deref(),
49                            Some("DeriveEntityModel") | Some("FerroModel")
50                        ) {
51                            return true;
52                        }
53                    }
54                }
55            }
56        }
57        false
58    }
59
60    fn is_field_primary_key(attrs: &[Attribute]) -> bool {
61        for attr in attrs {
62            if attr.path().is_ident("sea_orm") {
63                let tokens = attr.meta.to_token_stream().to_string();
64                if tokens.contains("primary_key") {
65                    return true;
66                }
67            }
68        }
69        false
70    }
71
72    fn type_to_string(ty: &Type) -> String {
73        ty.to_token_stream().to_string().replace(' ', "")
74    }
75
76    fn extract_fields(fields: &Fields) -> Vec<ModelField> {
77        let mut result = Vec::new();
78        if let Fields::Named(named) = fields {
79            for field in &named.named {
80                if let Some(ident) = &field.ident {
81                    let name = ident.to_string();
82                    let rust_type = Self::type_to_string(&field.ty);
83                    let is_nullable = rust_type.starts_with("Option<");
84                    let is_primary_key = Self::is_field_primary_key(&field.attrs);
85                    result.push(ModelField {
86                        name,
87                        rust_type,
88                        is_primary_key,
89                        is_nullable,
90                    });
91                }
92            }
93        }
94        result
95    }
96}
97
98impl<'ast> Visit<'ast> for ModelVisitor {
99    fn visit_item_struct(&mut self, node: &'ast ItemStruct) {
100        if Self::has_model_derive(&node.attrs) && node.ident == "Model" {
101            self.fields = Self::extract_fields(&node.fields);
102            self.found = true;
103        }
104        syn::visit::visit_item_struct(self, node);
105    }
106}
107
108// ---------------------------------------------------------------------------
109// Model scanning
110// ---------------------------------------------------------------------------
111
112/// Scan `src/models/` for model files and return matching model fields.
113fn scan_models(project_root: &Path) -> Vec<(String, Vec<ModelField>)> {
114    let models_dir = project_root.join("src/models");
115    if !models_dir.exists() || !models_dir.is_dir() {
116        return Vec::new();
117    }
118
119    let mut results = Vec::new();
120
121    for entry in WalkDir::new(&models_dir)
122        .into_iter()
123        .filter_map(|e| e.ok())
124        .filter(|e| e.path().extension().is_some_and(|ext| ext == "rs"))
125    {
126        let file_stem = entry
127            .path()
128            .file_stem()
129            .map(|s| s.to_string_lossy().to_string())
130            .unwrap_or_default();
131
132        if file_stem == "mod" {
133            continue;
134        }
135
136        let Ok(content) = fs::read_to_string(entry.path()) else {
137            continue;
138        };
139        let Ok(syntax) = syn::parse_file(&content) else {
140            continue;
141        };
142
143        let mut visitor = ModelVisitor::new();
144        visitor.visit_file(&syntax);
145
146        if visitor.found {
147            let is_entity_file = entry
148                .path()
149                .parent()
150                .and_then(|p| p.file_name())
151                .is_some_and(|dir| dir == "entities");
152
153            let singular_stem = if is_entity_file {
154                singularize(&file_stem)
155            } else {
156                file_stem.clone()
157            };
158
159            results.push((singular_stem, visitor.fields));
160        }
161    }
162
163    results
164}
165
166/// Scan models from a specific directory (for testing).
167#[cfg(test)]
168fn scan_models_from_dir(models_dir: &Path) -> Vec<(String, Vec<ModelField>)> {
169    if !models_dir.exists() || !models_dir.is_dir() {
170        return Vec::new();
171    }
172
173    let mut results = Vec::new();
174
175    for entry in WalkDir::new(models_dir)
176        .into_iter()
177        .filter_map(|e| e.ok())
178        .filter(|e| e.path().extension().is_some_and(|ext| ext == "rs"))
179    {
180        let file_stem = entry
181            .path()
182            .file_stem()
183            .map(|s| s.to_string_lossy().to_string())
184            .unwrap_or_default();
185
186        if file_stem == "mod" {
187            continue;
188        }
189
190        let Ok(content) = fs::read_to_string(entry.path()) else {
191            continue;
192        };
193        let Ok(syntax) = syn::parse_file(&content) else {
194            continue;
195        };
196
197        let mut visitor = ModelVisitor::new();
198        visitor.visit_file(&syntax);
199
200        if visitor.found {
201            results.push((file_stem, visitor.fields));
202        }
203    }
204
205    results
206}
207
208fn singularize(name: &str) -> String {
209    if name.ends_with("ies") && name.len() > 3 {
210        format!("{}y", &name[..name.len() - 3])
211    } else if name.ends_with("ses")
212        || name.ends_with("xes")
213        || name.ends_with("ches")
214        || name.ends_with("shes")
215    {
216        name[..name.len() - 2].to_string()
217    } else if name.ends_with('s') && !name.ends_with("ss") {
218        name[..name.len() - 1].to_string()
219    } else {
220        name.to_string()
221    }
222}
223
224// ---------------------------------------------------------------------------
225// Type and meaning mapping
226// ---------------------------------------------------------------------------
227
228/// Map a Rust type string to a DataType name for code generation.
229fn rust_type_to_data_type(rust_type: &str) -> &'static str {
230    let inner = unwrap_option(rust_type);
231
232    match inner {
233        "String" | "&str" => "DataType::String",
234        "i8" | "i16" | "i32" | "i64" | "u8" | "u16" | "u32" | "u64" | "usize" | "isize" => {
235            "DataType::Integer"
236        }
237        "f32" | "f64" | "Decimal" => "DataType::Float",
238        "bool" => "DataType::Boolean",
239        "DateTime"
240        | "NaiveDateTime"
241        | "DateTimeUtc"
242        | "DateTimeWithTimeZone"
243        | "chrono::DateTime<chrono::Utc>"
244        | "chrono::NaiveDateTime" => "DataType::DateTime",
245        "NaiveDate" | "Date" | "chrono::NaiveDate" => "DataType::Date",
246        "Uuid" | "uuid::Uuid" => "DataType::Uuid",
247        "Vec<u8>" => "DataType::Binary",
248        "Json" | "serde_json::Value" | "JsonValue" => "DataType::Json",
249        _ => "DataType::String",
250    }
251}
252
253/// Unwrap `Option<T>` to get the inner type string.
254fn unwrap_option(ty: &str) -> &str {
255    if let Some(inner) = ty.strip_prefix("Option<") {
256        if let Some(inner) = inner.strip_suffix('>') {
257            return inner;
258        }
259    }
260    ty
261}
262
263/// Infer a FieldMeaning variant name from a field name.
264/// Replicates the logic from ferro-projections/src/field.rs.
265fn infer_meaning(field_name: &str) -> &'static str {
266    match field_name {
267        "id" => "FieldMeaning::Identifier",
268        "email" => "FieldMeaning::Email",
269        "created_at" => "FieldMeaning::CreatedAt",
270        "updated_at" => "FieldMeaning::UpdatedAt",
271        _ => {
272            if field_name.ends_with("_id") {
273                return "FieldMeaning::ForeignKey";
274            }
275            if field_name.ends_with("_at") {
276                return "FieldMeaning::DateTime";
277            }
278            if field_name.starts_with("is_") || field_name.starts_with("has_") {
279                return "FieldMeaning::Boolean";
280            }
281            if is_sensitive_field(field_name) {
282                return "FieldMeaning::Sensitive";
283            }
284            "custom"
285        }
286    }
287}
288
289/// Check if a field name matches sensitive patterns.
290fn is_sensitive_field(field_name: &str) -> bool {
291    const SENSITIVE: &[&str] = &["password", "secret", "token", "api_key", "hashed_key"];
292    SENSITIVE.iter().any(|s| field_name.contains(s))
293}
294
295/// Determine the builder method for a field.
296/// Returns None if the field should be skipped (sensitive).
297fn field_builder_method(field: &ModelField, meaning: &str) -> Option<&'static str> {
298    if meaning == "FieldMeaning::Sensitive" {
299        return None;
300    }
301
302    if field.is_primary_key
303        || meaning == "FieldMeaning::CreatedAt"
304        || meaning == "FieldMeaning::UpdatedAt"
305        || meaning == "FieldMeaning::ForeignKey"
306    {
307        return Some("read_only_field");
308    }
309
310    if field.is_nullable {
311        return Some("optional_field");
312    }
313
314    Some("field")
315}
316
317// ---------------------------------------------------------------------------
318// Model-aware template generation
319// ---------------------------------------------------------------------------
320
321/// Generate a projection template populated from model fields.
322fn model_aware_template(name: &str, display_name: &str, fields: &[ModelField]) -> String {
323    let mut field_lines = Vec::new();
324    let mut belongs_to_lines = Vec::new();
325
326    for field in fields {
327        let meaning = infer_meaning(&field.name);
328        let data_type = rust_type_to_data_type(&field.rust_type);
329
330        let Some(builder) = field_builder_method(field, meaning) else {
331            continue;
332        };
333
334        let meaning_str = if meaning == "custom" {
335            // Use {:?} debug-formatting (same as ai_make.rs) so field names with
336            // special characters produce valid Rust source rather than broken code.
337            format!("FieldMeaning::Custom({:?}.into())", field.name)
338        } else {
339            meaning.to_string()
340        };
341
342        field_lines.push(format!(
343            "        .{builder}(\"{}\", {data_type}, {meaning_str})",
344            field.name
345        ));
346
347        if field.name.ends_with("_id") && field.name.len() > 3 {
348            let rel_name = &field.name[..field.name.len() - 3];
349            belongs_to_lines.push(format!(
350                "        .belongs_to(\"{rel_name}\", \"{rel_name}\")"
351            ));
352        }
353    }
354
355    let all_lines: Vec<String> = field_lines.into_iter().chain(belongs_to_lines).collect();
356
357    let builder_calls = all_lines.join("\n");
358
359    format!(
360        r#"use ferro::{{
361    DataType, FieldMeaning, ServiceDef,
362}};
363
364/// Build the {display_name} service projection.
365///
366/// Derived from the {display_name} model.
367/// Describes the {display_name} entity's fields, relationships,
368/// and behavioral semantics for intent derivation and UI rendering.
369pub fn {name}_service() -> ServiceDef {{
370    ServiceDef::new("{name}")
371        .display_name("{display_name}")
372{builder_calls}
373}}
374"#
375    )
376}
377
378// ---------------------------------------------------------------------------
379// Public API
380// ---------------------------------------------------------------------------
381
382pub fn execute(name: &str, from_model: bool) {
383    let file_name = to_snake_case(name);
384
385    if !is_valid_identifier(&file_name) {
386        eprintln!(
387            "{} '{}' is not a valid projection name",
388            style("Error:").red().bold(),
389            name
390        );
391        std::process::exit(1);
392    }
393
394    let display_name = to_pascal_case(name);
395    let projections_dir = Path::new("src/projections");
396    let projection_file = projections_dir.join(format!("{file_name}.rs"));
397    let mod_file = projections_dir.join("mod.rs");
398
399    if !projections_dir.exists() {
400        if let Err(e) = fs::create_dir_all(projections_dir) {
401            eprintln!(
402                "{} Failed to create src/projections directory: {}",
403                style("Error:").red().bold(),
404                e
405            );
406            std::process::exit(1);
407        }
408        println!("{} Created src/projections/", style("✓").green());
409    }
410
411    if projection_file.exists() {
412        eprintln!(
413            "{} Projection '{}' already exists at {}",
414            style("Info:").yellow().bold(),
415            file_name,
416            projection_file.display()
417        );
418        std::process::exit(0);
419    }
420
421    if mod_file.exists() {
422        let mod_content = fs::read_to_string(&mod_file).unwrap_or_default();
423        let mod_decl = format!("mod {file_name};");
424        let pub_mod_decl = format!("pub mod {file_name};");
425        if mod_content.contains(&mod_decl) || mod_content.contains(&pub_mod_decl) {
426            eprintln!(
427                "{} Module '{}' is already declared in src/projections/mod.rs",
428                style("Info:").yellow().bold(),
429                file_name
430            );
431            std::process::exit(0);
432        }
433    }
434
435    let content = if from_model {
436        let project_root = Path::new(".");
437        let available = scan_models(project_root);
438
439        let matched = available
440            .iter()
441            .find(|(sn, _)| sn.eq_ignore_ascii_case(&file_name));
442
443        match matched {
444            Some((_, fields)) => {
445                println!(
446                    "{} Found model '{}' with {} fields",
447                    style("✓").green(),
448                    display_name,
449                    fields.len()
450                );
451                model_aware_template(&file_name, &display_name, fields)
452            }
453            None => {
454                let model_names: Vec<&str> = available.iter().map(|(n, _)| n.as_str()).collect();
455                if model_names.is_empty() {
456                    eprintln!(
457                        "{} No models found in src/models/",
458                        style("Error:").red().bold()
459                    );
460                } else {
461                    eprintln!(
462                        "{} Model '{}' not found. Available models: {}",
463                        style("Error:").red().bold(),
464                        file_name,
465                        model_names.join(", ")
466                    );
467                }
468                std::process::exit(1);
469            }
470        }
471    } else {
472        projection_template(&file_name, &display_name)
473    };
474
475    if let Err(e) = fs::write(&projection_file, &content) {
476        eprintln!(
477            "{} Failed to write projection file: {}",
478            style("Error:").red().bold(),
479            e
480        );
481        std::process::exit(1);
482    }
483    println!(
484        "{} Created {}",
485        style("✓").green(),
486        projection_file.display()
487    );
488
489    if mod_file.exists() {
490        if let Err(e) = update_mod_file(&mod_file, &file_name) {
491            eprintln!(
492                "{} Failed to update mod.rs: {}",
493                style("Error:").red().bold(),
494                e
495            );
496            std::process::exit(1);
497        }
498        println!("{} Updated src/projections/mod.rs", style("✓").green());
499    } else {
500        let mod_content = format!("pub mod {file_name};\n");
501        if let Err(e) = fs::write(&mod_file, mod_content) {
502            eprintln!(
503                "{} Failed to create mod.rs: {}",
504                style("Error:").red().bold(),
505                e
506            );
507            std::process::exit(1);
508        }
509        println!("{} Created src/projections/mod.rs", style("✓").green());
510    }
511
512    println!();
513    println!(
514        "Projection {} created successfully!",
515        style(&file_name).cyan().bold()
516    );
517    println!();
518    println!("Usage:");
519    println!(
520        "  {} Define fields matching your model in src/projections/{file_name}.rs",
521        style("1.").dim()
522    );
523    println!("  {} Use in a handler:", style("2.").dim());
524    println!("     use crate::projections::{file_name};");
525    println!();
526    println!("     let service = {file_name}::{file_name}_service();");
527    println!("     let intents = derive_intents(&service);");
528    println!();
529}
530
531fn projection_template(name: &str, display_name: &str) -> String {
532    format!(
533        r#"use ferro::{{
534    DataType, FieldMeaning, ServiceDef,
535}};
536
537/// Build the {display_name} service projection.
538///
539/// Describes the {display_name} entity's fields, relationships,
540/// and behavioral semantics for intent derivation and UI rendering.
541pub fn {name}_service() -> ServiceDef {{
542    ServiceDef::new("{name}")
543        .display_name("{display_name}")
544        .field("id", DataType::Integer, FieldMeaning::Identifier)
545        // Add fields matching your model:
546        // .field("name", DataType::String, FieldMeaning::EntityName)
547        // .field("email", DataType::String, FieldMeaning::Email)
548        // .field("status", DataType::String, FieldMeaning::Status)
549        // .field("created_at", DataType::DateTime, FieldMeaning::CreatedAt)
550        // .field("updated_at", DataType::DateTime, FieldMeaning::UpdatedAt)
551}}
552"#
553    )
554}
555
556fn to_pascal_case(s: &str) -> String {
557    let mut result = String::new();
558    let mut capitalize_next = true;
559
560    for c in s.chars() {
561        if c == '_' || c == '-' || c == ' ' {
562            capitalize_next = true;
563        } else if capitalize_next {
564            result.push(c.to_uppercase().next().unwrap());
565            capitalize_next = false;
566        } else {
567            result.push(c);
568        }
569    }
570    result
571}
572
573/// Generate projection files at the given base directory.
574/// Returns (projection_file_path, mod_file_path) on success.
575#[cfg(test)]
576fn generate_in_dir(
577    base_dir: &Path,
578    name: &str,
579) -> Result<(std::path::PathBuf, std::path::PathBuf), String> {
580    let file_name = to_snake_case(name);
581    let display_name = to_pascal_case(name);
582    let projections_dir = base_dir.join("src/projections");
583    let projection_file = projections_dir.join(format!("{file_name}.rs"));
584    let mod_file = projections_dir.join("mod.rs");
585
586    fs::create_dir_all(&projections_dir)
587        .map_err(|e| format!("Failed to create projections directory: {e}"))?;
588
589    let content = projection_template(&file_name, &display_name);
590    fs::write(&projection_file, content)
591        .map_err(|e| format!("Failed to write projection file: {e}"))?;
592
593    if mod_file.exists() {
594        let mod_content = fs::read_to_string(&mod_file).unwrap_or_default();
595        let pub_mod_decl = format!("pub mod {file_name};");
596        if !mod_content.contains(&pub_mod_decl) {
597            update_mod_file(&mod_file, &file_name)?;
598        }
599    } else {
600        let mod_content = format!("pub mod {file_name};\n");
601        fs::write(&mod_file, mod_content).map_err(|e| format!("Failed to create mod.rs: {e}"))?;
602    }
603
604    Ok((projection_file, mod_file))
605}
606
607pub(crate) fn update_mod_file(mod_file: &Path, file_name: &str) -> Result<(), String> {
608    let content =
609        fs::read_to_string(mod_file).map_err(|e| format!("Failed to read mod.rs: {e}"))?;
610
611    let pub_mod_decl = format!("pub mod {file_name};");
612
613    let mut lines: Vec<&str> = content.lines().collect();
614
615    let mut last_pub_mod_idx = None;
616    for (i, line) in lines.iter().enumerate() {
617        if line.trim().starts_with("pub mod ") {
618            last_pub_mod_idx = Some(i);
619        }
620    }
621
622    let insert_idx = match last_pub_mod_idx {
623        Some(idx) => idx + 1,
624        None => {
625            let mut insert_idx = 0;
626            for (i, line) in lines.iter().enumerate() {
627                if line.starts_with("//!") || line.is_empty() {
628                    insert_idx = i + 1;
629                } else {
630                    break;
631                }
632            }
633            insert_idx
634        }
635    };
636    lines.insert(insert_idx, &pub_mod_decl);
637
638    let new_content = lines.join("\n");
639    fs::write(mod_file, new_content).map_err(|e| format!("Failed to write mod.rs: {e}"))?;
640
641    Ok(())
642}
643
644#[cfg(test)]
645mod tests {
646    use super::*;
647    use tempfile::TempDir;
648
649    #[test]
650    fn test_template_generation() {
651        let template = projection_template("user", "User");
652
653        assert!(template.contains("pub fn user_service() -> ServiceDef"));
654        assert!(template.contains("ServiceDef::new(\"user\")"));
655        assert!(template.contains(".display_name(\"User\")"));
656        assert!(template.contains("DataType::Integer, FieldMeaning::Identifier"));
657        assert!(template.contains("use ferro::{"));
658        assert!(template.contains("/// Build the User service projection."));
659    }
660
661    #[test]
662    fn test_creates_directory_and_file() {
663        let tmp = TempDir::new().unwrap();
664        let (proj_file, _mod_file) = generate_in_dir(tmp.path(), "order").unwrap();
665
666        assert!(tmp.path().join("src/projections").exists());
667        assert!(proj_file.exists());
668
669        let content = fs::read_to_string(&proj_file).unwrap();
670        assert!(content.contains("pub fn order_service() -> ServiceDef"));
671        assert!(content.contains(".display_name(\"Order\")"));
672    }
673
674    #[test]
675    fn test_mod_rs_creation() {
676        let tmp = TempDir::new().unwrap();
677        let (_proj_file, mod_file) = generate_in_dir(tmp.path(), "product").unwrap();
678
679        assert!(mod_file.exists());
680        let mod_content = fs::read_to_string(&mod_file).unwrap();
681        assert!(mod_content.contains("pub mod product;"));
682    }
683
684    #[test]
685    fn test_mod_rs_append() {
686        let tmp = TempDir::new().unwrap();
687
688        generate_in_dir(tmp.path(), "user").unwrap();
689        let mod_file = tmp.path().join("src/projections/mod.rs");
690        let content = fs::read_to_string(&mod_file).unwrap();
691        assert!(content.contains("pub mod user;"));
692
693        generate_in_dir(tmp.path(), "order").unwrap();
694        let content = fs::read_to_string(&mod_file).unwrap();
695        assert!(content.contains("pub mod user;"));
696        assert!(content.contains("pub mod order;"));
697
698        generate_in_dir(tmp.path(), "order").unwrap();
699        let content = fs::read_to_string(&mod_file).unwrap();
700        let count = content.matches("pub mod order;").count();
701        assert_eq!(count, 1, "pub mod order; should appear exactly once");
702    }
703
704    // -- Model-aware tests --
705
706    fn write_mock_model(dir: &Path, filename: &str, content: &str) {
707        fs::create_dir_all(dir).unwrap();
708        fs::write(dir.join(filename), content).unwrap();
709    }
710
711    #[test]
712    fn test_model_aware_template_basic() {
713        let tmp = TempDir::new().unwrap();
714        let models_dir = tmp.path().join("models");
715        write_mock_model(
716            &models_dir,
717            "user.rs",
718            r#"
719use sea_orm::entity::prelude::*;
720
721#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
722#[sea_orm(table_name = "users")]
723pub struct Model {
724    #[sea_orm(primary_key)]
725    pub id: i32,
726    pub name: String,
727    pub email: String,
728    pub created_at: DateTime,
729    pub updated_at: DateTime,
730}
731"#,
732        );
733
734        let models = scan_models_from_dir(&models_dir);
735        assert_eq!(models.len(), 1);
736        let (name, fields) = &models[0];
737        assert_eq!(name, "user");
738
739        let output = model_aware_template("user", "User", fields);
740
741        assert!(output.contains("pub fn user_service() -> ServiceDef"));
742        assert!(output.contains("Derived from the User model"));
743        assert!(output
744            .contains(r#".read_only_field("id", DataType::Integer, FieldMeaning::Identifier)"#));
745        assert!(output
746            .contains(r#".field("name", DataType::String, FieldMeaning::Custom("name".into()))"#));
747        assert!(output.contains(r#".field("email", DataType::String, FieldMeaning::Email)"#));
748        assert!(output.contains(
749            r#".read_only_field("created_at", DataType::DateTime, FieldMeaning::CreatedAt)"#
750        ));
751        assert!(output.contains(
752            r#".read_only_field("updated_at", DataType::DateTime, FieldMeaning::UpdatedAt)"#
753        ));
754    }
755
756    #[test]
757    fn test_model_aware_excludes_sensitive() {
758        let tmp = TempDir::new().unwrap();
759        let models_dir = tmp.path().join("models");
760        write_mock_model(
761            &models_dir,
762            "user.rs",
763            r#"
764use sea_orm::entity::prelude::*;
765
766#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
767#[sea_orm(table_name = "users")]
768pub struct Model {
769    #[sea_orm(primary_key)]
770    pub id: i32,
771    pub name: String,
772    pub password_hash: String,
773    pub remember_token: Option<String>,
774}
775"#,
776        );
777
778        let models = scan_models_from_dir(&models_dir);
779        let (_, fields) = &models[0];
780        let output = model_aware_template("user", "User", fields);
781
782        assert!(output.contains(r#".read_only_field("id""#));
783        assert!(output.contains(r#".field("name""#));
784        assert!(!output.contains("password_hash"));
785        assert!(!output.contains("remember_token"));
786    }
787
788    #[test]
789    fn test_model_aware_foreign_keys() {
790        let tmp = TempDir::new().unwrap();
791        let models_dir = tmp.path().join("models");
792        write_mock_model(
793            &models_dir,
794            "order.rs",
795            r#"
796use sea_orm::entity::prelude::*;
797
798#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
799#[sea_orm(table_name = "orders")]
800pub struct Model {
801    #[sea_orm(primary_key)]
802    pub id: i32,
803    pub user_id: i32,
804    pub total: f64,
805}
806"#,
807        );
808
809        let models = scan_models_from_dir(&models_dir);
810        let (_, fields) = &models[0];
811        let output = model_aware_template("order", "Order", fields);
812
813        assert!(output.contains(
814            r#".read_only_field("user_id", DataType::Integer, FieldMeaning::ForeignKey)"#
815        ));
816        assert!(output.contains(r#".belongs_to("user", "user")"#));
817    }
818
819    #[test]
820    fn test_model_aware_optional_fields() {
821        let tmp = TempDir::new().unwrap();
822        let models_dir = tmp.path().join("models");
823        write_mock_model(
824            &models_dir,
825            "post.rs",
826            r#"
827use sea_orm::entity::prelude::*;
828
829#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
830#[sea_orm(table_name = "posts")]
831pub struct Model {
832    #[sea_orm(primary_key)]
833    pub id: i32,
834    pub title: String,
835    pub notes: Option<String>,
836}
837"#,
838        );
839
840        let models = scan_models_from_dir(&models_dir);
841        let (_, fields) = &models[0];
842        let output = model_aware_template("post", "Post", fields);
843
844        assert!(output.contains(
845            r#".optional_field("notes", DataType::String, FieldMeaning::Custom("notes".into()))"#
846        ));
847    }
848
849    #[test]
850    fn test_rust_type_to_data_type() {
851        assert_eq!(rust_type_to_data_type("String"), "DataType::String");
852        assert_eq!(rust_type_to_data_type("i32"), "DataType::Integer");
853        assert_eq!(rust_type_to_data_type("i64"), "DataType::Integer");
854        assert_eq!(rust_type_to_data_type("u32"), "DataType::Integer");
855        assert_eq!(rust_type_to_data_type("f32"), "DataType::Float");
856        assert_eq!(rust_type_to_data_type("f64"), "DataType::Float");
857        assert_eq!(rust_type_to_data_type("bool"), "DataType::Boolean");
858        assert_eq!(rust_type_to_data_type("DateTime"), "DataType::DateTime");
859        assert_eq!(
860            rust_type_to_data_type("NaiveDateTime"),
861            "DataType::DateTime"
862        );
863        assert_eq!(rust_type_to_data_type("NaiveDate"), "DataType::Date");
864        assert_eq!(rust_type_to_data_type("Uuid"), "DataType::Uuid");
865        assert_eq!(rust_type_to_data_type("Vec<u8>"), "DataType::Binary");
866        assert_eq!(rust_type_to_data_type("Json"), "DataType::Json");
867        assert_eq!(rust_type_to_data_type("Option<String>"), "DataType::String");
868        assert_eq!(rust_type_to_data_type("Option<i32>"), "DataType::Integer");
869        assert_eq!(
870            rust_type_to_data_type("SomeUnknownType"),
871            "DataType::String"
872        );
873    }
874
875    #[test]
876    fn test_infer_field_meaning() {
877        assert_eq!(infer_meaning("id"), "FieldMeaning::Identifier");
878        assert_eq!(infer_meaning("email"), "FieldMeaning::Email");
879        assert_eq!(infer_meaning("created_at"), "FieldMeaning::CreatedAt");
880        assert_eq!(infer_meaning("updated_at"), "FieldMeaning::UpdatedAt");
881        assert_eq!(infer_meaning("user_id"), "FieldMeaning::ForeignKey");
882        assert_eq!(infer_meaning("deleted_at"), "FieldMeaning::DateTime");
883        assert_eq!(infer_meaning("is_active"), "FieldMeaning::Boolean");
884        assert_eq!(infer_meaning("has_premium"), "FieldMeaning::Boolean");
885        assert_eq!(infer_meaning("password"), "FieldMeaning::Sensitive");
886        assert_eq!(infer_meaning("remember_token"), "FieldMeaning::Sensitive");
887        assert_eq!(infer_meaning("title"), "custom");
888    }
889}