1mod discovery;
4pub(crate) mod merge;
5mod source;
6
7use std::path::Path;
8
9pub use discovery::{Discovered, discover};
10pub use merge::MergeConflict;
11pub use source::{SourceFile, SourceId, SourceLoc, SourceMap};
12
13use crate::ast::Schema;
14use crate::error::SchemaError;
15use crate::parser::parse_schema;
16use crate::validator::Validator;
17
18#[derive(Debug, Clone)]
21pub struct LoadedSchema {
22 pub schema: Schema,
23 pub sources: SourceMap,
24}
25
26#[derive(Debug)]
30pub struct LoadError {
31 pub error: SchemaError,
32 pub sources: SourceMap,
33}
34
35impl std::fmt::Display for LoadError {
36 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
37 self.error.fmt(f)
38 }
39}
40
41impl std::error::Error for LoadError {
42 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
43 Some(&self.error)
44 }
45}
46
47pub fn load(path: impl AsRef<Path>) -> Result<LoadedSchema, LoadError> {
53 let path = path.as_ref();
54 let meta = std::fs::metadata(path).map_err(|e| LoadError {
55 error: SchemaError::IoError {
56 path: path.display().to_string(),
57 source: e,
58 },
59 sources: SourceMap::new(),
60 })?;
61
62 if meta.is_file() {
63 load_single(path)
64 } else if meta.is_dir() {
65 load_directory(path)
66 } else {
67 Err(LoadError {
68 error: SchemaError::ConfigError {
69 message: format!(
70 "schema path `{}` is neither a file nor a directory",
71 path.display()
72 ),
73 },
74 sources: SourceMap::new(),
75 })
76 }
77}
78
79fn load_single(path: &Path) -> Result<LoadedSchema, LoadError> {
80 let mut sources = SourceMap::new();
81 let content = match std::fs::read_to_string(path) {
82 Ok(c) => c,
83 Err(e) => {
84 return Err(LoadError {
85 error: SchemaError::IoError {
86 path: path.display().to_string(),
87 source: e,
88 },
89 sources,
90 });
91 }
92 };
93
94 let mut schema = match parse_schema(&content) {
95 Ok(s) => s,
96 Err(e) => {
97 sources.insert(path.to_path_buf(), content);
100 return Err(LoadError { error: e, sources });
101 }
102 };
103 let sid = sources.insert(path.to_path_buf(), content);
104 stamp_source(&mut schema, sid);
105
106 let validated = match Validator::new().validate(schema) {
107 Ok(s) => s,
108 Err(e) => return Err(LoadError { error: e, sources }),
109 };
110
111 Ok(LoadedSchema {
112 schema: validated,
113 sources,
114 })
115}
116
117fn load_directory(root: &Path) -> Result<LoadedSchema, LoadError> {
118 let mut sources = SourceMap::new();
119
120 let files = match discovery::discover(root) {
121 Ok(v) => v,
122 Err(e) => return Err(LoadError { error: e, sources }),
123 };
124
125 if files.is_empty() {
126 return Err(LoadError {
127 error: SchemaError::EmptySchemaDirectory {
128 path: root.to_path_buf(),
129 },
130 sources,
131 });
132 }
133
134 let mut per_file: Vec<(SourceId, Schema)> = Vec::with_capacity(files.len());
135 for f in files {
136 let content = match std::fs::read_to_string(&f.absolute) {
137 Ok(c) => c,
138 Err(e) => {
139 return Err(LoadError {
140 error: SchemaError::IoError {
141 path: f.absolute.display().to_string(),
142 source: e,
143 },
144 sources,
145 });
146 }
147 };
148 let sid = sources.insert(f.absolute, content);
149 let file_content = &sources.get(sid).expect("just inserted").content;
152 let mut schema_i = match parse_schema(file_content) {
153 Ok(s) => s,
154 Err(inner) => {
155 return Err(LoadError {
156 error: SchemaError::ParseInFile {
157 source: sid,
158 inner: Box::new(inner),
159 },
160 sources,
161 });
162 }
163 };
164 stamp_source(&mut schema_i, sid);
165 per_file.push((sid, schema_i));
166 }
167
168 let mut merged = Schema::new();
169 let mut all_conflicts: Vec<MergeConflict> = Vec::new();
170 for (_, schema_i) in per_file {
171 if let Err(conflicts) = merged.try_merge(schema_i) {
172 all_conflicts.extend(conflicts);
173 }
174 }
175
176 if !all_conflicts.is_empty() {
177 return Err(LoadError {
178 error: from_conflicts(all_conflicts),
179 sources,
180 });
181 }
182
183 let validated = match Validator::new().validate(merged) {
184 Ok(s) => s,
185 Err(e) => return Err(LoadError { error: e, sources }),
186 };
187
188 Ok(LoadedSchema {
189 schema: validated,
190 sources,
191 })
192}
193
194fn from_conflicts(conflicts: Vec<MergeConflict>) -> SchemaError {
196 let mut errors: Vec<SchemaError> = conflicts.into_iter().map(conflict_to_error).collect();
197 if errors.len() == 1 {
198 errors.remove(0)
199 } else {
200 SchemaError::ValidationFailed {
201 count: errors.len(),
202 errors,
203 }
204 }
205}
206
207fn conflict_to_error(c: MergeConflict) -> SchemaError {
208 use crate::error::DuplicateKind;
209
210 macro_rules! dispatch {
211 ($($variant:ident => $kind:ident),+ $(,)?) => {
212 match c {
213 $(
214 MergeConflict::$variant { name, existing, incoming } => {
215 SchemaError::DuplicateAcrossFiles {
216 kind: DuplicateKind::$kind,
217 name: name.to_string(),
218 first: existing,
219 second: incoming,
220 }
221 }
222 ),+,
223 MergeConflict::MultipleDatasource { existing, incoming } => {
224 SchemaError::MultipleDatasource {
225 first: existing,
226 second: incoming,
227 }
228 }
229 }
230 };
231 }
232
233 dispatch! {
234 DuplicateModel => Model,
235 DuplicateEnum => Enum,
236 DuplicateType => Type,
237 DuplicateView => View,
238 DuplicateServerGroup => ServerGroup,
239 DuplicatePolicy => Policy,
240 DuplicateGenerator => Generator,
241 DuplicateRawSql => RawSql,
242 }
243}
244
245pub(crate) fn stamp_source(schema: &mut Schema, source: SourceId) {
249 for m in schema.models.values_mut() {
250 m.source_id = Some(source);
251 }
252 for e in schema.enums.values_mut() {
253 e.source_id = Some(source);
254 }
255 for t in schema.types.values_mut() {
256 t.source_id = Some(source);
257 }
258 for v in schema.views.values_mut() {
259 v.source_id = Some(source);
260 }
261 for sg in schema.server_groups.values_mut() {
262 sg.source_id = Some(source);
263 }
264 for p in &mut schema.policies {
265 p.source_id = Some(source);
266 }
267 for g in schema.generators.values_mut() {
268 g.source_id = Some(source);
269 }
270 if let Some(ds) = &mut schema.datasource {
271 ds.source_id = Some(source);
272 }
273 for r in &mut schema.raw_sql {
274 r.source_id = Some(source);
275 }
276}
277
278#[cfg(test)]
279mod tests {
280 use super::*;
281 use crate::parser::parse_schema;
282
283 #[test]
284 fn load_directory_merges_files_and_resolves_cross_file_relations() {
285 use tempfile::tempdir;
286
287 let dir = tempdir().unwrap();
288 std::fs::write(
289 dir.path().join("datasource.prax"),
290 r#"datasource db { provider = "postgresql" url = "x" }"#,
291 )
292 .unwrap();
293 std::fs::create_dir_all(dir.path().join("models")).unwrap();
294 std::fs::write(
295 dir.path().join("models/user.prax"),
296 "model User { id Int @id @auto email String @unique posts Post[] }",
297 )
298 .unwrap();
299 std::fs::write(
300 dir.path().join("models/post.prax"),
301 "model Post { id Int @id @auto author_id Int author User @relation(fields: [author_id], references: [id]) }",
302 )
303 .unwrap();
304
305 let loaded = load(dir.path()).expect("load should succeed");
306 assert!(loaded.schema.get_model("User").is_some());
307 assert!(loaded.schema.get_model("Post").is_some());
308 assert!(loaded.schema.datasource.is_some());
309 assert_eq!(loaded.sources.len(), 3);
310 }
311
312 #[test]
313 fn load_directory_duplicate_model_errors() {
314 use tempfile::tempdir;
315
316 let dir = tempdir().unwrap();
317 std::fs::write(dir.path().join("a.prax"), "model User { id Int @id @auto }").unwrap();
318 std::fs::write(dir.path().join("b.prax"), "model User { id Int @id @auto }").unwrap();
319
320 let err = load(dir.path()).unwrap_err();
321 let msg = format!("{}", err.error);
322 assert!(msg.contains("duplicate model"), "got: {msg}");
323 assert_eq!(err.sources.len(), 2);
324 }
325
326 #[test]
327 fn load_empty_directory_errors() {
328 use tempfile::tempdir;
329 let dir = tempdir().unwrap();
330 let err = load(dir.path()).unwrap_err();
331 assert!(matches!(
332 err.error,
333 crate::error::SchemaError::EmptySchemaDirectory { .. }
334 ));
335 }
336
337 #[test]
338 fn stamp_marks_all_items() {
339 let mut schema = parse_schema(
340 r#"
341 datasource db { provider = "postgresql" url = "x" }
342 generator client { provider = "prax-client" }
343 enum Role { User Admin }
344 model User { id Int @id @auto role Role }
345 "#,
346 )
347 .unwrap();
348 stamp_source(&mut schema, SourceId(7));
349 assert_eq!(schema.models["User"].source_id, Some(SourceId(7)));
350 assert_eq!(schema.enums["Role"].source_id, Some(SourceId(7)));
351 assert_eq!(schema.datasource.unwrap().source_id, Some(SourceId(7)));
352 assert_eq!(schema.generators["client"].source_id, Some(SourceId(7)));
353 }
354}