1use std::fs;
2use std::path::Path;
3
4use anyhow::{Context, Result};
5use vespertide_config::VespertideConfig;
6use vespertide_core::TableDef;
7use vespertide_planner::validate_schema;
8
9pub fn load_models(config: &VespertideConfig) -> Result<Vec<TableDef>> {
11 let models_dir = config.models_dir();
12 if !models_dir.exists() {
13 return Ok(Vec::new());
14 }
15
16 let mut tables = Vec::new();
17 load_models_recursive(models_dir, &mut tables)?;
18
19 if !tables.is_empty() {
22 let normalized_tables: Vec<TableDef> = tables
23 .iter()
24 .map(|t| {
25 t.normalize()
26 .map_err(|e| anyhow::anyhow!("Failed to normalize table '{}': {}", t.name, e))
27 })
28 .collect::<Result<Vec<_>, _>>()?;
29
30 validate_schema(&normalized_tables)
31 .map_err(|e| anyhow::anyhow!("schema validation failed: {}", e))?;
32 }
33
34 Ok(tables)
35}
36
37fn load_models_recursive(dir: &Path, tables: &mut Vec<TableDef>) -> Result<()> {
39 let entries =
40 fs::read_dir(dir).with_context(|| format!("read models directory: {}", dir.display()))?;
41
42 for entry in entries {
43 let entry = entry.context("read directory entry")?;
44 let path = entry.path();
45
46 if path.is_dir() {
47 load_models_recursive(&path, tables)?;
49 continue;
50 }
51
52 if path.is_file() {
53 let ext = path.extension().and_then(|s| s.to_str());
54 if matches!(ext, Some("json") | Some("yaml") | Some("yml")) {
55 let content = fs::read_to_string(&path)
56 .with_context(|| format!("read model file: {}", path.display()))?;
57
58 let table: TableDef = if ext == Some("json") {
59 serde_json::from_str(&content)
60 .with_context(|| format!("parse JSON model: {}", path.display()))?
61 } else {
62 serde_yaml::from_str(&content)
63 .with_context(|| format!("parse YAML model: {}", path.display()))?
64 };
65
66 tables.push(table);
67 }
68 }
69 }
70
71 Ok(())
72}
73
74pub fn load_models_from_dir(
76 project_root: Option<std::path::PathBuf>,
77) -> Result<Vec<TableDef>, Box<dyn std::error::Error>> {
78 use std::env;
79
80 let project_root = if let Some(root) = project_root {
82 root
83 } else {
84 std::path::PathBuf::from(
85 env::var("CARGO_MANIFEST_DIR")
86 .context("CARGO_MANIFEST_DIR environment variable not set")?,
87 )
88 };
89
90 let config = crate::config::load_config_or_default(Some(project_root.clone()))
92 .map_err(|e| format!("Failed to load config: {}", e))?;
93
94 let models_dir = project_root.join(config.models_dir());
96 if !models_dir.exists() {
97 return Ok(Vec::new());
98 }
99
100 let mut tables = Vec::new();
101 load_models_recursive_internal(&models_dir, &mut tables)
102 .map_err(|e| format!("Failed to load models: {}", e))?;
103
104 let normalized_tables: Vec<TableDef> = tables
106 .into_iter()
107 .map(|t| {
108 t.normalize()
109 .map_err(|e| format!("Failed to normalize table '{}': {}", t.name, e))
110 })
111 .collect::<Result<Vec<_>, _>>()
112 .map_err(|e| e.to_string())?;
113
114 Ok(normalized_tables)
115}
116
117fn load_models_recursive_internal(
119 dir: &Path,
120 tables: &mut Vec<TableDef>,
121) -> Result<(), Box<dyn std::error::Error>> {
122 use std::fs;
123
124 let entries = fs::read_dir(dir)
125 .map_err(|e| format!("Failed to read models directory {}: {}", dir.display(), e))?;
126
127 for entry in entries {
128 let entry = entry.map_err(|e| format!("Failed to read directory entry: {}", e))?;
129 let path = entry.path();
130
131 if path.is_dir() {
132 load_models_recursive_internal(&path, tables)?;
134 continue;
135 }
136
137 if path.is_file() {
138 let ext = path.extension().and_then(|s| s.to_str());
139 if matches!(ext, Some("json") | Some("yaml") | Some("yml")) {
140 let content = fs::read_to_string(&path)
141 .map_err(|e| format!("Failed to read model file {}: {}", path.display(), e))?;
142
143 let table: TableDef = if ext == Some("json") {
144 serde_json::from_str(&content).map_err(|e| {
145 format!("Failed to parse JSON model {}: {}", path.display(), e)
146 })?
147 } else {
148 serde_yaml::from_str(&content).map_err(|e| {
149 format!("Failed to parse YAML model {}: {}", path.display(), e)
150 })?
151 };
152
153 tables.push(table);
154 }
155 }
156 }
157
158 Ok(())
159}
160
161pub fn load_models_at_compile_time() -> Result<Vec<TableDef>, Box<dyn std::error::Error>> {
163 load_models_from_dir(None)
164}
165
166#[cfg(test)]
167mod tests {
168 use super::*;
169 use serial_test::serial;
170 use std::fs;
171 use tempfile::tempdir;
172 use vespertide_core::{
173 ColumnDef, ColumnType, SimpleColumnType, TableConstraint,
174 schema::foreign_key::ForeignKeySyntax,
175 };
176
177 struct CwdGuard {
178 original: std::path::PathBuf,
179 }
180
181 impl CwdGuard {
182 fn new(dir: &std::path::PathBuf) -> Self {
183 let original = std::env::current_dir().unwrap();
184 std::env::set_current_dir(dir).unwrap();
185 Self { original }
186 }
187 }
188
189 impl Drop for CwdGuard {
190 fn drop(&mut self) {
191 let _ = std::env::set_current_dir(&self.original);
192 }
193 }
194
195 fn write_config() {
196 let cfg = VespertideConfig::default();
197 let text = serde_json::to_string_pretty(&cfg).unwrap();
198 fs::write("vespertide.json", text).unwrap();
199 }
200
201 #[test]
202 #[serial]
203 fn load_models_returns_empty_when_no_models_dir() {
204 let tmp = tempdir().unwrap();
205 let _guard = CwdGuard::new(&tmp.path().to_path_buf());
206 write_config();
207
208 let models = load_models(&VespertideConfig::default()).unwrap();
210 assert_eq!(models.len(), 0);
211 }
212
213 #[test]
214 #[serial]
215 fn load_models_reads_yaml_and_validates() {
216 let tmp = tempdir().unwrap();
217 let _guard = CwdGuard::new(&tmp.path().to_path_buf());
218 write_config();
219
220 fs::create_dir_all("models").unwrap();
221 let table = TableDef {
222 name: "users".into(),
223 description: None,
224 columns: vec![ColumnDef {
225 name: "id".into(),
226 r#type: ColumnType::Simple(SimpleColumnType::Integer),
227 nullable: false,
228 default: None,
229 comment: None,
230 primary_key: None,
231 unique: None,
232 index: None,
233 foreign_key: None,
234 }],
235 constraints: vec![TableConstraint::PrimaryKey {
236 auto_increment: false,
237 columns: vec!["id".into()],
238 }],
239 };
240 fs::write("models/users.yaml", serde_yaml::to_string(&table).unwrap()).unwrap();
241
242 let models = load_models(&VespertideConfig::default()).unwrap();
243 assert_eq!(models.len(), 1);
244 assert_eq!(models[0].name, "users");
245 }
246
247 #[test]
248 #[serial]
249 fn load_models_recursive_processes_subdirectories() {
250 let tmp = tempdir().unwrap();
251 let _guard = CwdGuard::new(&tmp.path().to_path_buf());
252 write_config();
253
254 fs::create_dir_all("models/subdir").unwrap();
255
256 let table = TableDef {
258 name: "subtable".into(),
259 description: None,
260 columns: vec![ColumnDef {
261 name: "id".into(),
262 r#type: ColumnType::Simple(SimpleColumnType::Integer),
263 nullable: false,
264 default: None,
265 comment: None,
266 primary_key: None,
267 unique: None,
268 index: None,
269 foreign_key: None,
270 }],
271 constraints: vec![TableConstraint::PrimaryKey {
272 auto_increment: false,
273 columns: vec!["id".into()],
274 }],
275 };
276 let content = serde_json::to_string_pretty(&table).unwrap();
277 fs::write("models/subdir/subtable.json", content).unwrap();
278
279 let models = load_models(&VespertideConfig::default()).unwrap();
280 assert_eq!(models.len(), 1);
281 assert_eq!(models[0].name, "subtable");
282 }
283
284 #[test]
285 #[serial]
286 fn load_models_fails_on_invalid_fk_format() {
287 let tmp = tempdir().unwrap();
288 let _guard = CwdGuard::new(&tmp.path().to_path_buf());
289 write_config();
290
291 fs::create_dir_all("models").unwrap();
292
293 let table = TableDef {
295 name: "orders".into(),
296 description: None,
297 columns: vec![ColumnDef {
298 name: "user_id".into(),
299 r#type: ColumnType::Simple(SimpleColumnType::Integer),
300 nullable: false,
301 default: None,
302 comment: None,
303 primary_key: None,
304 unique: None,
305 index: None,
306 foreign_key: Some(ForeignKeySyntax::String("invalid_format".into())),
308 }],
309 constraints: vec![],
310 };
311 fs::write(
312 "models/orders.json",
313 serde_json::to_string_pretty(&table).unwrap(),
314 )
315 .unwrap();
316
317 let result = load_models(&VespertideConfig::default());
318 assert!(result.is_err());
319 let err_msg = result.unwrap_err().to_string();
320 assert!(err_msg.contains("Failed to normalize table 'orders'"));
321 }
322
323 #[test]
324 #[serial]
325 fn test_load_models_from_dir_with_root() {
326 let temp_dir = tempdir().unwrap();
327 let models_dir = temp_dir.path().join("models");
328 fs::create_dir_all(&models_dir).unwrap();
329
330 let table = TableDef {
331 name: "users".into(),
332 description: None,
333 columns: vec![ColumnDef {
334 name: "id".into(),
335 r#type: ColumnType::Simple(SimpleColumnType::Integer),
336 nullable: false,
337 default: None,
338 comment: None,
339 primary_key: None,
340 unique: None,
341 index: None,
342 foreign_key: None,
343 }],
344 constraints: vec![],
345 };
346 fs::write(
347 models_dir.join("users.json"),
348 serde_json::to_string_pretty(&table).unwrap(),
349 )
350 .unwrap();
351
352 let result = load_models_from_dir(Some(temp_dir.path().to_path_buf()));
353 assert!(result.is_ok());
354 let models = result.unwrap();
355 assert_eq!(models.len(), 1);
356 assert_eq!(models[0].name, "users");
357 }
358
359 #[test]
360 #[serial]
361 fn test_load_models_from_dir_without_root() {
362 use std::env;
363
364 let original = env::var("CARGO_MANIFEST_DIR").ok();
366
367 unsafe {
369 env::remove_var("CARGO_MANIFEST_DIR");
370 }
371
372 let result = load_models_from_dir(None);
373 assert!(result.is_err());
374 let err_msg = result.unwrap_err().to_string();
375 assert!(err_msg.contains("CARGO_MANIFEST_DIR environment variable not set"));
376
377 if let Some(val) = original {
379 unsafe {
380 env::set_var("CARGO_MANIFEST_DIR", val);
381 }
382 }
383 }
384
385 #[test]
386 #[serial]
387 fn test_load_models_from_dir_no_models_dir() {
388 let temp_dir = tempdir().unwrap();
389 let result = load_models_from_dir(Some(temp_dir.path().to_path_buf()));
392 assert!(result.is_ok());
393 let models = result.unwrap();
394 assert_eq!(models.len(), 0);
395 }
396
397 #[test]
398 #[serial]
399 fn test_load_models_from_dir_with_yaml() {
400 let temp_dir = tempdir().unwrap();
401 let models_dir = temp_dir.path().join("models");
402 fs::create_dir_all(&models_dir).unwrap();
403
404 let table = TableDef {
405 name: "users".into(),
406 description: None,
407 columns: vec![ColumnDef {
408 name: "id".into(),
409 r#type: ColumnType::Simple(SimpleColumnType::Integer),
410 nullable: false,
411 default: None,
412 comment: None,
413 primary_key: None,
414 unique: None,
415 index: None,
416 foreign_key: None,
417 }],
418 constraints: vec![],
419 };
420 fs::write(
421 models_dir.join("users.yaml"),
422 serde_yaml::to_string(&table).unwrap(),
423 )
424 .unwrap();
425
426 let result = load_models_from_dir(Some(temp_dir.path().to_path_buf()));
427 assert!(result.is_ok());
428 let models = result.unwrap();
429 assert_eq!(models.len(), 1);
430 assert_eq!(models[0].name, "users");
431 }
432
433 #[test]
434 #[serial]
435 fn test_load_models_from_dir_with_yml() {
436 let temp_dir = tempdir().unwrap();
437 let models_dir = temp_dir.path().join("models");
438 fs::create_dir_all(&models_dir).unwrap();
439
440 let table = TableDef {
441 name: "users".into(),
442 description: None,
443 columns: vec![ColumnDef {
444 name: "id".into(),
445 r#type: ColumnType::Simple(SimpleColumnType::Integer),
446 nullable: false,
447 default: None,
448 comment: None,
449 primary_key: None,
450 unique: None,
451 index: None,
452 foreign_key: None,
453 }],
454 constraints: vec![],
455 };
456 fs::write(
457 models_dir.join("users.yml"),
458 serde_yaml::to_string(&table).unwrap(),
459 )
460 .unwrap();
461
462 let result = load_models_from_dir(Some(temp_dir.path().to_path_buf()));
463 assert!(result.is_ok());
464 let models = result.unwrap();
465 assert_eq!(models.len(), 1);
466 assert_eq!(models[0].name, "users");
467 }
468
469 #[test]
470 #[serial]
471 fn test_load_models_from_dir_recursive() {
472 let temp_dir = tempdir().unwrap();
473 let models_dir = temp_dir.path().join("models");
474 let subdir = models_dir.join("subdir");
475 fs::create_dir_all(&subdir).unwrap();
476
477 let table = TableDef {
478 name: "subtable".into(),
479 description: None,
480 columns: vec![ColumnDef {
481 name: "id".into(),
482 r#type: ColumnType::Simple(SimpleColumnType::Integer),
483 nullable: false,
484 default: None,
485 comment: None,
486 primary_key: None,
487 unique: None,
488 index: None,
489 foreign_key: None,
490 }],
491 constraints: vec![],
492 };
493 fs::write(
494 subdir.join("subtable.json"),
495 serde_json::to_string_pretty(&table).unwrap(),
496 )
497 .unwrap();
498
499 let result = load_models_from_dir(Some(temp_dir.path().to_path_buf()));
500 assert!(result.is_ok());
501 let models = result.unwrap();
502 assert_eq!(models.len(), 1);
503 assert_eq!(models[0].name, "subtable");
504 }
505
506 #[test]
507 #[serial]
508 fn test_load_models_from_dir_with_invalid_json() {
509 let temp_dir = tempdir().unwrap();
510 let models_dir = temp_dir.path().join("models");
511 fs::create_dir_all(&models_dir).unwrap();
512
513 fs::write(models_dir.join("invalid.json"), r#"{"invalid": json}"#).unwrap();
514
515 let result = load_models_from_dir(Some(temp_dir.path().to_path_buf()));
516 assert!(result.is_err());
517 let err_msg = result.unwrap_err().to_string();
518 assert!(err_msg.contains("Failed to parse JSON model"));
519 }
520
521 #[test]
522 #[serial]
523 fn test_load_models_from_dir_with_invalid_yaml() {
524 let temp_dir = tempdir().unwrap();
525 let models_dir = temp_dir.path().join("models");
526 fs::create_dir_all(&models_dir).unwrap();
527
528 fs::write(models_dir.join("invalid.yaml"), r#"invalid: [yaml"#).unwrap();
529
530 let result = load_models_from_dir(Some(temp_dir.path().to_path_buf()));
531 assert!(result.is_err());
532 let err_msg = result.unwrap_err().to_string();
533 assert!(err_msg.contains("Failed to parse YAML model"));
534 }
535
536 #[test]
537 #[serial]
538 fn test_load_models_from_dir_normalization_error() {
539 let temp_dir = tempdir().unwrap();
540 let models_dir = temp_dir.path().join("models");
541 fs::create_dir_all(&models_dir).unwrap();
542
543 let table = TableDef {
545 name: "orders".into(),
546 description: None,
547 columns: vec![ColumnDef {
548 name: "user_id".into(),
549 r#type: ColumnType::Simple(SimpleColumnType::Integer),
550 nullable: false,
551 default: None,
552 comment: None,
553 primary_key: None,
554 unique: None,
555 index: None,
556 foreign_key: Some(ForeignKeySyntax::String("invalid_format".into())),
557 }],
558 constraints: vec![],
559 };
560 fs::write(
561 models_dir.join("orders.json"),
562 serde_json::to_string_pretty(&table).unwrap(),
563 )
564 .unwrap();
565
566 let result = load_models_from_dir(Some(temp_dir.path().to_path_buf()));
567 assert!(result.is_err());
568 let err_msg = result.unwrap_err().to_string();
569 assert!(err_msg.contains("Failed to normalize table 'orders'"));
570 }
571
572 #[test]
573 #[serial]
574 fn test_load_models_from_dir_with_cargo_manifest_dir() {
575 let result = load_models_from_dir(None);
578 let _ = result;
582 }
583
584 #[test]
585 #[serial]
586 fn test_load_models_at_compile_time() {
587 let result = load_models_at_compile_time();
591 let _ = result;
595 }
596}