1use std::fs;
2use std::path::Path;
3
4use anyhow::{Context, Result};
5use vespertide_config::VespertideConfig;
6use vespertide_core::TableDef;
7use vespertide_planner::validate_schema;
8
9pub fn load_models(config: &VespertideConfig) -> Result<Vec<TableDef>> {
11 let models_dir = config.models_dir();
12 if !models_dir.exists() {
13 return Ok(Vec::new());
14 }
15
16 let mut tables = Vec::new();
17 load_models_recursive(models_dir, &mut tables)?;
18
19 let normalized_tables: Vec<TableDef> = tables
22 .into_iter()
23 .map(|t| {
24 t.normalize()
25 .map_err(|e| anyhow::anyhow!("Failed to normalize table '{}': {}", t.name, e))
26 })
27 .collect::<Result<Vec<_>, _>>()?;
28
29 if !normalized_tables.is_empty() {
31 validate_schema(&normalized_tables)
32 .map_err(|e| anyhow::anyhow!("schema validation failed: {}", e))?;
33 }
34
35 Ok(normalized_tables)
36}
37
38fn load_models_recursive(dir: &Path, tables: &mut Vec<TableDef>) -> Result<()> {
40 let entries =
41 fs::read_dir(dir).with_context(|| format!("read models directory: {}", dir.display()))?;
42
43 for entry in entries {
44 let entry = entry.context("read directory entry")?;
45 let path = entry.path();
46
47 if path.is_dir() {
48 load_models_recursive(&path, tables)?;
50 continue;
51 }
52
53 if path.is_file() {
54 let ext = path.extension().and_then(|s| s.to_str());
55 if matches!(ext, Some("json") | Some("yaml") | Some("yml")) {
56 let content = fs::read_to_string(&path)
57 .with_context(|| format!("read model file: {}", path.display()))?;
58
59 let table: TableDef = if ext == Some("json") {
60 serde_json::from_str(&content)
61 .with_context(|| format!("parse JSON model: {}", path.display()))?
62 } else {
63 serde_yaml::from_str(&content)
64 .with_context(|| format!("parse YAML model: {}", path.display()))?
65 };
66
67 tables.push(table);
68 }
69 }
70 }
71
72 Ok(())
73}
74
75pub fn load_models_from_dir(
77 project_root: Option<std::path::PathBuf>,
78) -> Result<Vec<TableDef>, Box<dyn std::error::Error>> {
79 use std::env;
80
81 let project_root = if let Some(root) = project_root {
83 root
84 } else {
85 std::path::PathBuf::from(
86 env::var("CARGO_MANIFEST_DIR")
87 .context("CARGO_MANIFEST_DIR environment variable not set")?,
88 )
89 };
90
91 let config = crate::config::load_config_or_default(Some(project_root.clone()))
93 .map_err(|e| format!("Failed to load config: {}", e))?;
94
95 let models_dir = project_root.join(config.models_dir());
97 if !models_dir.exists() {
98 return Ok(Vec::new());
99 }
100
101 let mut tables = Vec::new();
102 load_models_recursive_internal(&models_dir, &mut tables)
103 .map_err(|e| format!("Failed to load models: {}", e))?;
104
105 let normalized_tables: Vec<TableDef> = tables
107 .into_iter()
108 .map(|t| {
109 t.normalize()
110 .map_err(|e| format!("Failed to normalize table '{}': {}", t.name, e))
111 })
112 .collect::<Result<Vec<_>, _>>()
113 .map_err(|e| e.to_string())?;
114
115 Ok(normalized_tables)
116}
117
118fn load_models_recursive_internal(
120 dir: &Path,
121 tables: &mut Vec<TableDef>,
122) -> Result<(), Box<dyn std::error::Error>> {
123 use std::fs;
124
125 let entries = fs::read_dir(dir)
126 .map_err(|e| format!("Failed to read models directory {}: {}", dir.display(), e))?;
127
128 for entry in entries {
129 let entry = entry.map_err(|e| format!("Failed to read directory entry: {}", e))?;
130 let path = entry.path();
131
132 if path.is_dir() {
133 load_models_recursive_internal(&path, tables)?;
135 continue;
136 }
137
138 if path.is_file() {
139 let ext = path.extension().and_then(|s| s.to_str());
140 if matches!(ext, Some("json") | Some("yaml") | Some("yml")) {
141 let content = fs::read_to_string(&path)
142 .map_err(|e| format!("Failed to read model file {}: {}", path.display(), e))?;
143
144 let table: TableDef = if ext == Some("json") {
145 serde_json::from_str(&content).map_err(|e| {
146 format!("Failed to parse JSON model {}: {}", path.display(), e)
147 })?
148 } else {
149 serde_yaml::from_str(&content).map_err(|e| {
150 format!("Failed to parse YAML model {}: {}", path.display(), e)
151 })?
152 };
153
154 tables.push(table);
155 }
156 }
157 }
158
159 Ok(())
160}
161
162pub fn load_models_at_compile_time() -> Result<Vec<TableDef>, Box<dyn std::error::Error>> {
164 load_models_from_dir(None)
165}
166
167#[cfg(test)]
168mod tests {
169 use super::*;
170 use serial_test::serial;
171 use std::fs;
172 use tempfile::tempdir;
173 use vespertide_core::{
174 ColumnDef, ColumnType, SimpleColumnType, TableConstraint,
175 schema::foreign_key::ForeignKeySyntax,
176 };
177
178 struct CwdGuard {
179 original: std::path::PathBuf,
180 }
181
182 impl CwdGuard {
183 fn new(dir: &std::path::PathBuf) -> Self {
184 let original = std::env::current_dir().unwrap();
185 std::env::set_current_dir(dir).unwrap();
186 Self { original }
187 }
188 }
189
190 impl Drop for CwdGuard {
191 fn drop(&mut self) {
192 let _ = std::env::set_current_dir(&self.original);
193 }
194 }
195
196 fn write_config() {
197 let cfg = VespertideConfig::default();
198 let text = serde_json::to_string_pretty(&cfg).unwrap();
199 fs::write("vespertide.json", text).unwrap();
200 }
201
202 #[test]
203 #[serial]
204 fn load_models_returns_empty_when_no_models_dir() {
205 let tmp = tempdir().unwrap();
206 let _guard = CwdGuard::new(&tmp.path().to_path_buf());
207 write_config();
208
209 let models = load_models(&VespertideConfig::default()).unwrap();
211 assert_eq!(models.len(), 0);
212 }
213
214 #[test]
215 #[serial]
216 fn load_models_reads_yaml_and_validates() {
217 let tmp = tempdir().unwrap();
218 let _guard = CwdGuard::new(&tmp.path().to_path_buf());
219 write_config();
220
221 fs::create_dir_all("models").unwrap();
222 let table = TableDef {
223 name: "users".into(),
224 columns: vec![ColumnDef {
225 name: "id".into(),
226 r#type: ColumnType::Simple(SimpleColumnType::Integer),
227 nullable: false,
228 default: None,
229 comment: None,
230 primary_key: None,
231 unique: None,
232 index: None,
233 foreign_key: None,
234 }],
235 constraints: vec![TableConstraint::PrimaryKey {
236 auto_increment: false,
237 columns: vec!["id".into()],
238 }],
239 indexes: vec![],
240 };
241 fs::write("models/users.yaml", serde_yaml::to_string(&table).unwrap()).unwrap();
242
243 let models = load_models(&VespertideConfig::default()).unwrap();
244 assert_eq!(models.len(), 1);
245 assert_eq!(models[0].name, "users");
246 }
247
248 #[test]
249 #[serial]
250 fn load_models_recursive_processes_subdirectories() {
251 let tmp = tempdir().unwrap();
252 let _guard = CwdGuard::new(&tmp.path().to_path_buf());
253 write_config();
254
255 fs::create_dir_all("models/subdir").unwrap();
256
257 let table = TableDef {
259 name: "subtable".into(),
260 columns: vec![ColumnDef {
261 name: "id".into(),
262 r#type: ColumnType::Simple(SimpleColumnType::Integer),
263 nullable: false,
264 default: None,
265 comment: None,
266 primary_key: None,
267 unique: None,
268 index: None,
269 foreign_key: None,
270 }],
271 constraints: vec![TableConstraint::PrimaryKey {
272 auto_increment: false,
273 columns: vec!["id".into()],
274 }],
275 indexes: vec![],
276 };
277 let content = serde_json::to_string_pretty(&table).unwrap();
278 fs::write("models/subdir/subtable.json", content).unwrap();
279
280 let models = load_models(&VespertideConfig::default()).unwrap();
281 assert_eq!(models.len(), 1);
282 assert_eq!(models[0].name, "subtable");
283 }
284
285 #[test]
286 #[serial]
287 fn load_models_fails_on_invalid_fk_format() {
288 let tmp = tempdir().unwrap();
289 let _guard = CwdGuard::new(&tmp.path().to_path_buf());
290 write_config();
291
292 fs::create_dir_all("models").unwrap();
293
294 let table = TableDef {
296 name: "orders".into(),
297 columns: vec![ColumnDef {
298 name: "user_id".into(),
299 r#type: ColumnType::Simple(SimpleColumnType::Integer),
300 nullable: false,
301 default: None,
302 comment: None,
303 primary_key: None,
304 unique: None,
305 index: None,
306 foreign_key: Some(ForeignKeySyntax::String("invalid_format".into())),
308 }],
309 constraints: vec![],
310 indexes: vec![],
311 };
312 fs::write(
313 "models/orders.json",
314 serde_json::to_string_pretty(&table).unwrap(),
315 )
316 .unwrap();
317
318 let result = load_models(&VespertideConfig::default());
319 assert!(result.is_err());
320 let err_msg = result.unwrap_err().to_string();
321 assert!(err_msg.contains("Failed to normalize table 'orders'"));
322 }
323
324 #[test]
325 #[serial]
326 fn test_load_models_from_dir_with_root() {
327 let temp_dir = tempdir().unwrap();
328 let models_dir = temp_dir.path().join("models");
329 fs::create_dir_all(&models_dir).unwrap();
330
331 let table = TableDef {
332 name: "users".into(),
333 columns: vec![ColumnDef {
334 name: "id".into(),
335 r#type: ColumnType::Simple(SimpleColumnType::Integer),
336 nullable: false,
337 default: None,
338 comment: None,
339 primary_key: None,
340 unique: None,
341 index: None,
342 foreign_key: None,
343 }],
344 constraints: vec![],
345 indexes: vec![],
346 };
347 fs::write(
348 models_dir.join("users.json"),
349 serde_json::to_string_pretty(&table).unwrap(),
350 )
351 .unwrap();
352
353 let result = load_models_from_dir(Some(temp_dir.path().to_path_buf()));
354 assert!(result.is_ok());
355 let models = result.unwrap();
356 assert_eq!(models.len(), 1);
357 assert_eq!(models[0].name, "users");
358 }
359
360 #[test]
361 #[serial]
362 fn test_load_models_from_dir_without_root() {
363 use std::env;
364
365 let original = env::var("CARGO_MANIFEST_DIR").ok();
367
368 unsafe {
370 env::remove_var("CARGO_MANIFEST_DIR");
371 }
372
373 let result = load_models_from_dir(None);
374 assert!(result.is_err());
375 let err_msg = result.unwrap_err().to_string();
376 assert!(err_msg.contains("CARGO_MANIFEST_DIR environment variable not set"));
377
378 drop(original);
379 }
380
381 #[test]
382 #[serial]
383 fn test_load_models_from_dir_no_models_dir() {
384 let temp_dir = tempdir().unwrap();
385 let result = load_models_from_dir(Some(temp_dir.path().to_path_buf()));
388 assert!(result.is_ok());
389 let models = result.unwrap();
390 assert_eq!(models.len(), 0);
391 }
392
393 #[test]
394 #[serial]
395 fn test_load_models_from_dir_with_yaml() {
396 let temp_dir = tempdir().unwrap();
397 let models_dir = temp_dir.path().join("models");
398 fs::create_dir_all(&models_dir).unwrap();
399
400 let table = TableDef {
401 name: "users".into(),
402 columns: vec![ColumnDef {
403 name: "id".into(),
404 r#type: ColumnType::Simple(SimpleColumnType::Integer),
405 nullable: false,
406 default: None,
407 comment: None,
408 primary_key: None,
409 unique: None,
410 index: None,
411 foreign_key: None,
412 }],
413 constraints: vec![],
414 indexes: vec![],
415 };
416 fs::write(
417 models_dir.join("users.yaml"),
418 serde_yaml::to_string(&table).unwrap(),
419 )
420 .unwrap();
421
422 let result = load_models_from_dir(Some(temp_dir.path().to_path_buf()));
423 assert!(result.is_ok());
424 let models = result.unwrap();
425 assert_eq!(models.len(), 1);
426 assert_eq!(models[0].name, "users");
427 }
428
429 #[test]
430 #[serial]
431 fn test_load_models_from_dir_with_yml() {
432 let temp_dir = tempdir().unwrap();
433 let models_dir = temp_dir.path().join("models");
434 fs::create_dir_all(&models_dir).unwrap();
435
436 let table = TableDef {
437 name: "users".into(),
438 columns: vec![ColumnDef {
439 name: "id".into(),
440 r#type: ColumnType::Simple(SimpleColumnType::Integer),
441 nullable: false,
442 default: None,
443 comment: None,
444 primary_key: None,
445 unique: None,
446 index: None,
447 foreign_key: None,
448 }],
449 constraints: vec![],
450 indexes: vec![],
451 };
452 fs::write(
453 models_dir.join("users.yml"),
454 serde_yaml::to_string(&table).unwrap(),
455 )
456 .unwrap();
457
458 let result = load_models_from_dir(Some(temp_dir.path().to_path_buf()));
459 assert!(result.is_ok());
460 let models = result.unwrap();
461 assert_eq!(models.len(), 1);
462 assert_eq!(models[0].name, "users");
463 }
464
465 #[test]
466 #[serial]
467 fn test_load_models_from_dir_recursive() {
468 let temp_dir = tempdir().unwrap();
469 let models_dir = temp_dir.path().join("models");
470 let subdir = models_dir.join("subdir");
471 fs::create_dir_all(&subdir).unwrap();
472
473 let table = TableDef {
474 name: "subtable".into(),
475 columns: vec![ColumnDef {
476 name: "id".into(),
477 r#type: ColumnType::Simple(SimpleColumnType::Integer),
478 nullable: false,
479 default: None,
480 comment: None,
481 primary_key: None,
482 unique: None,
483 index: None,
484 foreign_key: None,
485 }],
486 constraints: vec![],
487 indexes: vec![],
488 };
489 fs::write(
490 subdir.join("subtable.json"),
491 serde_json::to_string_pretty(&table).unwrap(),
492 )
493 .unwrap();
494
495 let result = load_models_from_dir(Some(temp_dir.path().to_path_buf()));
496 assert!(result.is_ok());
497 let models = result.unwrap();
498 assert_eq!(models.len(), 1);
499 assert_eq!(models[0].name, "subtable");
500 }
501
502 #[test]
503 #[serial]
504 fn test_load_models_from_dir_with_invalid_json() {
505 let temp_dir = tempdir().unwrap();
506 let models_dir = temp_dir.path().join("models");
507 fs::create_dir_all(&models_dir).unwrap();
508
509 fs::write(models_dir.join("invalid.json"), r#"{"invalid": json}"#).unwrap();
510
511 let result = load_models_from_dir(Some(temp_dir.path().to_path_buf()));
512 assert!(result.is_err());
513 let err_msg = result.unwrap_err().to_string();
514 assert!(err_msg.contains("Failed to parse JSON model"));
515 }
516
517 #[test]
518 #[serial]
519 fn test_load_models_from_dir_with_invalid_yaml() {
520 let temp_dir = tempdir().unwrap();
521 let models_dir = temp_dir.path().join("models");
522 fs::create_dir_all(&models_dir).unwrap();
523
524 fs::write(models_dir.join("invalid.yaml"), r#"invalid: [yaml"#).unwrap();
525
526 let result = load_models_from_dir(Some(temp_dir.path().to_path_buf()));
527 assert!(result.is_err());
528 let err_msg = result.unwrap_err().to_string();
529 assert!(err_msg.contains("Failed to parse YAML model"));
530 }
531
532 #[test]
533 #[serial]
534 fn test_load_models_from_dir_normalization_error() {
535 let temp_dir = tempdir().unwrap();
536 let models_dir = temp_dir.path().join("models");
537 fs::create_dir_all(&models_dir).unwrap();
538
539 let table = TableDef {
541 name: "orders".into(),
542 columns: vec![ColumnDef {
543 name: "user_id".into(),
544 r#type: ColumnType::Simple(SimpleColumnType::Integer),
545 nullable: false,
546 default: None,
547 comment: None,
548 primary_key: None,
549 unique: None,
550 index: None,
551 foreign_key: Some(ForeignKeySyntax::String("invalid_format".into())),
552 }],
553 constraints: vec![],
554 indexes: vec![],
555 };
556 fs::write(
557 models_dir.join("orders.json"),
558 serde_json::to_string_pretty(&table).unwrap(),
559 )
560 .unwrap();
561
562 let result = load_models_from_dir(Some(temp_dir.path().to_path_buf()));
563 assert!(result.is_err());
564 let err_msg = result.unwrap_err().to_string();
565 assert!(err_msg.contains("Failed to normalize table 'orders'"));
566 }
567
568 #[test]
569 #[serial]
570 fn test_load_models_from_dir_with_cargo_manifest_dir() {
571 let result = load_models_from_dir(None);
574 let _ = result;
578 }
579
580 #[test]
581 #[serial]
582 fn test_load_models_at_compile_time() {
583 let result = load_models_at_compile_time();
587 let _ = result;
591 }
592}