Skip to main content

prax_schema/loader/
mod.rs

1//! Multi-file schema loader.
2
3mod discovery;
4pub(crate) mod merge;
5mod source;
6
7use std::path::Path;
8
9pub use discovery::{Discovered, discover};
10pub use merge::MergeConflict;
11pub use source::{SourceFile, SourceId, SourceLoc, SourceMap};
12
13use crate::ast::Schema;
14use crate::error::SchemaError;
15use crate::parser::parse_schema;
16use crate::validator::Validator;
17
18/// A successfully loaded multi-file (or single-file) schema, paired with the
19/// source map needed for downstream diagnostics rendering.
20#[derive(Debug, Clone)]
21pub struct LoadedSchema {
22    pub schema: Schema,
23    pub sources: SourceMap,
24}
25
26/// Error returned by [`load`], carrying the partial source map built up to the
27/// point of failure so the renderer can resolve [`SourceId`]s back to file
28/// content.
29#[derive(Debug)]
30pub struct LoadError {
31    pub error: SchemaError,
32    pub sources: SourceMap,
33}
34
35impl std::fmt::Display for LoadError {
36    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
37        self.error.fmt(f)
38    }
39}
40
41impl std::error::Error for LoadError {
42    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
43        Some(&self.error)
44    }
45}
46
47/// Load a schema from a file or directory.
48///
49/// - If `path` is a file: parse the single file.
50/// - If `path` is a directory: recursively find `*.prax`, parse each, merge
51///   with collision detection, then validate the merged AST.
52pub fn load(path: impl AsRef<Path>) -> Result<LoadedSchema, LoadError> {
53    let path = path.as_ref();
54    let meta = std::fs::metadata(path).map_err(|e| LoadError {
55        error: SchemaError::IoError {
56            path: path.display().to_string(),
57            source: e,
58        },
59        sources: SourceMap::new(),
60    })?;
61
62    if meta.is_file() {
63        load_single(path)
64    } else if meta.is_dir() {
65        load_directory(path)
66    } else {
67        Err(LoadError {
68            error: SchemaError::ConfigError {
69                message: format!(
70                    "schema path `{}` is neither a file nor a directory",
71                    path.display()
72                ),
73            },
74            sources: SourceMap::new(),
75        })
76    }
77}
78
79fn load_single(path: &Path) -> Result<LoadedSchema, LoadError> {
80    let mut sources = SourceMap::new();
81    let content = match std::fs::read_to_string(path) {
82        Ok(c) => c,
83        Err(e) => {
84            return Err(LoadError {
85                error: SchemaError::IoError {
86                    path: path.display().to_string(),
87                    source: e,
88                },
89                sources,
90            });
91        }
92    };
93
94    let mut schema = match parse_schema(&content) {
95        Ok(s) => s,
96        Err(e) => {
97            // Insert into the map before returning so the renderer can resolve
98            // SourceId(0) back to file content.
99            sources.insert(path.to_path_buf(), content);
100            return Err(LoadError { error: e, sources });
101        }
102    };
103    let sid = sources.insert(path.to_path_buf(), content);
104    stamp_source(&mut schema, sid);
105
106    let validated = match Validator::new().validate(schema) {
107        Ok(s) => s,
108        Err(e) => return Err(LoadError { error: e, sources }),
109    };
110
111    Ok(LoadedSchema {
112        schema: validated,
113        sources,
114    })
115}
116
117fn load_directory(root: &Path) -> Result<LoadedSchema, LoadError> {
118    let mut sources = SourceMap::new();
119
120    let files = match discovery::discover(root) {
121        Ok(v) => v,
122        Err(e) => return Err(LoadError { error: e, sources }),
123    };
124
125    if files.is_empty() {
126        return Err(LoadError {
127            error: SchemaError::EmptySchemaDirectory {
128                path: root.to_path_buf(),
129            },
130            sources,
131        });
132    }
133
134    let mut per_file: Vec<(SourceId, Schema)> = Vec::with_capacity(files.len());
135    for f in files {
136        let content = match std::fs::read_to_string(&f.absolute) {
137            Ok(c) => c,
138            Err(e) => {
139                return Err(LoadError {
140                    error: SchemaError::IoError {
141                        path: f.absolute.display().to_string(),
142                        source: e,
143                    },
144                    sources,
145                });
146            }
147        };
148        let sid = sources.insert(f.absolute, content);
149        // Borrow content back through the map; per-file syntax errors are
150        // fail-fast (no useful partial schema if file N of M is malformed).
151        let file_content = &sources.get(sid).expect("just inserted").content;
152        let mut schema_i = match parse_schema(file_content) {
153            Ok(s) => s,
154            Err(inner) => {
155                return Err(LoadError {
156                    error: SchemaError::ParseInFile {
157                        source: sid,
158                        inner: Box::new(inner),
159                    },
160                    sources,
161                });
162            }
163        };
164        stamp_source(&mut schema_i, sid);
165        per_file.push((sid, schema_i));
166    }
167
168    let mut merged = Schema::new();
169    let mut all_conflicts: Vec<MergeConflict> = Vec::new();
170    for (_, schema_i) in per_file {
171        if let Err(conflicts) = merged.try_merge(schema_i) {
172            all_conflicts.extend(conflicts);
173        }
174    }
175
176    if !all_conflicts.is_empty() {
177        return Err(LoadError {
178            error: from_conflicts(all_conflicts),
179            sources,
180        });
181    }
182
183    let validated = match Validator::new().validate(merged) {
184        Ok(s) => s,
185        Err(e) => return Err(LoadError { error: e, sources }),
186    };
187
188    Ok(LoadedSchema {
189        schema: validated,
190        sources,
191    })
192}
193
194/// Bundle a batch of [`MergeConflict`]s into a single [`SchemaError`].
195fn from_conflicts(conflicts: Vec<MergeConflict>) -> SchemaError {
196    let mut errors: Vec<SchemaError> = conflicts.into_iter().map(conflict_to_error).collect();
197    if errors.len() == 1 {
198        errors.remove(0)
199    } else {
200        SchemaError::ValidationFailed {
201            count: errors.len(),
202            errors,
203        }
204    }
205}
206
207fn conflict_to_error(c: MergeConflict) -> SchemaError {
208    use crate::error::DuplicateKind;
209
210    macro_rules! dispatch {
211        ($($variant:ident => $kind:ident),+ $(,)?) => {
212            match c {
213                $(
214                    MergeConflict::$variant { name, existing, incoming } => {
215                        SchemaError::DuplicateAcrossFiles {
216                            kind: DuplicateKind::$kind,
217                            name: name.to_string(),
218                            first: existing,
219                            second: incoming,
220                        }
221                    }
222                ),+,
223                MergeConflict::MultipleDatasource { existing, incoming } => {
224                    SchemaError::MultipleDatasource {
225                        first: existing,
226                        second: incoming,
227                    }
228                }
229            }
230        };
231    }
232
233    dispatch! {
234        DuplicateModel => Model,
235        DuplicateEnum => Enum,
236        DuplicateType => Type,
237        DuplicateView => View,
238        DuplicateServerGroup => ServerGroup,
239        DuplicatePolicy => Policy,
240        DuplicateGenerator => Generator,
241        DuplicateRawSql => RawSql,
242    }
243}
244
245/// Stamp every top-level item in `schema` with `source`.
246///
247/// Called by [`load`] right after parsing a per-file [`Schema`], before merging.
248pub(crate) fn stamp_source(schema: &mut Schema, source: SourceId) {
249    for m in schema.models.values_mut() {
250        m.source_id = Some(source);
251    }
252    for e in schema.enums.values_mut() {
253        e.source_id = Some(source);
254    }
255    for t in schema.types.values_mut() {
256        t.source_id = Some(source);
257    }
258    for v in schema.views.values_mut() {
259        v.source_id = Some(source);
260    }
261    for sg in schema.server_groups.values_mut() {
262        sg.source_id = Some(source);
263    }
264    for p in &mut schema.policies {
265        p.source_id = Some(source);
266    }
267    for g in schema.generators.values_mut() {
268        g.source_id = Some(source);
269    }
270    if let Some(ds) = &mut schema.datasource {
271        ds.source_id = Some(source);
272    }
273    for r in &mut schema.raw_sql {
274        r.source_id = Some(source);
275    }
276}
277
278#[cfg(test)]
279mod tests {
280    use super::*;
281    use crate::parser::parse_schema;
282
283    #[test]
284    fn load_directory_merges_files_and_resolves_cross_file_relations() {
285        use tempfile::tempdir;
286
287        let dir = tempdir().unwrap();
288        std::fs::write(
289            dir.path().join("datasource.prax"),
290            r#"datasource db { provider = "postgresql" url = "x" }"#,
291        )
292        .unwrap();
293        std::fs::create_dir_all(dir.path().join("models")).unwrap();
294        std::fs::write(
295            dir.path().join("models/user.prax"),
296            "model User { id Int @id @auto email String @unique posts Post[] }",
297        )
298        .unwrap();
299        std::fs::write(
300            dir.path().join("models/post.prax"),
301            "model Post { id Int @id @auto author_id Int author User @relation(fields: [author_id], references: [id]) }",
302        )
303        .unwrap();
304
305        let loaded = load(dir.path()).expect("load should succeed");
306        assert!(loaded.schema.get_model("User").is_some());
307        assert!(loaded.schema.get_model("Post").is_some());
308        assert!(loaded.schema.datasource.is_some());
309        assert_eq!(loaded.sources.len(), 3);
310    }
311
312    #[test]
313    fn load_directory_duplicate_model_errors() {
314        use tempfile::tempdir;
315
316        let dir = tempdir().unwrap();
317        std::fs::write(dir.path().join("a.prax"), "model User { id Int @id @auto }").unwrap();
318        std::fs::write(dir.path().join("b.prax"), "model User { id Int @id @auto }").unwrap();
319
320        let err = load(dir.path()).unwrap_err();
321        let msg = format!("{}", err.error);
322        assert!(msg.contains("duplicate model"), "got: {msg}");
323        assert_eq!(err.sources.len(), 2);
324    }
325
326    #[test]
327    fn load_empty_directory_errors() {
328        use tempfile::tempdir;
329        let dir = tempdir().unwrap();
330        let err = load(dir.path()).unwrap_err();
331        assert!(matches!(
332            err.error,
333            crate::error::SchemaError::EmptySchemaDirectory { .. }
334        ));
335    }
336
337    #[test]
338    fn stamp_marks_all_items() {
339        let mut schema = parse_schema(
340            r#"
341            datasource db { provider = "postgresql" url = "x" }
342            generator client { provider = "prax-client" }
343            enum Role { User Admin }
344            model User { id Int @id @auto role Role }
345            "#,
346        )
347        .unwrap();
348        stamp_source(&mut schema, SourceId(7));
349        assert_eq!(schema.models["User"].source_id, Some(SourceId(7)));
350        assert_eq!(schema.enums["Role"].source_id, Some(SourceId(7)));
351        assert_eq!(schema.datasource.unwrap().source_id, Some(SourceId(7)));
352        assert_eq!(schema.generators["client"].source_id, Some(SourceId(7)));
353    }
354}