use csw_derive::TypeSystem;
use std::path::Path;
use thiserror::Error;
use crate::Generator;
#[derive(Debug, Error)]
pub enum RustGeneratorError {
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
#[error("failed to create output directory: {0}")]
CreateDir(std::io::Error),
}
pub struct RustGenerator;
impl Generator for RustGenerator {
type Error = RustGeneratorError;
fn generate(ts: &TypeSystem, output_dir: &Path) -> Result<(), Self::Error> {
std::fs::create_dir_all(output_dir).map_err(RustGeneratorError::CreateDir)?;
std::fs::create_dir_all(output_dir.join("src"))?;
Self::generate_cargo_toml(ts, output_dir)?;
Self::generate_lib_rs(ts, output_dir)?;
Self::generate_types_rs(ts, output_dir)?;
Self::generate_terms_rs(ts, output_dir)?;
Self::generate_checker_rs(ts, output_dir)?;
Self::generate_interpreter_rs(ts, output_dir)?;
Self::generate_readme(ts, output_dir)?;
Ok(())
}
}
impl RustGenerator {
fn generate_cargo_toml(ts: &TypeSystem, output_dir: &Path) -> Result<(), RustGeneratorError> {
let name = ts.name.to_lowercase().replace(' ', "-");
let content = format!(
r#"[package]
name = "{name}"
version = "0.1.0"
edition = "2021"
description = "Generated type system: {}"
[dependencies]
thiserror = "1.0"
[dev-dependencies]
"#,
ts.name
);
std::fs::write(output_dir.join("Cargo.toml"), content)?;
Ok(())
}
fn generate_lib_rs(ts: &TypeSystem, output_dir: &Path) -> Result<(), RustGeneratorError> {
let content = format!(
r#"//! # {}
//!
//! Auto-generated type system from categorical specification.
//!
//! This crate provides:
//! - Type definitions
//! - Term definitions
//! - Type checker
//! - Interpreter/evaluator
mod types;
mod terms;
mod checker;
mod interpreter;
pub use types::*;
pub use terms::*;
pub use checker::*;
pub use interpreter::*;
"#,
ts.name
);
std::fs::write(output_dir.join("src/lib.rs"), content)?;
Ok(())
}
fn generate_types_rs(ts: &TypeSystem, output_dir: &Path) -> Result<(), RustGeneratorError> {
let mut variants = String::new();
for tc in &ts.type_constructors {
let variant = match tc.arity {
0 => format!(" /// {} type\n {},\n", tc.name, tc.name),
2 => format!(
" /// {} type ({})\n {}(Box<Type>, Box<Type>),\n",
tc.name, tc.symbol, tc.name
),
_ => format!(" {},\n", tc.name),
};
variants.push_str(&variant);
}
let content = format!(
r#"//! Type definitions for {}.
/// Types in the {} type system.
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum Type {{
{variants}}}
impl std::fmt::Display for Type {{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {{
match self {{
// TODO: Implement pretty printing
_ => write!(f, "{{:?}}", self),
}}
}}
}}
"#,
ts.name, ts.name
);
std::fs::write(output_dir.join("src/types.rs"), content)?;
Ok(())
}
fn generate_terms_rs(ts: &TypeSystem, output_dir: &Path) -> Result<(), RustGeneratorError> {
let content = format!(
r#"//! Term definitions for {}.
use crate::Type;
/// Terms in the {} type system.
#[derive(Clone, Debug)]
pub enum Term {{
/// Variable reference
Var(String),
/// Unit value
Unit,
/// Pair construction
Pair(Box<Term>, Box<Term>),
/// First projection
Fst(Box<Term>),
/// Second projection
Snd(Box<Term>),
/// Lambda abstraction
Abs(String, Box<Type>, Box<Term>),
/// Function application
App(Box<Term>, Box<Term>),
/// Left injection (sum types)
Inl(Box<Term>, Box<Type>),
/// Right injection (sum types)
Inr(Box<Term>, Box<Type>),
/// Case analysis
Case(Box<Term>, String, Box<Term>, String, Box<Term>),
}}
"#,
ts.name, ts.name
);
std::fs::write(output_dir.join("src/terms.rs"), content)?;
Ok(())
}
fn generate_checker_rs(ts: &TypeSystem, output_dir: &Path) -> Result<(), RustGeneratorError> {
let content = format!(
r#"//! Type checker for {}.
use crate::{{Term, Type}};
use std::collections::HashMap;
use thiserror::Error;
/// Type checking errors.
#[derive(Debug, Error)]
pub enum TypeError {{
#[error("unbound variable: {{0}}")]
UnboundVar(String),
#[error("type mismatch: expected {{expected}}, got {{actual}}")]
TypeMismatch {{ expected: Type, actual: Type }},
#[error("expected function type, got {{0}}")]
ExpectedFunction(Type),
#[error("expected product type, got {{0}}")]
ExpectedProduct(Type),
#[error("expected sum type, got {{0}}")]
ExpectedSum(Type),
}}
/// Type checking context.
pub type Context = HashMap<String, Type>;
/// Type checker for the {} type system.
pub struct Checker;
impl Checker {{
/// Check the type of a term in a given context.
pub fn check(ctx: &Context, term: &Term) -> Result<Type, TypeError> {{
match term {{
Term::Var(x) => ctx
.get(x)
.cloned()
.ok_or_else(|| TypeError::UnboundVar(x.clone())),
Term::Unit => Ok(Type::Unit),
Term::Pair(a, b) => {{
let ta = Self::check(ctx, a)?;
let tb = Self::check(ctx, b)?;
Ok(Type::Product(Box::new(ta), Box::new(tb)))
}}
Term::Fst(p) => {{
match Self::check(ctx, p)? {{
Type::Product(a, _) => Ok(*a),
t => Err(TypeError::ExpectedProduct(t)),
}}
}}
Term::Snd(p) => {{
match Self::check(ctx, p)? {{
Type::Product(_, b) => Ok(*b),
t => Err(TypeError::ExpectedProduct(t)),
}}
}}
Term::Abs(x, ty, body) => {{
let mut new_ctx = ctx.clone();
new_ctx.insert(x.clone(), (**ty).clone());
let body_ty = Self::check(&new_ctx, body)?;
Ok(Type::Arrow(ty.clone(), Box::new(body_ty)))
}}
Term::App(f, a) => {{
match Self::check(ctx, f)? {{
Type::Arrow(param_ty, ret_ty) => {{
let arg_ty = Self::check(ctx, a)?;
if *param_ty == arg_ty {{
Ok(*ret_ty)
}} else {{
Err(TypeError::TypeMismatch {{
expected: *param_ty,
actual: arg_ty,
}})
}}
}}
t => Err(TypeError::ExpectedFunction(t)),
}}
}}
Term::Inl(a, ty_b) => {{
let ty_a = Self::check(ctx, a)?;
Ok(Type::Coproduct(Box::new(ty_a), ty_b.clone()))
}}
Term::Inr(b, ty_a) => {{
let ty_b = Self::check(ctx, b)?;
Ok(Type::Coproduct(ty_a.clone(), Box::new(ty_b)))
}}
Term::Case(e, x, e1, y, e2) => {{
match Self::check(ctx, e)? {{
Type::Coproduct(ty_a, ty_b) => {{
let mut ctx1 = ctx.clone();
ctx1.insert(x.clone(), *ty_a);
let ty1 = Self::check(&ctx1, e1)?;
let mut ctx2 = ctx.clone();
ctx2.insert(y.clone(), *ty_b);
let ty2 = Self::check(&ctx2, e2)?;
if ty1 == ty2 {{
Ok(ty1)
}} else {{
Err(TypeError::TypeMismatch {{
expected: ty1,
actual: ty2,
}})
}}
}}
t => Err(TypeError::ExpectedSum(t)),
}}
}}
}}
}}
}}
"#,
ts.name, ts.name
);
std::fs::write(output_dir.join("src/checker.rs"), content)?;
Ok(())
}
fn generate_interpreter_rs(
ts: &TypeSystem,
output_dir: &Path,
) -> Result<(), RustGeneratorError> {
let content = format!(
r#"//! Interpreter for {}.
use crate::Term;
use std::collections::HashMap;
/// Runtime values.
#[derive(Clone, Debug)]
pub enum Value {{
/// Unit value
Unit,
/// Pair of values
Pair(Box<Value>, Box<Value>),
/// Closure (captured environment + parameter + body)
Closure(Env, String, Box<Term>),
/// Left injection
Inl(Box<Value>),
/// Right injection
Inr(Box<Value>),
}}
/// Runtime environment.
pub type Env = HashMap<String, Value>;
/// Interpreter for the {} type system.
pub struct Interpreter;
impl Interpreter {{
/// Evaluate a term in a given environment.
pub fn eval(env: &Env, term: &Term) -> Value {{
match term {{
Term::Var(x) => env.get(x).cloned().expect("unbound variable"),
Term::Unit => Value::Unit,
Term::Pair(a, b) => {{
let va = Self::eval(env, a);
let vb = Self::eval(env, b);
Value::Pair(Box::new(va), Box::new(vb))
}}
Term::Fst(p) => {{
match Self::eval(env, p) {{
Value::Pair(a, _) => *a,
_ => panic!("fst of non-pair"),
}}
}}
Term::Snd(p) => {{
match Self::eval(env, p) {{
Value::Pair(_, b) => *b,
_ => panic!("snd of non-pair"),
}}
}}
Term::Abs(x, _, body) => {{
Value::Closure(env.clone(), x.clone(), body.clone())
}}
Term::App(f, a) => {{
let vf = Self::eval(env, f);
let va = Self::eval(env, a);
match vf {{
Value::Closure(mut cenv, x, body) => {{
cenv.insert(x, va);
Self::eval(&cenv, &body)
}}
_ => panic!("application of non-function"),
}}
}}
Term::Inl(a, _) => Value::Inl(Box::new(Self::eval(env, a))),
Term::Inr(b, _) => Value::Inr(Box::new(Self::eval(env, b))),
Term::Case(e, x, e1, y, e2) => {{
match Self::eval(env, e) {{
Value::Inl(va) => {{
let mut new_env = env.clone();
new_env.insert(x.clone(), *va);
Self::eval(&new_env, e1)
}}
Value::Inr(vb) => {{
let mut new_env = env.clone();
new_env.insert(y.clone(), *vb);
Self::eval(&new_env, e2)
}}
_ => panic!("case on non-sum"),
}}
}}
}}
}}
}}
"#,
ts.name, ts.name
);
std::fs::write(output_dir.join("src/interpreter.rs"), content)?;
Ok(())
}
fn generate_readme(ts: &TypeSystem, output_dir: &Path) -> Result<(), RustGeneratorError> {
let content = format!(
r#"# {}
Auto-generated type system from categorical specification.
## Structural Rules
- Weakening: {}
- Contraction: {}
- Exchange: {}
## Usage
```rust
use {}::*;
// Create a context
let mut ctx = Context::new();
ctx.insert("x".to_string(), Type::Int);
// Type check a term
let term = Term::Var("x".to_string());
let ty = Checker::check(&ctx, &term).unwrap();
// Evaluate a term
let mut env = Env::new();
// ... add bindings ...
let value = Interpreter::eval(&env, &term);
```
## Generated from
This type system was derived from a categorical specification using the
[Categorical Semantics Workbench](https://github.com/ibrahimcesar/categorical-semantics-workbench).
"#,
ts.name,
if ts.structural.weakening { "✓" } else { "✗" },
if ts.structural.contraction { "✓" } else { "✗" },
if ts.structural.exchange { "✓" } else { "✗" },
ts.name.to_lowercase().replace(' ', "_")
);
std::fs::write(output_dir.join("README.md"), content)?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use csw_core::CategoryBuilder;
use csw_derive::Deriver;
use std::path::PathBuf;
#[test]
fn test_generate_stlc() {
let ccc = CategoryBuilder::new("STLC")
.with_base("Int")
.with_terminal()
.with_products()
.with_exponentials()
.cartesian()
.build()
.unwrap();
let ts = Deriver::derive(&ccc);
let temp_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.join("target")
.join("test-output")
.join("stlc");
RustGenerator::generate(&ts, &temp_dir).unwrap();
assert!(temp_dir.join("Cargo.toml").exists());
assert!(temp_dir.join("src/lib.rs").exists());
assert!(temp_dir.join("src/types.rs").exists());
assert!(temp_dir.join("src/checker.rs").exists());
}
}