Skip to main content

crdt_codegen/
lib.rs

1//! # crdt-codegen
2//!
3//! Code generation from TOML schema definitions for crdt-kit.
4//!
5//! Reads a `crdt-schema.toml` file and generates a complete **persistence layer**
6//! organized for Clean Architecture / Hexagonal Architecture:
7//!
8//! - `models/` — Versioned Rust structs with `#[crdt_schema]` annotations
9//! - `migrations/` — Migration functions with `#[migration]` annotations + helpers
10//! - `repositories/` — Repository traits (ports) and `CrdtDb`-backed implementations (adapters)
11//! - `store.rs` — Unified `Persistence<S>` entry point with scoped repository access
12//! - `events/` — Event sourcing types, snapshots, and policies (optional)
13//! - `sync/` — Delta sync and state-based merge helpers for CRDT entities (optional)
14//!
15//! All generated files contain a header marking them as auto-generated.
16//!
17//! # Example
18//!
19//! ```rust
20//! use crdt_codegen::generate_from_str;
21//!
22//! let toml = r#"
23//! [config]
24//! output = "src/persistence"
25//!
26//! [[entity]]
27//! name = "Task"
28//! table = "tasks"
29//!
30//! [[entity.versions]]
31//! version = 1
32//! fields = [
33//!     { name = "title", type = "String" },
34//!     { name = "done", type = "bool" },
35//! ]
36//! "#;
37//!
38//! let output = generate_from_str(toml).unwrap();
39//! assert_eq!(output.output_dir, "src/persistence");
40//! assert!(!output.files.is_empty());
41//! ```
42
43mod schema;
44pub mod templates;
45mod validator;
46
47pub use schema::{Entity, EntityVersion, Field, SchemaConfig, SchemaFile};
48pub use validator::{validate_schema, ValidationError};
49
50use std::fmt;
51use std::path::Path;
52
53/// Error type for the code-generation process.
54#[derive(Debug)]
55pub enum CodegenError {
56    /// Failed to read the schema file.
57    Io(std::io::Error),
58    /// Failed to parse the TOML schema.
59    Parse(String),
60    /// Schema validation failed.
61    Validation(Vec<ValidationError>),
62}
63
64impl fmt::Display for CodegenError {
65    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
66        match self {
67            Self::Io(e) => write!(f, "IO error: {e}"),
68            Self::Parse(msg) => write!(f, "parse error: {msg}"),
69            Self::Validation(errs) => {
70                writeln!(f, "schema validation failed:")?;
71                for e in errs {
72                    writeln!(f, "  - {e}")?;
73                }
74                Ok(())
75            }
76        }
77    }
78}
79
80impl std::error::Error for CodegenError {}
81
82/// A single generated file ready to be written to disk.
83#[derive(Debug, Clone)]
84pub struct GeneratedFile {
85    /// Path relative to the output directory (e.g., `"models/task.rs"`).
86    pub relative_path: String,
87    /// Full file content including the auto-generated header.
88    pub content: String,
89}
90
91/// The complete output of the code-generation process.
92#[derive(Debug, Clone)]
93pub struct GeneratedOutput {
94    /// Output directory from the schema config.
95    pub output_dir: String,
96    /// All generated files.
97    pub files: Vec<GeneratedFile>,
98}
99
100/// Parse a TOML schema file from disk and generate all code.
101pub fn generate(schema_path: &Path) -> Result<GeneratedOutput, CodegenError> {
102    let toml_content = std::fs::read_to_string(schema_path).map_err(CodegenError::Io)?;
103    generate_from_str(&toml_content)
104}
105
106/// Parse a TOML string and generate all code.
107pub fn generate_from_str(toml_content: &str) -> Result<GeneratedOutput, CodegenError> {
108    let schema: SchemaFile =
109        toml::from_str(toml_content).map_err(|e| CodegenError::Parse(e.to_string()))?;
110    generate_from_schema(&schema)
111}
112
113/// Generate code from an already-parsed schema.
114///
115/// Produces a nested directory structure:
116///
117/// ```text
118/// {output}/
119///   mod.rs
120///   store.rs
121///   models/
122///     mod.rs
123///     {entity}.rs ...
124///   migrations/
125///     mod.rs
126///     helpers.rs
127///     {entity}_migrations.rs ...
128///   repositories/
129///     mod.rs
130///     traits.rs
131///     {entity}_repo.rs ...
132///   events/         (if config.events.enabled)
133///     mod.rs
134///     policies.rs
135///     {entity}_events.rs ...
136///   sync/           (if config.sync.enabled)
137///     mod.rs
138///     {entity}_sync.rs ...
139/// ```
140pub fn generate_from_schema(schema: &SchemaFile) -> Result<GeneratedOutput, CodegenError> {
141    validate_schema(schema).map_err(CodegenError::Validation)?;
142
143    let mut files = Vec::new();
144
145    // ── Models ────────────────────────────────────────────────────
146    for entity in &schema.entities {
147        let (filename, content) = templates::generate_entity_file(entity);
148        files.push(GeneratedFile {
149            relative_path: format!("models/{filename}"),
150            content,
151        });
152    }
153    files.push(GeneratedFile {
154        relative_path: "models/mod.rs".into(),
155        content: templates::generate_models_mod_file(&schema.entities),
156    });
157
158    // ── Migrations ────────────────────────────────────────────────
159    for entity in &schema.entities {
160        if entity.versions.len() > 1 {
161            let (filename, content) = templates::generate_migration_file(entity);
162            files.push(GeneratedFile {
163                relative_path: format!("migrations/{filename}"),
164                content,
165            });
166        }
167    }
168    files.push(GeneratedFile {
169        relative_path: "migrations/helpers.rs".into(),
170        content: templates::generate_helpers_file(&schema.entities),
171    });
172    files.push(GeneratedFile {
173        relative_path: "migrations/mod.rs".into(),
174        content: templates::generate_migrations_mod_file(&schema.entities),
175    });
176
177    // ── Repositories ──────────────────────────────────────────────
178    files.push(GeneratedFile {
179        relative_path: "repositories/traits.rs".into(),
180        content: templates::generate_repository_traits_file(&schema.entities),
181    });
182    for entity in &schema.entities {
183        let (filename, content) = templates::generate_repository_impl_file(entity);
184        files.push(GeneratedFile {
185            relative_path: format!("repositories/{filename}"),
186            content,
187        });
188    }
189    files.push(GeneratedFile {
190        relative_path: "repositories/mod.rs".into(),
191        content: templates::generate_repositories_mod_file(&schema.entities),
192    });
193
194    // ── Store ─────────────────────────────────────────────────────
195    files.push(GeneratedFile {
196        relative_path: "store.rs".into(),
197        content: templates::generate_store_file(&schema.entities),
198    });
199
200    // ── Events (conditional) ──────────────────────────────────────
201    let events_enabled = schema
202        .config
203        .events
204        .as_ref()
205        .map(|e| e.enabled)
206        .unwrap_or(false);
207
208    if events_enabled {
209        let threshold = schema
210            .config
211            .events
212            .as_ref()
213            .map(|e| e.snapshot_threshold)
214            .unwrap_or(100);
215
216        for entity in &schema.entities {
217            let (filename, content) = templates::generate_event_types_file(entity);
218            files.push(GeneratedFile {
219                relative_path: format!("events/{filename}"),
220                content,
221            });
222        }
223        files.push(GeneratedFile {
224            relative_path: "events/policies.rs".into(),
225            content: templates::generate_snapshot_policy_file(threshold),
226        });
227        files.push(GeneratedFile {
228            relative_path: "events/mod.rs".into(),
229            content: templates::generate_events_mod_file(&schema.entities),
230        });
231    }
232
233    // ── Sync (conditional) ────────────────────────────────────────
234    let sync_enabled = schema
235        .config
236        .sync
237        .as_ref()
238        .map(|s| s.enabled)
239        .unwrap_or(false);
240
241    if sync_enabled {
242        // Only generate sync files for entities with CRDT fields.
243        let crdt_entities: Vec<&schema::Entity> = schema
244            .entities
245            .iter()
246            .filter(|e| {
247                e.versions
248                    .iter()
249                    .any(|v| v.fields.iter().any(|f| f.crdt.is_some()))
250            })
251            .collect();
252
253        for entity in &crdt_entities {
254            let (filename, content) = templates::generate_sync_file(entity);
255            files.push(GeneratedFile {
256                relative_path: format!("sync/{filename}"),
257                content,
258            });
259        }
260        if !crdt_entities.is_empty() {
261            files.push(GeneratedFile {
262                relative_path: "sync/mod.rs".into(),
263                content: templates::generate_sync_mod_file(&schema.entities),
264            });
265        }
266    }
267
268    // ── Top-level mod.rs ──────────────────────────────────────────
269    files.push(GeneratedFile {
270        relative_path: "mod.rs".into(),
271        content: templates::generate_persistence_mod_file(&schema.entities, &schema.config),
272    });
273
274    Ok(GeneratedOutput {
275        output_dir: schema.config.output.clone(),
276        files,
277    })
278}
279
280#[cfg(test)]
281mod tests {
282    use super::*;
283
284    const SAMPLE_TOML: &str = r#"
285[config]
286output = "src/persistence"
287
288[[entity]]
289name = "Task"
290table = "tasks"
291
292[[entity.versions]]
293version = 1
294fields = [
295    { name = "title", type = "String" },
296    { name = "done", type = "bool" },
297]
298
299[[entity.versions]]
300version = 2
301fields = [
302    { name = "title", type = "String" },
303    { name = "done", type = "bool" },
304    { name = "priority", type = "Option<u8>", default = "None" },
305    { name = "tags", type = "Vec<String>", default = "Vec::new()" },
306]
307"#;
308
309    #[test]
310    fn end_to_end_generation() {
311        let output = generate_from_str(SAMPLE_TOML).unwrap();
312        assert_eq!(output.output_dir, "src/persistence");
313
314        let filenames: Vec<&str> = output
315            .files
316            .iter()
317            .map(|f| f.relative_path.as_str())
318            .collect();
319
320        // Models.
321        assert!(filenames.contains(&"models/task.rs"));
322        assert!(filenames.contains(&"models/mod.rs"));
323        // Migrations.
324        assert!(filenames.contains(&"migrations/task_migrations.rs"));
325        assert!(filenames.contains(&"migrations/helpers.rs"));
326        assert!(filenames.contains(&"migrations/mod.rs"));
327        // Repositories.
328        assert!(filenames.contains(&"repositories/traits.rs"));
329        assert!(filenames.contains(&"repositories/task_repo.rs"));
330        assert!(filenames.contains(&"repositories/mod.rs"));
331        // Store.
332        assert!(filenames.contains(&"store.rs"));
333        // Top-level.
334        assert!(filenames.contains(&"mod.rs"));
335        // No events or sync by default.
336        assert!(!filenames.iter().any(|f| f.starts_with("events/")));
337        assert!(!filenames.iter().any(|f| f.starts_with("sync/")));
338    }
339
340    #[test]
341    fn generated_structs_compile_ready() {
342        let output = generate_from_str(SAMPLE_TOML).unwrap();
343        let task_file = output
344            .files
345            .iter()
346            .find(|f| f.relative_path == "models/task.rs")
347            .unwrap();
348
349        assert!(task_file.content.contains("pub struct TaskV1"));
350        assert!(task_file.content.contains("pub struct TaskV2"));
351        assert!(task_file.content.contains("pub type Task = TaskV2;"));
352        assert!(task_file
353            .content
354            .contains("#[crdt_schema(version = 1, table = \"tasks\")]"));
355        assert!(task_file
356            .content
357            .contains("#[crdt_schema(version = 2, table = \"tasks\")]"));
358    }
359
360    #[test]
361    fn generated_migrations_compile_ready() {
362        let output = generate_from_str(SAMPLE_TOML).unwrap();
363        let mig_file = output
364            .files
365            .iter()
366            .find(|f| f.relative_path == "migrations/task_migrations.rs")
367            .unwrap();
368
369        assert!(mig_file.content.contains("#[migration(from = 1, to = 2)]"));
370        assert!(mig_file.content.contains("pub fn migrate_task_v1_to_v2"));
371        assert!(mig_file.content.contains("priority: None,"));
372        assert!(mig_file.content.contains("tags: Vec::new(),"));
373    }
374
375    #[test]
376    fn invalid_schema_returns_error() {
377        let bad_toml = r#"
378[config]
379output = ""
380
381[[entity]]
382name = "task"
383table = ""
384
385[[entity.versions]]
386version = 1
387fields = []
388"#;
389        let result = generate_from_str(bad_toml);
390        assert!(result.is_err());
391        if let Err(CodegenError::Validation(errs)) = result {
392            assert!(!errs.is_empty());
393        } else {
394            panic!("expected validation error");
395        }
396    }
397
398    #[test]
399    fn single_version_no_migrations() {
400        let toml = r#"
401[config]
402output = "out"
403
404[[entity]]
405name = "Note"
406table = "notes"
407
408[[entity.versions]]
409version = 1
410fields = [
411    { name = "text", type = "String" },
412]
413"#;
414        let output = generate_from_str(toml).unwrap();
415        let filenames: Vec<&str> = output
416            .files
417            .iter()
418            .map(|f| f.relative_path.as_str())
419            .collect();
420        assert!(filenames.contains(&"models/note.rs"));
421        assert!(!filenames.iter().any(|f| f.contains("note_migrations")));
422        assert!(filenames.contains(&"migrations/helpers.rs"));
423        assert!(filenames.contains(&"repositories/note_repo.rs"));
424    }
425
426    #[test]
427    fn multiple_entities() {
428        let toml = r#"
429[config]
430output = "out"
431
432[[entity]]
433name = "Task"
434table = "tasks"
435
436[[entity.versions]]
437version = 1
438fields = [
439    { name = "title", type = "String" },
440]
441
442[[entity]]
443name = "User"
444table = "users"
445
446[[entity.versions]]
447version = 1
448fields = [
449    { name = "name", type = "String" },
450]
451"#;
452        let output = generate_from_str(toml).unwrap();
453        let filenames: Vec<&str> = output
454            .files
455            .iter()
456            .map(|f| f.relative_path.as_str())
457            .collect();
458        assert!(filenames.contains(&"models/task.rs"));
459        assert!(filenames.contains(&"models/user.rs"));
460        assert!(filenames.contains(&"repositories/task_repo.rs"));
461        assert!(filenames.contains(&"repositories/user_repo.rs"));
462    }
463
464    #[test]
465    fn events_and_sync_conditional() {
466        let toml = r#"
467[config]
468output = "out"
469
470[config.events]
471enabled = true
472snapshot_threshold = 50
473
474[config.sync]
475enabled = true
476
477[[entity]]
478name = "Project"
479table = "projects"
480
481[[entity.versions]]
482version = 1
483fields = [
484    { name = "name", type = "String", crdt = "LWWRegister" },
485    { name = "members", type = "String", crdt = "ORSet" },
486]
487"#;
488        let output = generate_from_str(toml).unwrap();
489        let filenames: Vec<&str> = output
490            .files
491            .iter()
492            .map(|f| f.relative_path.as_str())
493            .collect();
494
495        // Events should be generated.
496        assert!(filenames.contains(&"events/mod.rs"));
497        assert!(filenames.contains(&"events/policies.rs"));
498        assert!(filenames.contains(&"events/project_events.rs"));
499
500        // Sync should be generated for CRDT entity.
501        assert!(filenames.contains(&"sync/mod.rs"));
502        assert!(filenames.contains(&"sync/project_sync.rs"));
503
504        // Check snapshot threshold.
505        let policies = output
506            .files
507            .iter()
508            .find(|f| f.relative_path == "events/policies.rs")
509            .unwrap();
510        assert!(policies.content.contains("event_threshold: 50,"));
511
512        // Check top-level mod includes events/sync.
513        let mod_file = output
514            .files
515            .iter()
516            .find(|f| f.relative_path == "mod.rs")
517            .unwrap();
518        assert!(mod_file.content.contains("pub mod events;"));
519        assert!(mod_file.content.contains("pub mod sync;"));
520    }
521}