1use std::fs;
2use std::path::Path;
3
4use anyhow::{Context, Result};
5use vespertide_config::VespertideConfig;
6use vespertide_core::TableDef;
7use vespertide_planner::validate_schema;
8
9pub fn load_models(config: &VespertideConfig) -> Result<Vec<TableDef>> {
11 let models_dir = config.models_dir();
12 if !models_dir.exists() {
13 return Ok(Vec::new());
14 }
15
16 let mut tables = Vec::new();
17 load_models_recursive(models_dir, &mut tables)?;
18
19 if !tables.is_empty() {
22 let normalized_tables: Vec<TableDef> = tables
23 .iter()
24 .map(|t| {
25 t.normalize()
26 .map_err(|e| anyhow::anyhow!("Failed to normalize table '{}': {}", t.name, e))
27 })
28 .collect::<Result<Vec<_>, _>>()?;
29
30 validate_schema(&normalized_tables)
31 .map_err(|e| anyhow::anyhow!("schema validation failed: {}", e))?;
32 }
33
34 Ok(tables)
35}
36
37fn load_models_recursive(dir: &Path, tables: &mut Vec<TableDef>) -> Result<()> {
39 let entries =
40 fs::read_dir(dir).with_context(|| format!("read models directory: {}", dir.display()))?;
41
42 for entry in entries {
43 let entry = entry.context("read directory entry")?;
44 let path = entry.path();
45
46 if path.is_dir() {
47 load_models_recursive(&path, tables)?;
49 continue;
50 }
51
52 if path.is_file() {
53 let ext = path.extension().and_then(|s| s.to_str());
54 if matches!(ext, Some("json") | Some("yaml") | Some("yml")) {
55 let content = fs::read_to_string(&path)
56 .with_context(|| format!("read model file: {}", path.display()))?;
57
58 let table: TableDef = if ext == Some("json") {
59 serde_json::from_str(&content)
60 .with_context(|| format!("parse JSON model: {}", path.display()))?
61 } else {
62 serde_yaml::from_str(&content)
63 .with_context(|| format!("parse YAML model: {}", path.display()))?
64 };
65
66 tables.push(table);
67 }
68 }
69 }
70
71 Ok(())
72}
73
74pub fn load_models_from_dir(
76 project_root: Option<std::path::PathBuf>,
77) -> Result<Vec<TableDef>, Box<dyn std::error::Error>> {
78 use std::env;
79
80 let project_root = if let Some(root) = project_root {
82 root
83 } else {
84 std::path::PathBuf::from(
85 env::var("CARGO_MANIFEST_DIR")
86 .context("CARGO_MANIFEST_DIR environment variable not set")?,
87 )
88 };
89
90 let config = crate::config::load_config_or_default(Some(project_root.clone()))
92 .map_err(|e| format!("Failed to load config: {}", e))?;
93
94 let models_dir = project_root.join(config.models_dir());
96 if !models_dir.exists() {
97 return Ok(Vec::new());
98 }
99
100 let mut tables = Vec::new();
101 load_models_recursive_internal(&models_dir, &mut tables)
102 .map_err(|e| format!("Failed to load models: {}", e))?;
103
104 let normalized_tables: Vec<TableDef> = tables
106 .into_iter()
107 .map(|t| {
108 t.normalize()
109 .map_err(|e| format!("Failed to normalize table '{}': {}", t.name, e))
110 })
111 .collect::<Result<Vec<_>, _>>()
112 .map_err(|e| e.to_string())?;
113
114 Ok(normalized_tables)
115}
116
117fn load_models_recursive_internal(
119 dir: &Path,
120 tables: &mut Vec<TableDef>,
121) -> Result<(), Box<dyn std::error::Error>> {
122 use std::fs;
123
124 let entries = fs::read_dir(dir)
125 .map_err(|e| format!("Failed to read models directory {}: {}", dir.display(), e))?;
126
127 for entry in entries {
128 let entry = entry.map_err(|e| format!("Failed to read directory entry: {}", e))?;
129 let path = entry.path();
130
131 if path.is_dir() {
132 load_models_recursive_internal(&path, tables)?;
134 continue;
135 }
136
137 if path.is_file() {
138 let ext = path.extension().and_then(|s| s.to_str());
139 if matches!(ext, Some("json") | Some("yaml") | Some("yml")) {
140 let content = fs::read_to_string(&path)
141 .map_err(|e| format!("Failed to read model file {}: {}", path.display(), e))?;
142
143 let table: TableDef = if ext == Some("json") {
144 serde_json::from_str(&content).map_err(|e| {
145 format!("Failed to parse JSON model {}: {}", path.display(), e)
146 })?
147 } else {
148 serde_yaml::from_str(&content).map_err(|e| {
149 format!("Failed to parse YAML model {}: {}", path.display(), e)
150 })?
151 };
152
153 tables.push(table);
154 }
155 }
156 }
157
158 Ok(())
159}
160
161pub fn load_models_at_compile_time() -> Result<Vec<TableDef>, Box<dyn std::error::Error>> {
163 load_models_from_dir(None)
164}
165
166#[cfg(test)]
167mod tests {
168 use super::*;
169 use serial_test::serial;
170 use std::fs;
171 use tempfile::tempdir;
172 use vespertide_core::{
173 ColumnDef, ColumnType, SimpleColumnType, TableConstraint,
174 schema::foreign_key::ForeignKeySyntax,
175 };
176
177 struct CwdGuard {
178 original: std::path::PathBuf,
179 }
180
181 impl CwdGuard {
182 fn new(dir: &std::path::PathBuf) -> Self {
183 let original = std::env::current_dir().unwrap();
184 std::env::set_current_dir(dir).unwrap();
185 Self { original }
186 }
187 }
188
189 impl Drop for CwdGuard {
190 fn drop(&mut self) {
191 let _ = std::env::set_current_dir(&self.original);
192 }
193 }
194
195 fn write_config() {
196 let cfg = VespertideConfig::default();
197 let text = serde_json::to_string_pretty(&cfg).unwrap();
198 fs::write("vespertide.json", text).unwrap();
199 }
200
201 #[test]
202 #[serial]
203 fn load_models_returns_empty_when_no_models_dir() {
204 let tmp = tempdir().unwrap();
205 let _guard = CwdGuard::new(&tmp.path().to_path_buf());
206 write_config();
207
208 let models = load_models(&VespertideConfig::default()).unwrap();
210 assert_eq!(models.len(), 0);
211 }
212
213 #[test]
214 #[serial]
215 fn load_models_reads_yaml_and_validates() {
216 let tmp = tempdir().unwrap();
217 let _guard = CwdGuard::new(&tmp.path().to_path_buf());
218 write_config();
219
220 fs::create_dir_all("models").unwrap();
221 let table = TableDef {
222 name: "users".into(),
223 columns: vec![ColumnDef {
224 name: "id".into(),
225 r#type: ColumnType::Simple(SimpleColumnType::Integer),
226 nullable: false,
227 default: None,
228 comment: None,
229 primary_key: None,
230 unique: None,
231 index: None,
232 foreign_key: None,
233 }],
234 constraints: vec![TableConstraint::PrimaryKey {
235 auto_increment: false,
236 columns: vec!["id".into()],
237 }],
238 };
239 fs::write("models/users.yaml", serde_yaml::to_string(&table).unwrap()).unwrap();
240
241 let models = load_models(&VespertideConfig::default()).unwrap();
242 assert_eq!(models.len(), 1);
243 assert_eq!(models[0].name, "users");
244 }
245
246 #[test]
247 #[serial]
248 fn load_models_recursive_processes_subdirectories() {
249 let tmp = tempdir().unwrap();
250 let _guard = CwdGuard::new(&tmp.path().to_path_buf());
251 write_config();
252
253 fs::create_dir_all("models/subdir").unwrap();
254
255 let table = TableDef {
257 name: "subtable".into(),
258 columns: vec![ColumnDef {
259 name: "id".into(),
260 r#type: ColumnType::Simple(SimpleColumnType::Integer),
261 nullable: false,
262 default: None,
263 comment: None,
264 primary_key: None,
265 unique: None,
266 index: None,
267 foreign_key: None,
268 }],
269 constraints: vec![TableConstraint::PrimaryKey {
270 auto_increment: false,
271 columns: vec!["id".into()],
272 }],
273 };
274 let content = serde_json::to_string_pretty(&table).unwrap();
275 fs::write("models/subdir/subtable.json", content).unwrap();
276
277 let models = load_models(&VespertideConfig::default()).unwrap();
278 assert_eq!(models.len(), 1);
279 assert_eq!(models[0].name, "subtable");
280 }
281
282 #[test]
283 #[serial]
284 fn load_models_fails_on_invalid_fk_format() {
285 let tmp = tempdir().unwrap();
286 let _guard = CwdGuard::new(&tmp.path().to_path_buf());
287 write_config();
288
289 fs::create_dir_all("models").unwrap();
290
291 let table = TableDef {
293 name: "orders".into(),
294 columns: vec![ColumnDef {
295 name: "user_id".into(),
296 r#type: ColumnType::Simple(SimpleColumnType::Integer),
297 nullable: false,
298 default: None,
299 comment: None,
300 primary_key: None,
301 unique: None,
302 index: None,
303 foreign_key: Some(ForeignKeySyntax::String("invalid_format".into())),
305 }],
306 constraints: vec![],
307 };
308 fs::write(
309 "models/orders.json",
310 serde_json::to_string_pretty(&table).unwrap(),
311 )
312 .unwrap();
313
314 let result = load_models(&VespertideConfig::default());
315 assert!(result.is_err());
316 let err_msg = result.unwrap_err().to_string();
317 assert!(err_msg.contains("Failed to normalize table 'orders'"));
318 }
319
320 #[test]
321 #[serial]
322 fn test_load_models_from_dir_with_root() {
323 let temp_dir = tempdir().unwrap();
324 let models_dir = temp_dir.path().join("models");
325 fs::create_dir_all(&models_dir).unwrap();
326
327 let table = TableDef {
328 name: "users".into(),
329 columns: vec![ColumnDef {
330 name: "id".into(),
331 r#type: ColumnType::Simple(SimpleColumnType::Integer),
332 nullable: false,
333 default: None,
334 comment: None,
335 primary_key: None,
336 unique: None,
337 index: None,
338 foreign_key: None,
339 }],
340 constraints: vec![],
341 };
342 fs::write(
343 models_dir.join("users.json"),
344 serde_json::to_string_pretty(&table).unwrap(),
345 )
346 .unwrap();
347
348 let result = load_models_from_dir(Some(temp_dir.path().to_path_buf()));
349 assert!(result.is_ok());
350 let models = result.unwrap();
351 assert_eq!(models.len(), 1);
352 assert_eq!(models[0].name, "users");
353 }
354
355 #[test]
356 #[serial]
357 fn test_load_models_from_dir_without_root() {
358 use std::env;
359
360 let original = env::var("CARGO_MANIFEST_DIR").ok();
362
363 unsafe {
365 env::remove_var("CARGO_MANIFEST_DIR");
366 }
367
368 let result = load_models_from_dir(None);
369 assert!(result.is_err());
370 let err_msg = result.unwrap_err().to_string();
371 assert!(err_msg.contains("CARGO_MANIFEST_DIR environment variable not set"));
372
373 drop(original);
374 }
375
376 #[test]
377 #[serial]
378 fn test_load_models_from_dir_no_models_dir() {
379 let temp_dir = tempdir().unwrap();
380 let result = load_models_from_dir(Some(temp_dir.path().to_path_buf()));
383 assert!(result.is_ok());
384 let models = result.unwrap();
385 assert_eq!(models.len(), 0);
386 }
387
388 #[test]
389 #[serial]
390 fn test_load_models_from_dir_with_yaml() {
391 let temp_dir = tempdir().unwrap();
392 let models_dir = temp_dir.path().join("models");
393 fs::create_dir_all(&models_dir).unwrap();
394
395 let table = TableDef {
396 name: "users".into(),
397 columns: vec![ColumnDef {
398 name: "id".into(),
399 r#type: ColumnType::Simple(SimpleColumnType::Integer),
400 nullable: false,
401 default: None,
402 comment: None,
403 primary_key: None,
404 unique: None,
405 index: None,
406 foreign_key: None,
407 }],
408 constraints: vec![],
409 };
410 fs::write(
411 models_dir.join("users.yaml"),
412 serde_yaml::to_string(&table).unwrap(),
413 )
414 .unwrap();
415
416 let result = load_models_from_dir(Some(temp_dir.path().to_path_buf()));
417 assert!(result.is_ok());
418 let models = result.unwrap();
419 assert_eq!(models.len(), 1);
420 assert_eq!(models[0].name, "users");
421 }
422
423 #[test]
424 #[serial]
425 fn test_load_models_from_dir_with_yml() {
426 let temp_dir = tempdir().unwrap();
427 let models_dir = temp_dir.path().join("models");
428 fs::create_dir_all(&models_dir).unwrap();
429
430 let table = TableDef {
431 name: "users".into(),
432 columns: vec![ColumnDef {
433 name: "id".into(),
434 r#type: ColumnType::Simple(SimpleColumnType::Integer),
435 nullable: false,
436 default: None,
437 comment: None,
438 primary_key: None,
439 unique: None,
440 index: None,
441 foreign_key: None,
442 }],
443 constraints: vec![],
444 };
445 fs::write(
446 models_dir.join("users.yml"),
447 serde_yaml::to_string(&table).unwrap(),
448 )
449 .unwrap();
450
451 let result = load_models_from_dir(Some(temp_dir.path().to_path_buf()));
452 assert!(result.is_ok());
453 let models = result.unwrap();
454 assert_eq!(models.len(), 1);
455 assert_eq!(models[0].name, "users");
456 }
457
458 #[test]
459 #[serial]
460 fn test_load_models_from_dir_recursive() {
461 let temp_dir = tempdir().unwrap();
462 let models_dir = temp_dir.path().join("models");
463 let subdir = models_dir.join("subdir");
464 fs::create_dir_all(&subdir).unwrap();
465
466 let table = TableDef {
467 name: "subtable".into(),
468 columns: vec![ColumnDef {
469 name: "id".into(),
470 r#type: ColumnType::Simple(SimpleColumnType::Integer),
471 nullable: false,
472 default: None,
473 comment: None,
474 primary_key: None,
475 unique: None,
476 index: None,
477 foreign_key: None,
478 }],
479 constraints: vec![],
480 };
481 fs::write(
482 subdir.join("subtable.json"),
483 serde_json::to_string_pretty(&table).unwrap(),
484 )
485 .unwrap();
486
487 let result = load_models_from_dir(Some(temp_dir.path().to_path_buf()));
488 assert!(result.is_ok());
489 let models = result.unwrap();
490 assert_eq!(models.len(), 1);
491 assert_eq!(models[0].name, "subtable");
492 }
493
494 #[test]
495 #[serial]
496 fn test_load_models_from_dir_with_invalid_json() {
497 let temp_dir = tempdir().unwrap();
498 let models_dir = temp_dir.path().join("models");
499 fs::create_dir_all(&models_dir).unwrap();
500
501 fs::write(models_dir.join("invalid.json"), r#"{"invalid": json}"#).unwrap();
502
503 let result = load_models_from_dir(Some(temp_dir.path().to_path_buf()));
504 assert!(result.is_err());
505 let err_msg = result.unwrap_err().to_string();
506 assert!(err_msg.contains("Failed to parse JSON model"));
507 }
508
509 #[test]
510 #[serial]
511 fn test_load_models_from_dir_with_invalid_yaml() {
512 let temp_dir = tempdir().unwrap();
513 let models_dir = temp_dir.path().join("models");
514 fs::create_dir_all(&models_dir).unwrap();
515
516 fs::write(models_dir.join("invalid.yaml"), r#"invalid: [yaml"#).unwrap();
517
518 let result = load_models_from_dir(Some(temp_dir.path().to_path_buf()));
519 assert!(result.is_err());
520 let err_msg = result.unwrap_err().to_string();
521 assert!(err_msg.contains("Failed to parse YAML model"));
522 }
523
524 #[test]
525 #[serial]
526 fn test_load_models_from_dir_normalization_error() {
527 let temp_dir = tempdir().unwrap();
528 let models_dir = temp_dir.path().join("models");
529 fs::create_dir_all(&models_dir).unwrap();
530
531 let table = TableDef {
533 name: "orders".into(),
534 columns: vec![ColumnDef {
535 name: "user_id".into(),
536 r#type: ColumnType::Simple(SimpleColumnType::Integer),
537 nullable: false,
538 default: None,
539 comment: None,
540 primary_key: None,
541 unique: None,
542 index: None,
543 foreign_key: Some(ForeignKeySyntax::String("invalid_format".into())),
544 }],
545 constraints: vec![],
546 };
547 fs::write(
548 models_dir.join("orders.json"),
549 serde_json::to_string_pretty(&table).unwrap(),
550 )
551 .unwrap();
552
553 let result = load_models_from_dir(Some(temp_dir.path().to_path_buf()));
554 assert!(result.is_err());
555 let err_msg = result.unwrap_err().to_string();
556 assert!(err_msg.contains("Failed to normalize table 'orders'"));
557 }
558
559 #[test]
560 #[serial]
561 fn test_load_models_from_dir_with_cargo_manifest_dir() {
562 let result = load_models_from_dir(None);
565 let _ = result;
569 }
570
571 #[test]
572 #[serial]
573 fn test_load_models_at_compile_time() {
574 let result = load_models_at_compile_time();
578 let _ = result;
582 }
583}