panproto_expr/
typecheck.rs1use std::collections::HashMap;
8use std::hash::BuildHasher;
9use std::sync::Arc;
10
11use crate::expr::{Expr, ExprType};
12use crate::literal::Literal;
13
14#[derive(Debug, thiserror::Error)]
16#[non_exhaustive]
17pub enum TypeError {
18 #[error("unbound variable: {0}")]
20 UnboundVariable(String),
21
22 #[error("type mismatch: expected {expected:?}, got {got:?}")]
24 TypeMismatch {
25 expected: ExprType,
27 got: ExprType,
29 },
30
31 #[error("cannot infer type of expression")]
33 CannotInfer,
34}
35
36pub fn infer_type<S: BuildHasher>(
47 expr: &Expr,
48 env: &HashMap<Arc<str>, ExprType, S>,
49) -> Result<ExprType, TypeError> {
50 match expr {
51 Expr::Var(name) => env
52 .get(name.as_ref())
53 .copied()
54 .ok_or_else(|| TypeError::UnboundVariable(name.to_string())),
55 Expr::Lit(lit) => Ok(literal_type(lit)),
56 Expr::Builtin(op, _) => {
57 if let Some((_, out)) = op.signature() {
58 Ok(out)
59 } else {
60 Ok(ExprType::Any)
61 }
62 }
63 Expr::Lam(..) | Expr::App(..) | Expr::Field(..) | Expr::Index(..) => Ok(ExprType::Any),
64 Expr::Record(..) => Ok(ExprType::Record),
65 Expr::List(..) => Ok(ExprType::List),
66 Expr::Match { arms, .. } => {
67 if let Some((_, body)) = arms.first() {
68 infer_type(body, env)
69 } else {
70 Ok(ExprType::Any)
71 }
72 }
73 Expr::Let { name, value, body } => {
74 let val_type = infer_type(value, env)?;
75 let mut inner_env: HashMap<Arc<str>, ExprType> =
76 env.iter().map(|(k, v)| (Arc::clone(k), *v)).collect();
77 inner_env.insert(Arc::clone(name), val_type);
78 infer_type(body, &inner_env)
79 }
80 }
81}
82
83const fn literal_type(lit: &Literal) -> ExprType {
85 match lit {
86 Literal::Int(_) => ExprType::Int,
87 Literal::Float(_) => ExprType::Float,
88 Literal::Str(_) => ExprType::Str,
89 Literal::Bool(_) => ExprType::Bool,
90 Literal::Null | Literal::Bytes(_) | Literal::Closure { .. } => ExprType::Any,
91 Literal::List(_) => ExprType::List,
92 Literal::Record(_) => ExprType::Record,
93 }
94}
95
96pub fn validate_coercion(expr: &Expr, source: ExprType, target: ExprType) -> Result<(), TypeError> {
105 let env = HashMap::from([(Arc::from("v"), source)]);
106 let inferred = infer_type(expr, &env)?;
107 if inferred != ExprType::Any && inferred != target {
108 return Err(TypeError::TypeMismatch {
109 expected: target,
110 got: inferred,
111 });
112 }
113 Ok(())
114}
115
116#[cfg(test)]
117#[allow(clippy::unwrap_used)]
118mod tests {
119 use super::*;
120 use crate::expr::{BuiltinOp, Expr};
121 use crate::literal::Literal;
122
123 #[test]
124 fn infer_literal_types() {
125 let env = HashMap::new();
126 assert_eq!(
127 infer_type(&Expr::Lit(Literal::Int(42)), &env).unwrap(),
128 ExprType::Int
129 );
130 assert_eq!(
131 infer_type(&Expr::Lit(Literal::Float(1.0)), &env).unwrap(),
132 ExprType::Float
133 );
134 assert_eq!(
135 infer_type(&Expr::Lit(Literal::Str("hi".into())), &env).unwrap(),
136 ExprType::Str
137 );
138 assert_eq!(
139 infer_type(&Expr::Lit(Literal::Bool(true)), &env).unwrap(),
140 ExprType::Bool
141 );
142 }
143
144 #[test]
145 fn infer_var_from_env() {
146 let env = HashMap::from([(Arc::from("x"), ExprType::Int)]);
147 assert_eq!(infer_type(&Expr::var("x"), &env).unwrap(), ExprType::Int);
148 }
149
150 #[test]
151 fn unbound_var_errors() {
152 let env = HashMap::new();
153 let result = infer_type(&Expr::var("missing"), &env);
154 assert!(matches!(result, Err(TypeError::UnboundVariable(_))));
155 }
156
157 #[test]
158 fn infer_builtin_with_signature() {
159 let env = HashMap::new();
160 let expr = Expr::int_to_float(Expr::Lit(Literal::Int(1)));
161 assert_eq!(infer_type(&expr, &env).unwrap(), ExprType::Float);
162 }
163
164 #[test]
165 fn infer_let_propagates_type() {
166 let env = HashMap::new();
167 let expr = Expr::let_in(
168 "x",
169 Expr::Lit(Literal::Int(42)),
170 Expr::Builtin(BuiltinOp::IntToFloat, vec![Expr::var("x")]),
171 );
172 assert_eq!(infer_type(&expr, &env).unwrap(), ExprType::Float);
173 }
174
175 #[test]
176 fn validate_coercion_ok() {
177 let expr = Expr::Builtin(BuiltinOp::IntToFloat, vec![Expr::var("v")]);
178 validate_coercion(&expr, ExprType::Int, ExprType::Float).unwrap();
179 }
180
181 #[test]
182 fn validate_coercion_mismatch() {
183 let expr = Expr::Builtin(BuiltinOp::IntToFloat, vec![Expr::var("v")]);
184 let result = validate_coercion(&expr, ExprType::Int, ExprType::Str);
185 assert!(matches!(result, Err(TypeError::TypeMismatch { .. })));
186 }
187}