csw_generate/
rust.rs

1//! Rust code generation.
2//!
3//! This module generates Rust source code for type checkers and interpreters
4//! from derived type systems.
5
6use csw_derive::TypeSystem;
7use std::path::Path;
8use thiserror::Error;
9
10use crate::Generator;
11
12/// Errors that can occur during Rust code generation.
13#[derive(Debug, Error)]
14pub enum RustGeneratorError {
15    /// IO error while writing files.
16    #[error("IO error: {0}")]
17    Io(#[from] std::io::Error),
18
19    /// The output directory doesn't exist and couldn't be created.
20    #[error("failed to create output directory: {0}")]
21    CreateDir(std::io::Error),
22}
23
24/// Rust code generator.
25///
26/// Generates complete, compilable Rust crates from derived type systems.
27pub struct RustGenerator;
28
29impl Generator for RustGenerator {
30    type Error = RustGeneratorError;
31
32    fn generate(ts: &TypeSystem, output_dir: &Path) -> Result<(), Self::Error> {
33        // Create output directory
34        std::fs::create_dir_all(output_dir).map_err(RustGeneratorError::CreateDir)?;
35        std::fs::create_dir_all(output_dir.join("src"))?;
36
37        // Generate files
38        Self::generate_cargo_toml(ts, output_dir)?;
39        Self::generate_lib_rs(ts, output_dir)?;
40        Self::generate_types_rs(ts, output_dir)?;
41        Self::generate_terms_rs(ts, output_dir)?;
42        Self::generate_checker_rs(ts, output_dir)?;
43        Self::generate_interpreter_rs(ts, output_dir)?;
44        Self::generate_readme(ts, output_dir)?;
45
46        Ok(())
47    }
48}
49
50impl RustGenerator {
51    fn generate_cargo_toml(ts: &TypeSystem, output_dir: &Path) -> Result<(), RustGeneratorError> {
52        let name = ts.name.to_lowercase().replace(' ', "-");
53        let content = format!(
54            r#"[package]
55name = "{name}"
56version = "0.1.0"
57edition = "2021"
58description = "Generated type system: {}"
59
60[dependencies]
61thiserror = "1.0"
62
63[dev-dependencies]
64"#,
65            ts.name
66        );
67
68        std::fs::write(output_dir.join("Cargo.toml"), content)?;
69        Ok(())
70    }
71
72    fn generate_lib_rs(ts: &TypeSystem, output_dir: &Path) -> Result<(), RustGeneratorError> {
73        let content = format!(
74            r#"//! # {}
75//!
76//! Auto-generated type system from categorical specification.
77//!
78//! This crate provides:
79//! - Type definitions
80//! - Term definitions
81//! - Type checker
82//! - Interpreter/evaluator
83
84mod types;
85mod terms;
86mod checker;
87mod interpreter;
88
89pub use types::*;
90pub use terms::*;
91pub use checker::*;
92pub use interpreter::*;
93"#,
94            ts.name
95        );
96
97        std::fs::write(output_dir.join("src/lib.rs"), content)?;
98        Ok(())
99    }
100
101    fn generate_types_rs(ts: &TypeSystem, output_dir: &Path) -> Result<(), RustGeneratorError> {
102        let mut variants = String::new();
103
104        for tc in &ts.type_constructors {
105            let variant = match tc.arity {
106                0 => format!("    /// {} type\n    {},\n", tc.name, tc.name),
107                2 => format!(
108                    "    /// {} type ({})\n    {}(Box<Type>, Box<Type>),\n",
109                    tc.name, tc.symbol, tc.name
110                ),
111                _ => format!("    {},\n", tc.name),
112            };
113            variants.push_str(&variant);
114        }
115
116        let content = format!(
117            r#"//! Type definitions for {}.
118
119/// Types in the {} type system.
120#[derive(Clone, Debug, PartialEq, Eq)]
121pub enum Type {{
122{variants}}}
123
124impl std::fmt::Display for Type {{
125    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {{
126        match self {{
127            // TODO: Implement pretty printing
128            _ => write!(f, "{{:?}}", self),
129        }}
130    }}
131}}
132"#,
133            ts.name, ts.name
134        );
135
136        std::fs::write(output_dir.join("src/types.rs"), content)?;
137        Ok(())
138    }
139
140    fn generate_terms_rs(ts: &TypeSystem, output_dir: &Path) -> Result<(), RustGeneratorError> {
141        let content = format!(
142            r#"//! Term definitions for {}.
143
144use crate::Type;
145
146/// Terms in the {} type system.
147#[derive(Clone, Debug)]
148pub enum Term {{
149    /// Variable reference
150    Var(String),
151
152    /// Unit value
153    Unit,
154
155    /// Pair construction
156    Pair(Box<Term>, Box<Term>),
157
158    /// First projection
159    Fst(Box<Term>),
160
161    /// Second projection
162    Snd(Box<Term>),
163
164    /// Lambda abstraction
165    Abs(String, Box<Type>, Box<Term>),
166
167    /// Function application
168    App(Box<Term>, Box<Term>),
169
170    /// Left injection (sum types)
171    Inl(Box<Term>, Box<Type>),
172
173    /// Right injection (sum types)
174    Inr(Box<Term>, Box<Type>),
175
176    /// Case analysis
177    Case(Box<Term>, String, Box<Term>, String, Box<Term>),
178}}
179"#,
180            ts.name, ts.name
181        );
182
183        std::fs::write(output_dir.join("src/terms.rs"), content)?;
184        Ok(())
185    }
186
187    fn generate_checker_rs(ts: &TypeSystem, output_dir: &Path) -> Result<(), RustGeneratorError> {
188        let content = format!(
189            r#"//! Type checker for {}.
190
191use crate::{{Term, Type}};
192use std::collections::HashMap;
193use thiserror::Error;
194
195/// Type checking errors.
196#[derive(Debug, Error)]
197pub enum TypeError {{
198    #[error("unbound variable: {{0}}")]
199    UnboundVar(String),
200
201    #[error("type mismatch: expected {{expected}}, got {{actual}}")]
202    TypeMismatch {{ expected: Type, actual: Type }},
203
204    #[error("expected function type, got {{0}}")]
205    ExpectedFunction(Type),
206
207    #[error("expected product type, got {{0}}")]
208    ExpectedProduct(Type),
209
210    #[error("expected sum type, got {{0}}")]
211    ExpectedSum(Type),
212}}
213
214/// Type checking context.
215pub type Context = HashMap<String, Type>;
216
217/// Type checker for the {} type system.
218pub struct Checker;
219
220impl Checker {{
221    /// Check the type of a term in a given context.
222    pub fn check(ctx: &Context, term: &Term) -> Result<Type, TypeError> {{
223        match term {{
224            Term::Var(x) => ctx
225                .get(x)
226                .cloned()
227                .ok_or_else(|| TypeError::UnboundVar(x.clone())),
228
229            Term::Unit => Ok(Type::Unit),
230
231            Term::Pair(a, b) => {{
232                let ta = Self::check(ctx, a)?;
233                let tb = Self::check(ctx, b)?;
234                Ok(Type::Product(Box::new(ta), Box::new(tb)))
235            }}
236
237            Term::Fst(p) => {{
238                match Self::check(ctx, p)? {{
239                    Type::Product(a, _) => Ok(*a),
240                    t => Err(TypeError::ExpectedProduct(t)),
241                }}
242            }}
243
244            Term::Snd(p) => {{
245                match Self::check(ctx, p)? {{
246                    Type::Product(_, b) => Ok(*b),
247                    t => Err(TypeError::ExpectedProduct(t)),
248                }}
249            }}
250
251            Term::Abs(x, ty, body) => {{
252                let mut new_ctx = ctx.clone();
253                new_ctx.insert(x.clone(), (**ty).clone());
254                let body_ty = Self::check(&new_ctx, body)?;
255                Ok(Type::Arrow(ty.clone(), Box::new(body_ty)))
256            }}
257
258            Term::App(f, a) => {{
259                match Self::check(ctx, f)? {{
260                    Type::Arrow(param_ty, ret_ty) => {{
261                        let arg_ty = Self::check(ctx, a)?;
262                        if *param_ty == arg_ty {{
263                            Ok(*ret_ty)
264                        }} else {{
265                            Err(TypeError::TypeMismatch {{
266                                expected: *param_ty,
267                                actual: arg_ty,
268                            }})
269                        }}
270                    }}
271                    t => Err(TypeError::ExpectedFunction(t)),
272                }}
273            }}
274
275            Term::Inl(a, ty_b) => {{
276                let ty_a = Self::check(ctx, a)?;
277                Ok(Type::Coproduct(Box::new(ty_a), ty_b.clone()))
278            }}
279
280            Term::Inr(b, ty_a) => {{
281                let ty_b = Self::check(ctx, b)?;
282                Ok(Type::Coproduct(ty_a.clone(), Box::new(ty_b)))
283            }}
284
285            Term::Case(e, x, e1, y, e2) => {{
286                match Self::check(ctx, e)? {{
287                    Type::Coproduct(ty_a, ty_b) => {{
288                        let mut ctx1 = ctx.clone();
289                        ctx1.insert(x.clone(), *ty_a);
290                        let ty1 = Self::check(&ctx1, e1)?;
291
292                        let mut ctx2 = ctx.clone();
293                        ctx2.insert(y.clone(), *ty_b);
294                        let ty2 = Self::check(&ctx2, e2)?;
295
296                        if ty1 == ty2 {{
297                            Ok(ty1)
298                        }} else {{
299                            Err(TypeError::TypeMismatch {{
300                                expected: ty1,
301                                actual: ty2,
302                            }})
303                        }}
304                    }}
305                    t => Err(TypeError::ExpectedSum(t)),
306                }}
307            }}
308        }}
309    }}
310}}
311"#,
312            ts.name, ts.name
313        );
314
315        std::fs::write(output_dir.join("src/checker.rs"), content)?;
316        Ok(())
317    }
318
319    fn generate_interpreter_rs(
320        ts: &TypeSystem,
321        output_dir: &Path,
322    ) -> Result<(), RustGeneratorError> {
323        let content = format!(
324            r#"//! Interpreter for {}.
325
326use crate::Term;
327use std::collections::HashMap;
328
329/// Runtime values.
330#[derive(Clone, Debug)]
331pub enum Value {{
332    /// Unit value
333    Unit,
334
335    /// Pair of values
336    Pair(Box<Value>, Box<Value>),
337
338    /// Closure (captured environment + parameter + body)
339    Closure(Env, String, Box<Term>),
340
341    /// Left injection
342    Inl(Box<Value>),
343
344    /// Right injection
345    Inr(Box<Value>),
346}}
347
348/// Runtime environment.
349pub type Env = HashMap<String, Value>;
350
351/// Interpreter for the {} type system.
352pub struct Interpreter;
353
354impl Interpreter {{
355    /// Evaluate a term in a given environment.
356    pub fn eval(env: &Env, term: &Term) -> Value {{
357        match term {{
358            Term::Var(x) => env.get(x).cloned().expect("unbound variable"),
359
360            Term::Unit => Value::Unit,
361
362            Term::Pair(a, b) => {{
363                let va = Self::eval(env, a);
364                let vb = Self::eval(env, b);
365                Value::Pair(Box::new(va), Box::new(vb))
366            }}
367
368            Term::Fst(p) => {{
369                match Self::eval(env, p) {{
370                    Value::Pair(a, _) => *a,
371                    _ => panic!("fst of non-pair"),
372                }}
373            }}
374
375            Term::Snd(p) => {{
376                match Self::eval(env, p) {{
377                    Value::Pair(_, b) => *b,
378                    _ => panic!("snd of non-pair"),
379                }}
380            }}
381
382            Term::Abs(x, _, body) => {{
383                Value::Closure(env.clone(), x.clone(), body.clone())
384            }}
385
386            Term::App(f, a) => {{
387                let vf = Self::eval(env, f);
388                let va = Self::eval(env, a);
389                match vf {{
390                    Value::Closure(mut cenv, x, body) => {{
391                        cenv.insert(x, va);
392                        Self::eval(&cenv, &body)
393                    }}
394                    _ => panic!("application of non-function"),
395                }}
396            }}
397
398            Term::Inl(a, _) => Value::Inl(Box::new(Self::eval(env, a))),
399
400            Term::Inr(b, _) => Value::Inr(Box::new(Self::eval(env, b))),
401
402            Term::Case(e, x, e1, y, e2) => {{
403                match Self::eval(env, e) {{
404                    Value::Inl(va) => {{
405                        let mut new_env = env.clone();
406                        new_env.insert(x.clone(), *va);
407                        Self::eval(&new_env, e1)
408                    }}
409                    Value::Inr(vb) => {{
410                        let mut new_env = env.clone();
411                        new_env.insert(y.clone(), *vb);
412                        Self::eval(&new_env, e2)
413                    }}
414                    _ => panic!("case on non-sum"),
415                }}
416            }}
417        }}
418    }}
419}}
420"#,
421            ts.name, ts.name
422        );
423
424        std::fs::write(output_dir.join("src/interpreter.rs"), content)?;
425        Ok(())
426    }
427
428    fn generate_readme(ts: &TypeSystem, output_dir: &Path) -> Result<(), RustGeneratorError> {
429        let content = format!(
430            r#"# {}
431
432Auto-generated type system from categorical specification.
433
434## Structural Rules
435
436- Weakening: {}
437- Contraction: {}
438- Exchange: {}
439
440## Usage
441
442```rust
443use {}::*;
444
445// Create a context
446let mut ctx = Context::new();
447ctx.insert("x".to_string(), Type::Int);
448
449// Type check a term
450let term = Term::Var("x".to_string());
451let ty = Checker::check(&ctx, &term).unwrap();
452
453// Evaluate a term
454let mut env = Env::new();
455// ... add bindings ...
456let value = Interpreter::eval(&env, &term);
457```
458
459## Generated from
460
461This type system was derived from a categorical specification using the
462[Categorical Semantics Workbench](https://github.com/ibrahimcesar/categorical-semantics-workbench).
463"#,
464            ts.name,
465            if ts.structural.weakening { "✓" } else { "✗" },
466            if ts.structural.contraction { "✓" } else { "✗" },
467            if ts.structural.exchange { "✓" } else { "✗" },
468            ts.name.to_lowercase().replace(' ', "_")
469        );
470
471        std::fs::write(output_dir.join("README.md"), content)?;
472        Ok(())
473    }
474}
475
476#[cfg(test)]
477mod tests {
478    use super::*;
479    use csw_core::CategoryBuilder;
480    use csw_derive::Deriver;
481    use std::path::PathBuf;
482
483    #[test]
484    fn test_generate_stlc() {
485        let ccc = CategoryBuilder::new("STLC")
486            .with_base("Int")
487            .with_terminal()
488            .with_products()
489            .with_exponentials()
490            .cartesian()
491            .build()
492            .unwrap();
493
494        let ts = Deriver::derive(&ccc);
495
496        let temp_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
497            .join("target")
498            .join("test-output")
499            .join("stlc");
500
501        RustGenerator::generate(&ts, &temp_dir).unwrap();
502
503        // Verify files were created
504        assert!(temp_dir.join("Cargo.toml").exists());
505        assert!(temp_dir.join("src/lib.rs").exists());
506        assert!(temp_dir.join("src/types.rs").exists());
507        assert!(temp_dir.join("src/checker.rs").exists());
508    }
509}