vespertide_loader/
models.rs

1use std::fs;
2use std::path::Path;
3
4use anyhow::{Context, Result};
5use vespertide_config::VespertideConfig;
6use vespertide_core::TableDef;
7use vespertide_planner::validate_schema;
8
9/// Load all model definitions from the models directory (recursively).
10pub fn load_models(config: &VespertideConfig) -> Result<Vec<TableDef>> {
11    let models_dir = config.models_dir();
12    if !models_dir.exists() {
13        return Ok(Vec::new());
14    }
15
16    let mut tables = Vec::new();
17    load_models_recursive(models_dir, &mut tables)?;
18
19    // Validate schema integrity using normalized version
20    // But return the original tables to preserve inline constraints
21    if !tables.is_empty() {
22        let normalized_tables: Vec<TableDef> = tables
23            .iter()
24            .map(|t| {
25                t.normalize()
26                    .map_err(|e| anyhow::anyhow!("Failed to normalize table '{}': {}", t.name, e))
27            })
28            .collect::<Result<Vec<_>, _>>()?;
29
30        validate_schema(&normalized_tables)
31            .map_err(|e| anyhow::anyhow!("schema validation failed: {}", e))?;
32    }
33
34    Ok(tables)
35}
36
37/// Recursively walk directory and load model files.
38fn load_models_recursive(dir: &Path, tables: &mut Vec<TableDef>) -> Result<()> {
39    let entries =
40        fs::read_dir(dir).with_context(|| format!("read models directory: {}", dir.display()))?;
41
42    for entry in entries {
43        let entry = entry.context("read directory entry")?;
44        let path = entry.path();
45
46        if path.is_dir() {
47            // Recursively process subdirectories
48            load_models_recursive(&path, tables)?;
49            continue;
50        }
51
52        if path.is_file() {
53            let ext = path.extension().and_then(|s| s.to_str());
54            if matches!(ext, Some("json") | Some("yaml") | Some("yml")) {
55                let content = fs::read_to_string(&path)
56                    .with_context(|| format!("read model file: {}", path.display()))?;
57
58                let table: TableDef = if ext == Some("json") {
59                    serde_json::from_str(&content)
60                        .with_context(|| format!("parse JSON model: {}", path.display()))?
61                } else {
62                    serde_yaml::from_str(&content)
63                        .with_context(|| format!("parse YAML model: {}", path.display()))?
64                };
65
66                tables.push(table);
67            }
68        }
69    }
70
71    Ok(())
72}
73
74/// Load models from a specific directory (for compile-time use in macros).
75pub fn load_models_from_dir(
76    project_root: Option<std::path::PathBuf>,
77) -> Result<Vec<TableDef>, Box<dyn std::error::Error>> {
78    use std::env;
79
80    // Locate project root from CARGO_MANIFEST_DIR or use provided path
81    let project_root = if let Some(root) = project_root {
82        root
83    } else {
84        std::path::PathBuf::from(
85            env::var("CARGO_MANIFEST_DIR")
86                .context("CARGO_MANIFEST_DIR environment variable not set")?,
87        )
88    };
89
90    // Read vespertide.json or use defaults
91    let config = crate::config::load_config_or_default(Some(project_root.clone()))
92        .map_err(|e| format!("Failed to load config: {}", e))?;
93
94    // Read models directory
95    let models_dir = project_root.join(config.models_dir());
96    if !models_dir.exists() {
97        return Ok(Vec::new());
98    }
99
100    let mut tables = Vec::new();
101    load_models_recursive_internal(&models_dir, &mut tables)
102        .map_err(|e| format!("Failed to load models: {}", e))?;
103
104    // Normalize tables
105    let normalized_tables: Vec<TableDef> = tables
106        .into_iter()
107        .map(|t| {
108            t.normalize()
109                .map_err(|e| format!("Failed to normalize table '{}': {}", t.name, e))
110        })
111        .collect::<Result<Vec<_>, _>>()
112        .map_err(|e| e.to_string())?;
113
114    Ok(normalized_tables)
115}
116
117/// Internal recursive function for loading models (used by both runtime and compile-time).
118fn load_models_recursive_internal(
119    dir: &Path,
120    tables: &mut Vec<TableDef>,
121) -> Result<(), Box<dyn std::error::Error>> {
122    use std::fs;
123
124    let entries = fs::read_dir(dir)
125        .map_err(|e| format!("Failed to read models directory {}: {}", dir.display(), e))?;
126
127    for entry in entries {
128        let entry = entry.map_err(|e| format!("Failed to read directory entry: {}", e))?;
129        let path = entry.path();
130
131        if path.is_dir() {
132            // Recursively process subdirectories
133            load_models_recursive_internal(&path, tables)?;
134            continue;
135        }
136
137        if path.is_file() {
138            let ext = path.extension().and_then(|s| s.to_str());
139            if matches!(ext, Some("json") | Some("yaml") | Some("yml")) {
140                let content = fs::read_to_string(&path)
141                    .map_err(|e| format!("Failed to read model file {}: {}", path.display(), e))?;
142
143                let table: TableDef = if ext == Some("json") {
144                    serde_json::from_str(&content).map_err(|e| {
145                        format!("Failed to parse JSON model {}: {}", path.display(), e)
146                    })?
147                } else {
148                    serde_yaml::from_str(&content).map_err(|e| {
149                        format!("Failed to parse YAML model {}: {}", path.display(), e)
150                    })?
151                };
152
153                tables.push(table);
154            }
155        }
156    }
157
158    Ok(())
159}
160
161/// Load models at compile time (for macro use).
162pub fn load_models_at_compile_time() -> Result<Vec<TableDef>, Box<dyn std::error::Error>> {
163    load_models_from_dir(None)
164}
165
166#[cfg(test)]
167mod tests {
168    use super::*;
169    use serial_test::serial;
170    use std::fs;
171    use tempfile::tempdir;
172    use vespertide_core::{
173        ColumnDef, ColumnType, SimpleColumnType, TableConstraint,
174        schema::foreign_key::ForeignKeySyntax,
175    };
176
177    struct CwdGuard {
178        original: std::path::PathBuf,
179    }
180
181    impl CwdGuard {
182        fn new(dir: &std::path::PathBuf) -> Self {
183            let original = std::env::current_dir().unwrap();
184            std::env::set_current_dir(dir).unwrap();
185            Self { original }
186        }
187    }
188
189    impl Drop for CwdGuard {
190        fn drop(&mut self) {
191            let _ = std::env::set_current_dir(&self.original);
192        }
193    }
194
195    fn write_config() {
196        let cfg = VespertideConfig::default();
197        let text = serde_json::to_string_pretty(&cfg).unwrap();
198        fs::write("vespertide.json", text).unwrap();
199    }
200
201    #[test]
202    #[serial]
203    fn load_models_returns_empty_when_no_models_dir() {
204        let tmp = tempdir().unwrap();
205        let _guard = CwdGuard::new(&tmp.path().to_path_buf());
206        write_config();
207
208        // Don't create models directory
209        let models = load_models(&VespertideConfig::default()).unwrap();
210        assert_eq!(models.len(), 0);
211    }
212
213    #[test]
214    #[serial]
215    fn load_models_reads_yaml_and_validates() {
216        let tmp = tempdir().unwrap();
217        let _guard = CwdGuard::new(&tmp.path().to_path_buf());
218        write_config();
219
220        fs::create_dir_all("models").unwrap();
221        let table = TableDef {
222            name: "users".into(),
223            description: None,
224            columns: vec![ColumnDef {
225                name: "id".into(),
226                r#type: ColumnType::Simple(SimpleColumnType::Integer),
227                nullable: false,
228                default: None,
229                comment: None,
230                primary_key: None,
231                unique: None,
232                index: None,
233                foreign_key: None,
234            }],
235            constraints: vec![TableConstraint::PrimaryKey {
236                auto_increment: false,
237                columns: vec!["id".into()],
238            }],
239        };
240        fs::write("models/users.yaml", serde_yaml::to_string(&table).unwrap()).unwrap();
241
242        let models = load_models(&VespertideConfig::default()).unwrap();
243        assert_eq!(models.len(), 1);
244        assert_eq!(models[0].name, "users");
245    }
246
247    #[test]
248    #[serial]
249    fn load_models_recursive_processes_subdirectories() {
250        let tmp = tempdir().unwrap();
251        let _guard = CwdGuard::new(&tmp.path().to_path_buf());
252        write_config();
253
254        fs::create_dir_all("models/subdir").unwrap();
255
256        // Create model in subdirectory
257        let table = TableDef {
258            name: "subtable".into(),
259            description: None,
260            columns: vec![ColumnDef {
261                name: "id".into(),
262                r#type: ColumnType::Simple(SimpleColumnType::Integer),
263                nullable: false,
264                default: None,
265                comment: None,
266                primary_key: None,
267                unique: None,
268                index: None,
269                foreign_key: None,
270            }],
271            constraints: vec![TableConstraint::PrimaryKey {
272                auto_increment: false,
273                columns: vec!["id".into()],
274            }],
275        };
276        let content = serde_json::to_string_pretty(&table).unwrap();
277        fs::write("models/subdir/subtable.json", content).unwrap();
278
279        let models = load_models(&VespertideConfig::default()).unwrap();
280        assert_eq!(models.len(), 1);
281        assert_eq!(models[0].name, "subtable");
282    }
283
284    #[test]
285    #[serial]
286    fn load_models_fails_on_invalid_fk_format() {
287        let tmp = tempdir().unwrap();
288        let _guard = CwdGuard::new(&tmp.path().to_path_buf());
289        write_config();
290
291        fs::create_dir_all("models").unwrap();
292
293        // Create a model with invalid FK string format (missing dot separator)
294        let table = TableDef {
295            name: "orders".into(),
296            description: None,
297            columns: vec![ColumnDef {
298                name: "user_id".into(),
299                r#type: ColumnType::Simple(SimpleColumnType::Integer),
300                nullable: false,
301                default: None,
302                comment: None,
303                primary_key: None,
304                unique: None,
305                index: None,
306                // Invalid FK format: should be "table.column" but missing the dot
307                foreign_key: Some(ForeignKeySyntax::String("invalid_format".into())),
308            }],
309            constraints: vec![],
310        };
311        fs::write(
312            "models/orders.json",
313            serde_json::to_string_pretty(&table).unwrap(),
314        )
315        .unwrap();
316
317        let result = load_models(&VespertideConfig::default());
318        assert!(result.is_err());
319        let err_msg = result.unwrap_err().to_string();
320        assert!(err_msg.contains("Failed to normalize table 'orders'"));
321    }
322
323    #[test]
324    #[serial]
325    fn test_load_models_from_dir_with_root() {
326        let temp_dir = tempdir().unwrap();
327        let models_dir = temp_dir.path().join("models");
328        fs::create_dir_all(&models_dir).unwrap();
329
330        let table = TableDef {
331            name: "users".into(),
332            description: None,
333            columns: vec![ColumnDef {
334                name: "id".into(),
335                r#type: ColumnType::Simple(SimpleColumnType::Integer),
336                nullable: false,
337                default: None,
338                comment: None,
339                primary_key: None,
340                unique: None,
341                index: None,
342                foreign_key: None,
343            }],
344            constraints: vec![],
345        };
346        fs::write(
347            models_dir.join("users.json"),
348            serde_json::to_string_pretty(&table).unwrap(),
349        )
350        .unwrap();
351
352        let result = load_models_from_dir(Some(temp_dir.path().to_path_buf()));
353        assert!(result.is_ok());
354        let models = result.unwrap();
355        assert_eq!(models.len(), 1);
356        assert_eq!(models[0].name, "users");
357    }
358
359    #[test]
360    #[serial]
361    fn test_load_models_from_dir_without_root() {
362        use std::env;
363
364        // Save the original value
365        let original = env::var("CARGO_MANIFEST_DIR").ok();
366
367        // Remove CARGO_MANIFEST_DIR to test the error path
368        unsafe {
369            env::remove_var("CARGO_MANIFEST_DIR");
370        }
371
372        let result = load_models_from_dir(None);
373        assert!(result.is_err());
374        let err_msg = result.unwrap_err().to_string();
375        assert!(err_msg.contains("CARGO_MANIFEST_DIR environment variable not set"));
376
377        // Restore the original value if it existed
378        if let Some(val) = original {
379            unsafe {
380                env::set_var("CARGO_MANIFEST_DIR", val);
381            }
382        }
383    }
384
385    #[test]
386    #[serial]
387    fn test_load_models_from_dir_no_models_dir() {
388        let temp_dir = tempdir().unwrap();
389        // Don't create models directory
390
391        let result = load_models_from_dir(Some(temp_dir.path().to_path_buf()));
392        assert!(result.is_ok());
393        let models = result.unwrap();
394        assert_eq!(models.len(), 0);
395    }
396
397    #[test]
398    #[serial]
399    fn test_load_models_from_dir_with_yaml() {
400        let temp_dir = tempdir().unwrap();
401        let models_dir = temp_dir.path().join("models");
402        fs::create_dir_all(&models_dir).unwrap();
403
404        let table = TableDef {
405            name: "users".into(),
406            description: None,
407            columns: vec![ColumnDef {
408                name: "id".into(),
409                r#type: ColumnType::Simple(SimpleColumnType::Integer),
410                nullable: false,
411                default: None,
412                comment: None,
413                primary_key: None,
414                unique: None,
415                index: None,
416                foreign_key: None,
417            }],
418            constraints: vec![],
419        };
420        fs::write(
421            models_dir.join("users.yaml"),
422            serde_yaml::to_string(&table).unwrap(),
423        )
424        .unwrap();
425
426        let result = load_models_from_dir(Some(temp_dir.path().to_path_buf()));
427        assert!(result.is_ok());
428        let models = result.unwrap();
429        assert_eq!(models.len(), 1);
430        assert_eq!(models[0].name, "users");
431    }
432
433    #[test]
434    #[serial]
435    fn test_load_models_from_dir_with_yml() {
436        let temp_dir = tempdir().unwrap();
437        let models_dir = temp_dir.path().join("models");
438        fs::create_dir_all(&models_dir).unwrap();
439
440        let table = TableDef {
441            name: "users".into(),
442            description: None,
443            columns: vec![ColumnDef {
444                name: "id".into(),
445                r#type: ColumnType::Simple(SimpleColumnType::Integer),
446                nullable: false,
447                default: None,
448                comment: None,
449                primary_key: None,
450                unique: None,
451                index: None,
452                foreign_key: None,
453            }],
454            constraints: vec![],
455        };
456        fs::write(
457            models_dir.join("users.yml"),
458            serde_yaml::to_string(&table).unwrap(),
459        )
460        .unwrap();
461
462        let result = load_models_from_dir(Some(temp_dir.path().to_path_buf()));
463        assert!(result.is_ok());
464        let models = result.unwrap();
465        assert_eq!(models.len(), 1);
466        assert_eq!(models[0].name, "users");
467    }
468
469    #[test]
470    #[serial]
471    fn test_load_models_from_dir_recursive() {
472        let temp_dir = tempdir().unwrap();
473        let models_dir = temp_dir.path().join("models");
474        let subdir = models_dir.join("subdir");
475        fs::create_dir_all(&subdir).unwrap();
476
477        let table = TableDef {
478            name: "subtable".into(),
479            description: None,
480            columns: vec![ColumnDef {
481                name: "id".into(),
482                r#type: ColumnType::Simple(SimpleColumnType::Integer),
483                nullable: false,
484                default: None,
485                comment: None,
486                primary_key: None,
487                unique: None,
488                index: None,
489                foreign_key: None,
490            }],
491            constraints: vec![],
492        };
493        fs::write(
494            subdir.join("subtable.json"),
495            serde_json::to_string_pretty(&table).unwrap(),
496        )
497        .unwrap();
498
499        let result = load_models_from_dir(Some(temp_dir.path().to_path_buf()));
500        assert!(result.is_ok());
501        let models = result.unwrap();
502        assert_eq!(models.len(), 1);
503        assert_eq!(models[0].name, "subtable");
504    }
505
506    #[test]
507    #[serial]
508    fn test_load_models_from_dir_with_invalid_json() {
509        let temp_dir = tempdir().unwrap();
510        let models_dir = temp_dir.path().join("models");
511        fs::create_dir_all(&models_dir).unwrap();
512
513        fs::write(models_dir.join("invalid.json"), r#"{"invalid": json}"#).unwrap();
514
515        let result = load_models_from_dir(Some(temp_dir.path().to_path_buf()));
516        assert!(result.is_err());
517        let err_msg = result.unwrap_err().to_string();
518        assert!(err_msg.contains("Failed to parse JSON model"));
519    }
520
521    #[test]
522    #[serial]
523    fn test_load_models_from_dir_with_invalid_yaml() {
524        let temp_dir = tempdir().unwrap();
525        let models_dir = temp_dir.path().join("models");
526        fs::create_dir_all(&models_dir).unwrap();
527
528        fs::write(models_dir.join("invalid.yaml"), r#"invalid: [yaml"#).unwrap();
529
530        let result = load_models_from_dir(Some(temp_dir.path().to_path_buf()));
531        assert!(result.is_err());
532        let err_msg = result.unwrap_err().to_string();
533        assert!(err_msg.contains("Failed to parse YAML model"));
534    }
535
536    #[test]
537    #[serial]
538    fn test_load_models_from_dir_normalization_error() {
539        let temp_dir = tempdir().unwrap();
540        let models_dir = temp_dir.path().join("models");
541        fs::create_dir_all(&models_dir).unwrap();
542
543        // Create a model with invalid FK format
544        let table = TableDef {
545            name: "orders".into(),
546            description: None,
547            columns: vec![ColumnDef {
548                name: "user_id".into(),
549                r#type: ColumnType::Simple(SimpleColumnType::Integer),
550                nullable: false,
551                default: None,
552                comment: None,
553                primary_key: None,
554                unique: None,
555                index: None,
556                foreign_key: Some(ForeignKeySyntax::String("invalid_format".into())),
557            }],
558            constraints: vec![],
559        };
560        fs::write(
561            models_dir.join("orders.json"),
562            serde_json::to_string_pretty(&table).unwrap(),
563        )
564        .unwrap();
565
566        let result = load_models_from_dir(Some(temp_dir.path().to_path_buf()));
567        assert!(result.is_err());
568        let err_msg = result.unwrap_err().to_string();
569        assert!(err_msg.contains("Failed to normalize table 'orders'"));
570    }
571
572    #[test]
573    #[serial]
574    fn test_load_models_from_dir_with_cargo_manifest_dir() {
575        // Test the path where CARGO_MANIFEST_DIR is set (line 87)
576        // In cargo test environment, CARGO_MANIFEST_DIR is usually set
577        let result = load_models_from_dir(None);
578        // This might succeed if CARGO_MANIFEST_DIR is set (like in cargo test)
579        // or fail if it's not set
580        // Either way, we're testing the code path including line 87
581        let _ = result;
582    }
583
584    #[test]
585    #[serial]
586    fn test_load_models_at_compile_time() {
587        // This function just calls load_models_from_dir(None)
588        // We can't easily test it without CARGO_MANIFEST_DIR, but we can verify
589        // it doesn't panic
590        let result = load_models_at_compile_time();
591        // This might succeed if CARGO_MANIFEST_DIR is set (like in cargo test)
592        // or fail if it's not set
593        // Either way, we're testing the code path
594        let _ = result;
595    }
596}