1mod schema;
44pub mod templates;
45mod validator;
46
47pub use schema::{Entity, EntityVersion, Field, SchemaConfig, SchemaFile};
48pub use validator::{validate_schema, ValidationError};
49
50use std::fmt;
51use std::path::Path;
52
53#[derive(Debug)]
55pub enum CodegenError {
56 Io(std::io::Error),
58 Parse(String),
60 Validation(Vec<ValidationError>),
62}
63
64impl fmt::Display for CodegenError {
65 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
66 match self {
67 Self::Io(e) => write!(f, "IO error: {e}"),
68 Self::Parse(msg) => write!(f, "parse error: {msg}"),
69 Self::Validation(errs) => {
70 writeln!(f, "schema validation failed:")?;
71 for e in errs {
72 writeln!(f, " - {e}")?;
73 }
74 Ok(())
75 }
76 }
77 }
78}
79
80impl std::error::Error for CodegenError {}
81
82#[derive(Debug, Clone)]
84pub struct GeneratedFile {
85 pub relative_path: String,
87 pub content: String,
89}
90
91#[derive(Debug, Clone)]
93pub struct GeneratedOutput {
94 pub output_dir: String,
96 pub files: Vec<GeneratedFile>,
98}
99
100pub fn generate(schema_path: &Path) -> Result<GeneratedOutput, CodegenError> {
102 let toml_content = std::fs::read_to_string(schema_path).map_err(CodegenError::Io)?;
103 generate_from_str(&toml_content)
104}
105
106pub fn generate_from_str(toml_content: &str) -> Result<GeneratedOutput, CodegenError> {
108 let schema: SchemaFile =
109 toml::from_str(toml_content).map_err(|e| CodegenError::Parse(e.to_string()))?;
110 generate_from_schema(&schema)
111}
112
113pub fn generate_from_schema(schema: &SchemaFile) -> Result<GeneratedOutput, CodegenError> {
141 validate_schema(schema).map_err(CodegenError::Validation)?;
142
143 let mut files = Vec::new();
144
145 for entity in &schema.entities {
147 let (filename, content) = templates::generate_entity_file(entity);
148 files.push(GeneratedFile {
149 relative_path: format!("models/{filename}"),
150 content,
151 });
152 }
153 files.push(GeneratedFile {
154 relative_path: "models/mod.rs".into(),
155 content: templates::generate_models_mod_file(&schema.entities),
156 });
157
158 for entity in &schema.entities {
160 if entity.versions.len() > 1 {
161 let (filename, content) = templates::generate_migration_file(entity);
162 files.push(GeneratedFile {
163 relative_path: format!("migrations/{filename}"),
164 content,
165 });
166 }
167 }
168 files.push(GeneratedFile {
169 relative_path: "migrations/helpers.rs".into(),
170 content: templates::generate_helpers_file(&schema.entities),
171 });
172 files.push(GeneratedFile {
173 relative_path: "migrations/mod.rs".into(),
174 content: templates::generate_migrations_mod_file(&schema.entities),
175 });
176
177 files.push(GeneratedFile {
179 relative_path: "repositories/traits.rs".into(),
180 content: templates::generate_repository_traits_file(&schema.entities),
181 });
182 for entity in &schema.entities {
183 let (filename, content) = templates::generate_repository_impl_file(entity);
184 files.push(GeneratedFile {
185 relative_path: format!("repositories/{filename}"),
186 content,
187 });
188 }
189 files.push(GeneratedFile {
190 relative_path: "repositories/mod.rs".into(),
191 content: templates::generate_repositories_mod_file(&schema.entities),
192 });
193
194 files.push(GeneratedFile {
196 relative_path: "store.rs".into(),
197 content: templates::generate_store_file(&schema.entities),
198 });
199
200 let events_enabled = schema
202 .config
203 .events
204 .as_ref()
205 .map(|e| e.enabled)
206 .unwrap_or(false);
207
208 if events_enabled {
209 let threshold = schema
210 .config
211 .events
212 .as_ref()
213 .map(|e| e.snapshot_threshold)
214 .unwrap_or(100);
215
216 for entity in &schema.entities {
217 let (filename, content) = templates::generate_event_types_file(entity);
218 files.push(GeneratedFile {
219 relative_path: format!("events/{filename}"),
220 content,
221 });
222 }
223 files.push(GeneratedFile {
224 relative_path: "events/policies.rs".into(),
225 content: templates::generate_snapshot_policy_file(threshold),
226 });
227 files.push(GeneratedFile {
228 relative_path: "events/mod.rs".into(),
229 content: templates::generate_events_mod_file(&schema.entities),
230 });
231 }
232
233 let sync_enabled = schema
235 .config
236 .sync
237 .as_ref()
238 .map(|s| s.enabled)
239 .unwrap_or(false);
240
241 if sync_enabled {
242 let crdt_entities: Vec<&schema::Entity> = schema
244 .entities
245 .iter()
246 .filter(|e| {
247 e.versions
248 .iter()
249 .any(|v| v.fields.iter().any(|f| f.crdt.is_some()))
250 })
251 .collect();
252
253 for entity in &crdt_entities {
254 let (filename, content) = templates::generate_sync_file(entity);
255 files.push(GeneratedFile {
256 relative_path: format!("sync/{filename}"),
257 content,
258 });
259 }
260 if !crdt_entities.is_empty() {
261 files.push(GeneratedFile {
262 relative_path: "sync/mod.rs".into(),
263 content: templates::generate_sync_mod_file(&schema.entities),
264 });
265 }
266 }
267
268 files.push(GeneratedFile {
270 relative_path: "mod.rs".into(),
271 content: templates::generate_persistence_mod_file(&schema.entities, &schema.config),
272 });
273
274 Ok(GeneratedOutput {
275 output_dir: schema.config.output.clone(),
276 files,
277 })
278}
279
280#[cfg(test)]
281mod tests {
282 use super::*;
283
284 const SAMPLE_TOML: &str = r#"
285[config]
286output = "src/persistence"
287
288[[entity]]
289name = "Task"
290table = "tasks"
291
292[[entity.versions]]
293version = 1
294fields = [
295 { name = "title", type = "String" },
296 { name = "done", type = "bool" },
297]
298
299[[entity.versions]]
300version = 2
301fields = [
302 { name = "title", type = "String" },
303 { name = "done", type = "bool" },
304 { name = "priority", type = "Option<u8>", default = "None" },
305 { name = "tags", type = "Vec<String>", default = "Vec::new()" },
306]
307"#;
308
309 #[test]
310 fn end_to_end_generation() {
311 let output = generate_from_str(SAMPLE_TOML).unwrap();
312 assert_eq!(output.output_dir, "src/persistence");
313
314 let filenames: Vec<&str> = output
315 .files
316 .iter()
317 .map(|f| f.relative_path.as_str())
318 .collect();
319
320 assert!(filenames.contains(&"models/task.rs"));
322 assert!(filenames.contains(&"models/mod.rs"));
323 assert!(filenames.contains(&"migrations/task_migrations.rs"));
325 assert!(filenames.contains(&"migrations/helpers.rs"));
326 assert!(filenames.contains(&"migrations/mod.rs"));
327 assert!(filenames.contains(&"repositories/traits.rs"));
329 assert!(filenames.contains(&"repositories/task_repo.rs"));
330 assert!(filenames.contains(&"repositories/mod.rs"));
331 assert!(filenames.contains(&"store.rs"));
333 assert!(filenames.contains(&"mod.rs"));
335 assert!(!filenames.iter().any(|f| f.starts_with("events/")));
337 assert!(!filenames.iter().any(|f| f.starts_with("sync/")));
338 }
339
340 #[test]
341 fn generated_structs_compile_ready() {
342 let output = generate_from_str(SAMPLE_TOML).unwrap();
343 let task_file = output
344 .files
345 .iter()
346 .find(|f| f.relative_path == "models/task.rs")
347 .unwrap();
348
349 assert!(task_file.content.contains("pub struct TaskV1"));
350 assert!(task_file.content.contains("pub struct TaskV2"));
351 assert!(task_file.content.contains("pub type Task = TaskV2;"));
352 assert!(task_file
353 .content
354 .contains("#[crdt_schema(version = 1, table = \"tasks\")]"));
355 assert!(task_file
356 .content
357 .contains("#[crdt_schema(version = 2, table = \"tasks\")]"));
358 }
359
360 #[test]
361 fn generated_migrations_compile_ready() {
362 let output = generate_from_str(SAMPLE_TOML).unwrap();
363 let mig_file = output
364 .files
365 .iter()
366 .find(|f| f.relative_path == "migrations/task_migrations.rs")
367 .unwrap();
368
369 assert!(mig_file.content.contains("#[migration(from = 1, to = 2)]"));
370 assert!(mig_file.content.contains("pub fn migrate_task_v1_to_v2"));
371 assert!(mig_file.content.contains("priority: None,"));
372 assert!(mig_file.content.contains("tags: Vec::new(),"));
373 }
374
375 #[test]
376 fn invalid_schema_returns_error() {
377 let bad_toml = r#"
378[config]
379output = ""
380
381[[entity]]
382name = "task"
383table = ""
384
385[[entity.versions]]
386version = 1
387fields = []
388"#;
389 let result = generate_from_str(bad_toml);
390 assert!(result.is_err());
391 if let Err(CodegenError::Validation(errs)) = result {
392 assert!(!errs.is_empty());
393 } else {
394 panic!("expected validation error");
395 }
396 }
397
398 #[test]
399 fn single_version_no_migrations() {
400 let toml = r#"
401[config]
402output = "out"
403
404[[entity]]
405name = "Note"
406table = "notes"
407
408[[entity.versions]]
409version = 1
410fields = [
411 { name = "text", type = "String" },
412]
413"#;
414 let output = generate_from_str(toml).unwrap();
415 let filenames: Vec<&str> = output
416 .files
417 .iter()
418 .map(|f| f.relative_path.as_str())
419 .collect();
420 assert!(filenames.contains(&"models/note.rs"));
421 assert!(!filenames.iter().any(|f| f.contains("note_migrations")));
422 assert!(filenames.contains(&"migrations/helpers.rs"));
423 assert!(filenames.contains(&"repositories/note_repo.rs"));
424 }
425
426 #[test]
427 fn multiple_entities() {
428 let toml = r#"
429[config]
430output = "out"
431
432[[entity]]
433name = "Task"
434table = "tasks"
435
436[[entity.versions]]
437version = 1
438fields = [
439 { name = "title", type = "String" },
440]
441
442[[entity]]
443name = "User"
444table = "users"
445
446[[entity.versions]]
447version = 1
448fields = [
449 { name = "name", type = "String" },
450]
451"#;
452 let output = generate_from_str(toml).unwrap();
453 let filenames: Vec<&str> = output
454 .files
455 .iter()
456 .map(|f| f.relative_path.as_str())
457 .collect();
458 assert!(filenames.contains(&"models/task.rs"));
459 assert!(filenames.contains(&"models/user.rs"));
460 assert!(filenames.contains(&"repositories/task_repo.rs"));
461 assert!(filenames.contains(&"repositories/user_repo.rs"));
462 }
463
464 #[test]
465 fn events_and_sync_conditional() {
466 let toml = r#"
467[config]
468output = "out"
469
470[config.events]
471enabled = true
472snapshot_threshold = 50
473
474[config.sync]
475enabled = true
476
477[[entity]]
478name = "Project"
479table = "projects"
480
481[[entity.versions]]
482version = 1
483fields = [
484 { name = "name", type = "String", crdt = "LWWRegister" },
485 { name = "members", type = "String", crdt = "ORSet" },
486]
487"#;
488 let output = generate_from_str(toml).unwrap();
489 let filenames: Vec<&str> = output
490 .files
491 .iter()
492 .map(|f| f.relative_path.as_str())
493 .collect();
494
495 assert!(filenames.contains(&"events/mod.rs"));
497 assert!(filenames.contains(&"events/policies.rs"));
498 assert!(filenames.contains(&"events/project_events.rs"));
499
500 assert!(filenames.contains(&"sync/mod.rs"));
502 assert!(filenames.contains(&"sync/project_sync.rs"));
503
504 let policies = output
506 .files
507 .iter()
508 .find(|f| f.relative_path == "events/policies.rs")
509 .unwrap();
510 assert!(policies.content.contains("event_threshold: 50,"));
511
512 let mod_file = output
514 .files
515 .iter()
516 .find(|f| f.relative_path == "mod.rs")
517 .unwrap();
518 assert!(mod_file.content.contains("pub mod events;"));
519 assert!(mod_file.content.contains("pub mod sync;"));
520 }
521}