1use std::fs;
2use std::path::Path;
3
4use anyhow::{Context, Result};
5use vespertide_config::VespertideConfig;
6use vespertide_core::TableDef;
7use vespertide_planner::validate_schema;
8
9pub fn load_models(config: &VespertideConfig) -> Result<Vec<TableDef>> {
11 let models_dir = config.models_dir();
12 if !models_dir.exists() {
13 return Ok(Vec::new());
14 }
15
16 let mut tables = Vec::new();
17 load_models_recursive(models_dir, &mut tables)?;
18
19 if !tables.is_empty() {
22 let normalized_tables: Vec<TableDef> = tables
23 .iter()
24 .map(|t| {
25 t.normalize()
26 .map_err(|e| anyhow::anyhow!("Failed to normalize table '{}': {}", t.name, e))
27 })
28 .collect::<Result<Vec<_>, _>>()?;
29
30 validate_schema(&normalized_tables)
31 .map_err(|e| anyhow::anyhow!("schema validation failed: {}", e))?;
32 }
33
34 Ok(tables)
35}
36
37fn load_models_recursive(dir: &Path, tables: &mut Vec<TableDef>) -> Result<()> {
39 let entries =
40 fs::read_dir(dir).with_context(|| format!("read models directory: {}", dir.display()))?;
41
42 for entry in entries {
43 let entry = entry.context("read directory entry")?;
44 let path = entry.path();
45
46 if path.is_dir() {
47 load_models_recursive(&path, tables)?;
49 continue;
50 }
51
52 if path.is_file() {
53 let ext = path.extension().and_then(|s| s.to_str());
54 if matches!(ext, Some("json") | Some("yaml") | Some("yml")) {
55 let content = fs::read_to_string(&path)
56 .with_context(|| format!("read model file: {}", path.display()))?;
57
58 let table: TableDef = if ext == Some("json") {
59 serde_json::from_str(&content)
60 .with_context(|| format!("parse JSON model: {}", path.display()))?
61 } else {
62 serde_yaml::from_str(&content)
63 .with_context(|| format!("parse YAML model: {}", path.display()))?
64 };
65
66 tables.push(table);
67 }
68 }
69 }
70
71 Ok(())
72}
73
74pub fn load_models_from_dir(
76 project_root: Option<std::path::PathBuf>,
77) -> Result<Vec<TableDef>, Box<dyn std::error::Error>> {
78 use std::env;
79
80 let project_root = if let Some(root) = project_root {
82 root
83 } else {
84 std::path::PathBuf::from(
85 env::var("CARGO_MANIFEST_DIR")
86 .context("CARGO_MANIFEST_DIR environment variable not set")?,
87 )
88 };
89
90 let config = crate::config::load_config_or_default(Some(project_root.clone()))
92 .map_err(|e| format!("Failed to load config: {}", e))?;
93
94 let models_dir = project_root.join(config.models_dir());
96 if !models_dir.exists() {
97 return Ok(Vec::new());
98 }
99
100 let mut tables = Vec::new();
101 load_models_recursive_internal(&models_dir, &mut tables)
102 .map_err(|e| format!("Failed to load models: {}", e))?;
103
104 let normalized_tables: Vec<TableDef> = tables
106 .into_iter()
107 .map(|t| {
108 t.normalize()
109 .map_err(|e| format!("Failed to normalize table '{}': {}", t.name, e))
110 })
111 .collect::<Result<Vec<_>, _>>()
112 .map_err(|e| e.to_string())?;
113
114 Ok(normalized_tables)
115}
116
117fn load_models_recursive_internal(
119 dir: &Path,
120 tables: &mut Vec<TableDef>,
121) -> Result<(), Box<dyn std::error::Error>> {
122 use std::fs;
123
124 let entries = fs::read_dir(dir)
125 .map_err(|e| format!("Failed to read models directory {}: {}", dir.display(), e))?;
126
127 for entry in entries {
128 let entry = entry.map_err(|e| format!("Failed to read directory entry: {}", e))?;
129 let path = entry.path();
130
131 if path.is_dir() {
132 load_models_recursive_internal(&path, tables)?;
134 continue;
135 }
136
137 if path.is_file() {
138 let ext = path.extension().and_then(|s| s.to_str());
139 if matches!(ext, Some("json") | Some("yaml") | Some("yml")) {
140 let content = fs::read_to_string(&path)
141 .map_err(|e| format!("Failed to read model file {}: {}", path.display(), e))?;
142
143 let table: TableDef = if ext == Some("json") {
144 serde_json::from_str(&content).map_err(|e| {
145 format!("Failed to parse JSON model {}: {}", path.display(), e)
146 })?
147 } else {
148 serde_yaml::from_str(&content).map_err(|e| {
149 format!("Failed to parse YAML model {}: {}", path.display(), e)
150 })?
151 };
152
153 tables.push(table);
154 }
155 }
156 }
157
158 Ok(())
159}
160
161pub fn load_models_at_compile_time() -> Result<Vec<TableDef>, Box<dyn std::error::Error>> {
163 load_models_from_dir(None)
164}
165
166#[cfg(test)]
167mod tests {
168 use super::*;
169 use serial_test::serial;
170 use std::fs;
171 use tempfile::tempdir;
172 use vespertide_core::{
173 ColumnDef, ColumnType, SimpleColumnType, TableConstraint,
174 schema::foreign_key::ForeignKeySyntax,
175 };
176
177 struct CwdGuard {
178 original: std::path::PathBuf,
179 }
180
181 impl CwdGuard {
182 fn new(dir: &std::path::PathBuf) -> Self {
183 let original = std::env::current_dir().unwrap();
184 std::env::set_current_dir(dir).unwrap();
185 Self { original }
186 }
187 }
188
189 impl Drop for CwdGuard {
190 fn drop(&mut self) {
191 let _ = std::env::set_current_dir(&self.original);
192 }
193 }
194
195 fn write_config() {
196 let cfg = VespertideConfig::default();
197 let text = serde_json::to_string_pretty(&cfg).unwrap();
198 fs::write("vespertide.json", text).unwrap();
199 }
200
201 #[test]
202 #[serial]
203 fn load_models_returns_empty_when_no_models_dir() {
204 let tmp = tempdir().unwrap();
205 let _guard = CwdGuard::new(&tmp.path().to_path_buf());
206 write_config();
207
208 let models = load_models(&VespertideConfig::default()).unwrap();
210 assert_eq!(models.len(), 0);
211 }
212
213 #[test]
214 #[serial]
215 fn load_models_reads_yaml_and_validates() {
216 let tmp = tempdir().unwrap();
217 let _guard = CwdGuard::new(&tmp.path().to_path_buf());
218 write_config();
219
220 fs::create_dir_all("models").unwrap();
221 let table = TableDef {
222 name: "users".into(),
223 description: None,
224 columns: vec![ColumnDef {
225 name: "id".into(),
226 r#type: ColumnType::Simple(SimpleColumnType::Integer),
227 nullable: false,
228 default: None,
229 comment: None,
230 primary_key: None,
231 unique: None,
232 index: None,
233 foreign_key: None,
234 }],
235 constraints: vec![TableConstraint::PrimaryKey {
236 auto_increment: false,
237 columns: vec!["id".into()],
238 }],
239 };
240 fs::write("models/users.yaml", serde_yaml::to_string(&table).unwrap()).unwrap();
241
242 let models = load_models(&VespertideConfig::default()).unwrap();
243 assert_eq!(models.len(), 1);
244 assert_eq!(models[0].name, "users");
245 }
246
247 #[test]
248 #[serial]
249 fn load_models_recursive_processes_subdirectories() {
250 let tmp = tempdir().unwrap();
251 let _guard = CwdGuard::new(&tmp.path().to_path_buf());
252 write_config();
253
254 fs::create_dir_all("models/subdir").unwrap();
255
256 let table = TableDef {
258 name: "subtable".into(),
259 description: None,
260 columns: vec![ColumnDef {
261 name: "id".into(),
262 r#type: ColumnType::Simple(SimpleColumnType::Integer),
263 nullable: false,
264 default: None,
265 comment: None,
266 primary_key: None,
267 unique: None,
268 index: None,
269 foreign_key: None,
270 }],
271 constraints: vec![TableConstraint::PrimaryKey {
272 auto_increment: false,
273 columns: vec!["id".into()],
274 }],
275 };
276 let content = serde_json::to_string_pretty(&table).unwrap();
277 fs::write("models/subdir/subtable.json", content).unwrap();
278
279 let models = load_models(&VespertideConfig::default()).unwrap();
280 assert_eq!(models.len(), 1);
281 assert_eq!(models[0].name, "subtable");
282 }
283
284 #[test]
285 #[serial]
286 fn load_models_fails_on_invalid_fk_format() {
287 let tmp = tempdir().unwrap();
288 let _guard = CwdGuard::new(&tmp.path().to_path_buf());
289 write_config();
290
291 fs::create_dir_all("models").unwrap();
292
293 let table = TableDef {
295 name: "orders".into(),
296 description: None,
297 columns: vec![ColumnDef {
298 name: "user_id".into(),
299 r#type: ColumnType::Simple(SimpleColumnType::Integer),
300 nullable: false,
301 default: None,
302 comment: None,
303 primary_key: None,
304 unique: None,
305 index: None,
306 foreign_key: Some(ForeignKeySyntax::String("invalid_format".into())),
308 }],
309 constraints: vec![],
310 };
311 fs::write(
312 "models/orders.json",
313 serde_json::to_string_pretty(&table).unwrap(),
314 )
315 .unwrap();
316
317 let result = load_models(&VespertideConfig::default());
318 assert!(result.is_err());
319 let err_msg = result.unwrap_err().to_string();
320 assert!(err_msg.contains("Failed to normalize table 'orders'"));
321 }
322
323 #[test]
324 #[serial]
325 fn test_load_models_from_dir_with_root() {
326 let temp_dir = tempdir().unwrap();
327 let models_dir = temp_dir.path().join("models");
328 fs::create_dir_all(&models_dir).unwrap();
329
330 let table = TableDef {
331 name: "users".into(),
332 description: None,
333 columns: vec![ColumnDef {
334 name: "id".into(),
335 r#type: ColumnType::Simple(SimpleColumnType::Integer),
336 nullable: false,
337 default: None,
338 comment: None,
339 primary_key: None,
340 unique: None,
341 index: None,
342 foreign_key: None,
343 }],
344 constraints: vec![],
345 };
346 fs::write(
347 models_dir.join("users.json"),
348 serde_json::to_string_pretty(&table).unwrap(),
349 )
350 .unwrap();
351
352 let result = load_models_from_dir(Some(temp_dir.path().to_path_buf()));
353 assert!(result.is_ok());
354 let models = result.unwrap();
355 assert_eq!(models.len(), 1);
356 assert_eq!(models[0].name, "users");
357 }
358
359 #[test]
360 #[serial]
361 fn test_load_models_from_dir_without_root() {
362 use std::env;
363
364 let original = env::var("CARGO_MANIFEST_DIR").ok();
366
367 unsafe {
369 env::remove_var("CARGO_MANIFEST_DIR");
370 }
371
372 let result = load_models_from_dir(None);
373 assert!(result.is_err());
374 let err_msg = result.unwrap_err().to_string();
375 assert!(err_msg.contains("CARGO_MANIFEST_DIR environment variable not set"));
376
377 drop(original);
378 }
379
380 #[test]
381 #[serial]
382 fn test_load_models_from_dir_no_models_dir() {
383 let temp_dir = tempdir().unwrap();
384 let result = load_models_from_dir(Some(temp_dir.path().to_path_buf()));
387 assert!(result.is_ok());
388 let models = result.unwrap();
389 assert_eq!(models.len(), 0);
390 }
391
392 #[test]
393 #[serial]
394 fn test_load_models_from_dir_with_yaml() {
395 let temp_dir = tempdir().unwrap();
396 let models_dir = temp_dir.path().join("models");
397 fs::create_dir_all(&models_dir).unwrap();
398
399 let table = TableDef {
400 name: "users".into(),
401 description: None,
402 columns: vec![ColumnDef {
403 name: "id".into(),
404 r#type: ColumnType::Simple(SimpleColumnType::Integer),
405 nullable: false,
406 default: None,
407 comment: None,
408 primary_key: None,
409 unique: None,
410 index: None,
411 foreign_key: None,
412 }],
413 constraints: vec![],
414 };
415 fs::write(
416 models_dir.join("users.yaml"),
417 serde_yaml::to_string(&table).unwrap(),
418 )
419 .unwrap();
420
421 let result = load_models_from_dir(Some(temp_dir.path().to_path_buf()));
422 assert!(result.is_ok());
423 let models = result.unwrap();
424 assert_eq!(models.len(), 1);
425 assert_eq!(models[0].name, "users");
426 }
427
428 #[test]
429 #[serial]
430 fn test_load_models_from_dir_with_yml() {
431 let temp_dir = tempdir().unwrap();
432 let models_dir = temp_dir.path().join("models");
433 fs::create_dir_all(&models_dir).unwrap();
434
435 let table = TableDef {
436 name: "users".into(),
437 description: None,
438 columns: vec![ColumnDef {
439 name: "id".into(),
440 r#type: ColumnType::Simple(SimpleColumnType::Integer),
441 nullable: false,
442 default: None,
443 comment: None,
444 primary_key: None,
445 unique: None,
446 index: None,
447 foreign_key: None,
448 }],
449 constraints: vec![],
450 };
451 fs::write(
452 models_dir.join("users.yml"),
453 serde_yaml::to_string(&table).unwrap(),
454 )
455 .unwrap();
456
457 let result = load_models_from_dir(Some(temp_dir.path().to_path_buf()));
458 assert!(result.is_ok());
459 let models = result.unwrap();
460 assert_eq!(models.len(), 1);
461 assert_eq!(models[0].name, "users");
462 }
463
464 #[test]
465 #[serial]
466 fn test_load_models_from_dir_recursive() {
467 let temp_dir = tempdir().unwrap();
468 let models_dir = temp_dir.path().join("models");
469 let subdir = models_dir.join("subdir");
470 fs::create_dir_all(&subdir).unwrap();
471
472 let table = TableDef {
473 name: "subtable".into(),
474 description: None,
475 columns: vec![ColumnDef {
476 name: "id".into(),
477 r#type: ColumnType::Simple(SimpleColumnType::Integer),
478 nullable: false,
479 default: None,
480 comment: None,
481 primary_key: None,
482 unique: None,
483 index: None,
484 foreign_key: None,
485 }],
486 constraints: vec![],
487 };
488 fs::write(
489 subdir.join("subtable.json"),
490 serde_json::to_string_pretty(&table).unwrap(),
491 )
492 .unwrap();
493
494 let result = load_models_from_dir(Some(temp_dir.path().to_path_buf()));
495 assert!(result.is_ok());
496 let models = result.unwrap();
497 assert_eq!(models.len(), 1);
498 assert_eq!(models[0].name, "subtable");
499 }
500
501 #[test]
502 #[serial]
503 fn test_load_models_from_dir_with_invalid_json() {
504 let temp_dir = tempdir().unwrap();
505 let models_dir = temp_dir.path().join("models");
506 fs::create_dir_all(&models_dir).unwrap();
507
508 fs::write(models_dir.join("invalid.json"), r#"{"invalid": json}"#).unwrap();
509
510 let result = load_models_from_dir(Some(temp_dir.path().to_path_buf()));
511 assert!(result.is_err());
512 let err_msg = result.unwrap_err().to_string();
513 assert!(err_msg.contains("Failed to parse JSON model"));
514 }
515
516 #[test]
517 #[serial]
518 fn test_load_models_from_dir_with_invalid_yaml() {
519 let temp_dir = tempdir().unwrap();
520 let models_dir = temp_dir.path().join("models");
521 fs::create_dir_all(&models_dir).unwrap();
522
523 fs::write(models_dir.join("invalid.yaml"), r#"invalid: [yaml"#).unwrap();
524
525 let result = load_models_from_dir(Some(temp_dir.path().to_path_buf()));
526 assert!(result.is_err());
527 let err_msg = result.unwrap_err().to_string();
528 assert!(err_msg.contains("Failed to parse YAML model"));
529 }
530
531 #[test]
532 #[serial]
533 fn test_load_models_from_dir_normalization_error() {
534 let temp_dir = tempdir().unwrap();
535 let models_dir = temp_dir.path().join("models");
536 fs::create_dir_all(&models_dir).unwrap();
537
538 let table = TableDef {
540 name: "orders".into(),
541 description: None,
542 columns: vec![ColumnDef {
543 name: "user_id".into(),
544 r#type: ColumnType::Simple(SimpleColumnType::Integer),
545 nullable: false,
546 default: None,
547 comment: None,
548 primary_key: None,
549 unique: None,
550 index: None,
551 foreign_key: Some(ForeignKeySyntax::String("invalid_format".into())),
552 }],
553 constraints: vec![],
554 };
555 fs::write(
556 models_dir.join("orders.json"),
557 serde_json::to_string_pretty(&table).unwrap(),
558 )
559 .unwrap();
560
561 let result = load_models_from_dir(Some(temp_dir.path().to_path_buf()));
562 assert!(result.is_err());
563 let err_msg = result.unwrap_err().to_string();
564 assert!(err_msg.contains("Failed to normalize table 'orders'"));
565 }
566
567 #[test]
568 #[serial]
569 fn test_load_models_from_dir_with_cargo_manifest_dir() {
570 let result = load_models_from_dir(None);
573 let _ = result;
577 }
578
579 #[test]
580 #[serial]
581 fn test_load_models_at_compile_time() {
582 let result = load_models_at_compile_time();
586 let _ = result;
590 }
591}