Skip to main content

panproto_expr/
typecheck.rs

1//! Lightweight type inference for expressions.
2//!
3//! Provides best-effort type inference over the expression language without
4//! requiring a full dependent type system. Useful for validating coercion
5//! expressions and catching obvious type mismatches early.
6
7use std::collections::HashMap;
8use std::hash::BuildHasher;
9use std::sync::Arc;
10
11use crate::expr::{Expr, ExprType};
12use crate::literal::Literal;
13
14/// Errors from type inference.
15#[derive(Debug, thiserror::Error)]
16#[non_exhaustive]
17pub enum TypeError {
18    /// A variable was not found in the type environment.
19    #[error("unbound variable: {0}")]
20    UnboundVariable(String),
21
22    /// An expression produced a type that does not match the expected type.
23    #[error("type mismatch: expected {expected:?}, got {got:?}")]
24    TypeMismatch {
25        /// The type that was expected.
26        expected: ExprType,
27        /// The type that was inferred.
28        got: ExprType,
29    },
30
31    /// The type of an expression could not be determined.
32    #[error("cannot infer type of expression")]
33    CannotInfer,
34}
35
36/// Infer the output type of an expression given input variable types.
37///
38/// This performs a single pass over the expression tree, using builtin
39/// signatures where available and falling back to `ExprType::Any` for
40/// polymorphic or opaque constructs.
41///
42/// # Errors
43///
44/// Returns [`TypeError::UnboundVariable`] if a variable is not in the
45/// environment.
46pub fn infer_type<S: BuildHasher>(
47    expr: &Expr,
48    env: &HashMap<Arc<str>, ExprType, S>,
49) -> Result<ExprType, TypeError> {
50    match expr {
51        Expr::Var(name) => env
52            .get(name.as_ref())
53            .copied()
54            .ok_or_else(|| TypeError::UnboundVariable(name.to_string())),
55        Expr::Lit(lit) => Ok(literal_type(lit)),
56        Expr::Builtin(op, _) => {
57            if let Some((_, out)) = op.signature() {
58                Ok(out)
59            } else {
60                Ok(ExprType::Any)
61            }
62        }
63        Expr::Lam(..) | Expr::App(..) | Expr::Field(..) | Expr::Index(..) => Ok(ExprType::Any),
64        Expr::Record(..) => Ok(ExprType::Record),
65        Expr::List(..) => Ok(ExprType::List),
66        Expr::Match { arms, .. } => {
67            if let Some((_, body)) = arms.first() {
68                infer_type(body, env)
69            } else {
70                Ok(ExprType::Any)
71            }
72        }
73        Expr::Let { name, value, body } => {
74            let val_type = infer_type(value, env)?;
75            let mut inner_env: HashMap<Arc<str>, ExprType> =
76                env.iter().map(|(k, v)| (Arc::clone(k), *v)).collect();
77            inner_env.insert(Arc::clone(name), val_type);
78            infer_type(body, &inner_env)
79        }
80    }
81}
82
83/// Map a literal value to its expression type.
84const fn literal_type(lit: &Literal) -> ExprType {
85    match lit {
86        Literal::Int(_) => ExprType::Int,
87        Literal::Float(_) => ExprType::Float,
88        Literal::Str(_) => ExprType::Str,
89        Literal::Bool(_) => ExprType::Bool,
90        Literal::Null | Literal::Bytes(_) | Literal::Closure { .. } => ExprType::Any,
91        Literal::List(_) => ExprType::List,
92        Literal::Record(_) => ExprType::Record,
93    }
94}
95
96/// Validate that a coercion expression produces the expected target type
97/// when given a single input variable `"v"` of the specified source type.
98///
99/// # Errors
100///
101/// Returns [`TypeError::TypeMismatch`] if the inferred type does not match
102/// `target` (and the inferred type is not `Any`), or [`TypeError::UnboundVariable`]
103/// if the expression references variables other than `"v"`.
104pub fn validate_coercion(expr: &Expr, source: ExprType, target: ExprType) -> Result<(), TypeError> {
105    let env = HashMap::from([(Arc::from("v"), source)]);
106    let inferred = infer_type(expr, &env)?;
107    if inferred != ExprType::Any && inferred != target {
108        return Err(TypeError::TypeMismatch {
109            expected: target,
110            got: inferred,
111        });
112    }
113    Ok(())
114}
115
116#[cfg(test)]
117#[allow(clippy::unwrap_used)]
118mod tests {
119    use super::*;
120    use crate::expr::{BuiltinOp, Expr};
121    use crate::literal::Literal;
122
123    #[test]
124    fn infer_literal_types() {
125        let env = HashMap::new();
126        assert_eq!(
127            infer_type(&Expr::Lit(Literal::Int(42)), &env).unwrap(),
128            ExprType::Int
129        );
130        assert_eq!(
131            infer_type(&Expr::Lit(Literal::Float(1.0)), &env).unwrap(),
132            ExprType::Float
133        );
134        assert_eq!(
135            infer_type(&Expr::Lit(Literal::Str("hi".into())), &env).unwrap(),
136            ExprType::Str
137        );
138        assert_eq!(
139            infer_type(&Expr::Lit(Literal::Bool(true)), &env).unwrap(),
140            ExprType::Bool
141        );
142    }
143
144    #[test]
145    fn infer_var_from_env() {
146        let env = HashMap::from([(Arc::from("x"), ExprType::Int)]);
147        assert_eq!(infer_type(&Expr::var("x"), &env).unwrap(), ExprType::Int);
148    }
149
150    #[test]
151    fn unbound_var_errors() {
152        let env = HashMap::new();
153        let result = infer_type(&Expr::var("missing"), &env);
154        assert!(matches!(result, Err(TypeError::UnboundVariable(_))));
155    }
156
157    #[test]
158    fn infer_builtin_with_signature() {
159        let env = HashMap::new();
160        let expr = Expr::int_to_float(Expr::Lit(Literal::Int(1)));
161        assert_eq!(infer_type(&expr, &env).unwrap(), ExprType::Float);
162    }
163
164    #[test]
165    fn infer_let_propagates_type() {
166        let env = HashMap::new();
167        let expr = Expr::let_in(
168            "x",
169            Expr::Lit(Literal::Int(42)),
170            Expr::Builtin(BuiltinOp::IntToFloat, vec![Expr::var("x")]),
171        );
172        assert_eq!(infer_type(&expr, &env).unwrap(), ExprType::Float);
173    }
174
175    #[test]
176    fn validate_coercion_ok() {
177        let expr = Expr::Builtin(BuiltinOp::IntToFloat, vec![Expr::var("v")]);
178        validate_coercion(&expr, ExprType::Int, ExprType::Float).unwrap();
179    }
180
181    #[test]
182    fn validate_coercion_mismatch() {
183        let expr = Expr::Builtin(BuiltinOp::IntToFloat, vec![Expr::var("v")]);
184        let result = validate_coercion(&expr, ExprType::Int, ExprType::Str);
185        assert!(matches!(result, Err(TypeError::TypeMismatch { .. })));
186    }
187}