kodept_inference/
language.rs

1use std::fmt::{Debug, Display, Formatter};
2
3use derive_more::{Display, From};
4use itertools::Itertools;
5use crate::r#type::MonomorphicType;
6
7#[derive(PartialEq, Eq, Hash)]
8pub struct  BVar {
9    pub var: Var,
10    pub ty: Option<MonomorphicType>
11}
12
13#[derive(Display, Clone, PartialEq, Eq, Hash)]
14#[display(fmt = "{name}")]
15pub struct Var {
16    pub name: String,
17}
18
19#[derive(PartialEq, Eq, Hash)]
20pub struct App {
21    pub arg: Box<Language>,
22    pub func: Box<Language>,
23}
24
25#[derive(PartialEq, Eq, Hash)]
26pub struct Lambda {
27    pub bind: BVar,
28    pub expr: Box<Language>,
29}
30
31#[derive(PartialEq, Eq, Hash)]
32pub struct Let {
33    pub binder: Box<Language>,
34    pub bind: BVar,
35    pub usage: Box<Language>,
36}
37
38#[derive(PartialEq, Eq, Hash)]
39pub enum Literal {
40    Integral(String),
41    Floating(String),
42    Tuple(Vec<Language>),
43}
44
45#[derive(PartialEq, Eq, Hash)]
46pub enum Special {
47    If {
48        condition: Box<Language>,
49        body: Box<Language>,
50        otherwise: Box<Language>,
51    },
52}
53
54#[derive(Debug, From, Display, PartialEq, Eq, Hash)]
55pub enum Language {
56    Var(Var),
57    App(App),
58    Lambda(Lambda),
59    Let(Let),
60    Literal(Literal),
61    Special(Special),
62}
63
64pub fn var<V: Into<Var>>(id: V) -> Var {
65    id.into()
66}
67
68pub fn app<N: Into<Language>, M: Into<Language>>(arg: N, func: M) -> App {
69    App {
70        arg: Box::new(arg.into()),
71        func: Box::new(func.into()),
72    }
73}
74
75pub fn lambda<B, E>(bind: B, expr: E) -> Lambda
76where
77    BVar: From<B>,
78    E: Into<Language>,
79{
80    Lambda {
81        bind: bind.into(),
82        expr: Box::new(expr.into()),
83    }
84}
85
86pub fn r#let<V, B, U>(bind: V, binder: B, usage: U) -> Let
87where
88    B: Into<Language>,
89    U: Into<Language>,
90    BVar: From<V>,
91{
92    Let {
93        binder: Box::new(binder.into()),
94        bind: bind.into(),
95        usage: Box::new(usage.into()),
96    }
97}
98
99pub fn r#if(
100    condition: impl Into<Language>,
101    body: impl Into<Language>,
102    otherwise: impl Into<Language>,
103) -> Special {
104    Special::If {
105        condition: Box::new(condition.into()),
106        body: Box::new(body.into()),
107        otherwise: Box::new(otherwise.into()),
108    }
109}
110
111pub fn bounded(v: impl Into<Var>, t: impl Into<MonomorphicType>) -> BVar {
112    BVar {
113        var: v.into(),
114        ty: Some(t.into()),
115    }
116}
117
118impl<S: Into<Var>> From<S> for BVar {
119    fn from(value: S) -> Self {
120        Self {
121            var: value.into(),
122            ty: None,
123        }
124    }
125}
126
127impl Display for App {
128    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
129        match self.arg.as_ref() {
130            Language::Var(_) | Language::Literal(_) => match self.func.as_ref() {
131                Language::Var(_) | Language::Literal(_) => write!(f, "{} {}", self.func, self.arg),
132                Language::App(App { func, .. }) => match func.as_ref() {
133                    Language::App(_) => write!(f, "{} ({})", self.func, self.arg),
134                    _ => write!(f, "{} {}", self.func, self.arg),
135                },
136                _ => write!(f, "{} ({})", self.func, self.arg),
137            },
138            _ => write!(f, "{} ({})", self.func, self.arg),
139        }
140    }
141}
142
143impl Display for Lambda {
144    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
145        write!(f, "λ{}. {}", self.bind, self.expr)
146    }
147}
148
149impl Display for Let {
150    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
151        write!(f, "(let {} = {} in {})", self.bind, self.binder, self.usage)
152    }
153}
154
155impl Display for Literal {
156    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
157        match self {
158            Literal::Integral(n) | Literal::Floating(n) => write!(f, "{n}"),
159            Literal::Tuple(t) => write!(f, "({})", t.iter().join(", ")),
160        }
161    }
162}
163
164impl Display for Special {
165    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
166        match self {
167            Special::If {
168                condition,
169                body,
170                otherwise,
171            } => write!(f, "if ({condition}) ({body}) ({otherwise})"),
172        }
173    }
174}
175
176impl Display for BVar {
177    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
178        match &self.ty {
179            None => write!(f, "{}", self.var),
180            Some(ty) => write!(f, "{} :: {}", self.var, ty)
181        }
182    }
183}
184
185impl Debug for App {
186    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
187        write!(f, "{self}")
188    }
189}
190
191impl Debug for Lambda {
192    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
193        write!(f, "{self}")
194    }
195}
196
197impl Debug for Let {
198    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
199        write!(f, "{self}")
200    }
201}
202
203impl Debug for Literal {
204    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
205        write!(f, "{self}")
206    }
207}
208
209impl Debug for Special {
210    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
211        write!(f, "{self}")
212    }
213}
214
215impl Debug for BVar {
216    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
217        write!(f, "{self}")
218    }
219}
220
221impl Debug for Var {
222    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
223        write!(f, "{self}")
224    }
225}
226
227impl<S: Into<String>> From<S> for Var {
228    fn from(value: S) -> Self {
229        Self { name: value.into() }
230    }
231}
232
233#[cfg(test)]
234#[allow(clippy::unwrap_used)]
235mod tests {
236    use std::collections::HashSet;
237    use crate::assumption::Environment;
238    use crate::language::{app, lambda, Language, Literal, r#let, var};
239    use crate::r#type::{fun1, Tuple, var as t_var};
240
241    #[test]
242    fn test_infer_language() {
243        // λz. let x = (z, z) in (λy. (y, y)) x
244        // ∀a, b, c => a -> ((a, a), (a, a))
245        let expr: Language = lambda(
246            "z",
247            r#let(
248                "x",
249                Literal::Tuple(vec![var("z").into(), var("z").into()]),
250                app(
251                    var("x"),
252                    lambda("y", Literal::Tuple(vec![var("y").into(), var("y").into()])),
253                ),
254            ),
255        )
256        .into();
257
258        let t = expr.infer(&Environment::empty()).unwrap();
259
260        println!("{}\n{}", expr, t);
261        assert_eq!(
262            t,
263            fun1(
264                t_var(0),
265                Tuple(vec![
266                    Tuple(vec![t_var(0), t_var(0)]).into(),
267                    Tuple(vec![t_var(0), t_var(0)]).into()
268                ])
269            ).generalize(&HashSet::new())
270        );
271    }
272
273    #[test]
274    fn test_church_encoding() {
275        //zero = \f. \x. x                   :: a -> b -> b
276        //one  = \f. \x. f x                 :: (a -> b) -> a -> b
277        //plus = \m. \n. \f. \x. m f (n f x) :: (a -> b -> c) -> (a -> d -> b) -> a -> d -> c
278
279        let zero: Language = lambda("f", lambda("x", var("x"))).into();
280        let one: Language = lambda("f", lambda("x", app(var("x"), var("f")))).into();
281        let plus: Language = lambda(
282            "m",
283            lambda(
284                "n",
285                lambda(
286                    "f",
287                    lambda(
288                        "x",
289                        app(
290                            app(var("x"), app(var("f"), var("n"))),
291                            app(var("f"), var("m")),
292                        ),
293                    ),
294                ),
295            ),
296        )
297        .into();
298
299        let zt = zero.infer(&Environment::empty()).unwrap();
300        let ot = one.infer(&Environment::empty()).unwrap();
301        let pt = plus.infer(&Environment::empty()).unwrap();
302
303        println!("{}\n{}\n\n{}\n{}\n\n{}\n{}", zero, zt, one, ot, plus, pt);
304    }
305}