deskc_thirgen/
lib.rs

1mod builtin;
2
3use std::cell::RefCell;
4
5use hir::{
6    expr::{Expr, Literal, MatchCase},
7    meta::{Meta, WithMeta},
8};
9use thir::{Handler, TypedHir};
10use types::{Effect, IdGen, Type, Types};
11
12use crate::builtin::find_builtin;
13
14pub fn gen_typed_hir(next_id: usize, types: Types, expr: &WithMeta<Expr>) -> TypedHir {
15    TypedHirGen {
16        types,
17        id_gen: RefCell::new(IdGen { next_id }),
18    }
19    .gen(expr)
20}
21
22#[derive(Debug, Default, Clone)]
23pub struct TypedHirGen {
24    types: Types,
25    id_gen: RefCell<IdGen>,
26}
27
28impl TypedHirGen {
29    pub fn gen(&self, expr: &WithMeta<Expr>) -> TypedHir {
30        let Meta { id: expr_id, .. } = &expr.meta;
31        let ty = self.types.get(expr_id).expect("must have type").clone();
32        let expr = match &expr.value {
33            Expr::Literal(Literal::Hole) => todo!(),
34            Expr::Literal(Literal::Int(value)) => thir::Expr::Literal(thir::Literal::Int(*value)),
35            Expr::Literal(Literal::Float(value)) => {
36                thir::Expr::Literal(thir::Literal::Float(*value))
37            }
38            Expr::Literal(Literal::Rational(a, b)) => {
39                thir::Expr::Literal(thir::Literal::Rational(*a, *b))
40            }
41            Expr::Literal(Literal::String(value)) => {
42                thir::Expr::Literal(thir::Literal::String(value.clone()))
43            }
44            Expr::Let {
45                ty: _,
46                definition,
47                expression,
48            } => thir::Expr::Let {
49                definition: Box::new(self.gen(&*definition)),
50                body: Box::new(self.gen(&*expression)),
51            },
52            Expr::Perform { input, output: _ } => thir::Expr::Perform(Box::new(self.gen(&*input))),
53            Expr::Continue { input, output: _ } => thir::Expr::Perform(Box::new(self.gen(&*input))),
54            Expr::Handle { handlers, expr } => thir::Expr::Handle {
55                handlers: handlers
56                    .iter()
57                    .map(
58                        |hir::expr::Handler {
59                             input,
60                             output,
61                             handler,
62                         }| Handler {
63                            effect: Effect {
64                                input: self.get_type(input),
65                                output: self.get_type(output),
66                            },
67                            handler: self.gen(&*handler),
68                        },
69                    )
70                    .collect(),
71                expr: Box::new(self.gen(&*expr)),
72            },
73            Expr::Apply {
74                function,
75                link_name,
76                arguments,
77            } => {
78                // TODO: lookup imported uuid to allow overwrite the builtin functions
79                if let Some(builtin) = find_builtin(&self.get_type(function)) {
80                    match builtin {
81                        builtin::Builtin::Normal { op, params } => {
82                            let op = thir::Expr::Op {
83                                op,
84                                operands: arguments.iter().map(|arg| self.gen(arg)).collect(),
85                            };
86                            // TODO wrap by function
87                            if arguments.len() < params {}
88                            op
89                        }
90                        builtin::Builtin::Custom(expr) => expr(self, arguments),
91                    }
92                } else {
93                    thir::Expr::Apply {
94                        function: self.get_type(function),
95                        link_name: link_name.clone(),
96                        arguments: arguments.iter().map(|arg| self.gen(arg)).collect(),
97                    }
98                }
99            }
100            Expr::Product(values) => {
101                thir::Expr::Product(values.iter().map(|value| self.gen(&*value)).collect())
102            }
103            // one ID disappeared here, but fine
104            Expr::Typed { ty: _, item: expr } => self.gen(expr).expr,
105            Expr::Function { parameter: _, body } => {
106                // get type from whole function is more accurate than from parameter.
107                let function_ty = self.get_type(expr);
108                if let Type::Function {
109                    parameters,
110                    body: _,
111                } = function_ty
112                {
113                    // Flatten the function
114                    match self.gen(&*body) {
115                        TypedHir {
116                            expr: thir::Expr::Function { body, .. },
117                            ..
118                        } => thir::Expr::Function { parameters, body },
119                        inner => thir::Expr::Function {
120                            parameters,
121                            body: Box::new(inner),
122                        },
123                    }
124                } else {
125                    panic!("function is inferred to not function??");
126                }
127            }
128            Expr::Array(values) => {
129                thir::Expr::Array(values.iter().map(|value| self.gen(&*value)).collect())
130            }
131            Expr::Set(values) => {
132                thir::Expr::Set(values.iter().map(|value| self.gen(&*value)).collect())
133            }
134            Expr::Match { of, cases } => thir::Expr::Match {
135                input: Box::new(self.gen(&*of)),
136                cases: cases
137                    .iter()
138                    .map(|MatchCase { ty, expr }| thir::MatchCase {
139                        ty: self.get_type(ty),
140                        expr: self.gen(expr),
141                    })
142                    .collect(),
143            },
144            Expr::Label {
145                label, item: body, ..
146            }
147            | Expr::Brand {
148                brand: label,
149                item: body,
150                ..
151            } => thir::Expr::Label {
152                label: label.clone(),
153                item: Box::new(self.gen(&*body)),
154            },
155        };
156        TypedHir {
157            id: *expr_id,
158            ty,
159            expr,
160        }
161    }
162
163    fn get_type<T>(&self, expr: &WithMeta<T>) -> Type {
164        self.types
165            .get(&expr.meta.id)
166            .expect("must have type")
167            .clone()
168    }
169
170    pub fn next_id(&self) -> usize {
171        self.id_gen.borrow_mut().next_id()
172    }
173}
174
175#[cfg(test)]
176mod tests {
177    use file::FileId;
178    use thir::BuiltinOp;
179
180    use super::*;
181    use pretty_assertions::assert_eq;
182
183    fn parse(input: &str) -> WithMeta<Expr> {
184        let tokens = lexer::scan(input).unwrap();
185        let ast = parser::parse(tokens).unwrap();
186        hirgen::gen_hir(FileId(0), &ast, Default::default())
187            .unwrap()
188            .1
189    }
190
191    fn infer(expr: &WithMeta<Expr>) -> Types {
192        let infer = typeinfer::ctx::Ctx::default();
193        let _ = infer.synth(expr).unwrap();
194        infer.get_types()
195    }
196
197    #[test]
198    fn literal() {
199        let expr = parse("1");
200        let gen = TypedHirGen {
201            types: infer(&expr),
202            ..Default::default()
203        };
204        assert_eq!(
205            gen.gen(&expr),
206            TypedHir {
207                id: 0,
208                ty: Type::Number,
209                expr: thir::Expr::Literal(thir::Literal::Int(1)),
210            }
211        );
212    }
213
214    #[test]
215    fn function_and_reference() {
216        let expr = parse(r#"\ 'number, 'string -> &'number"#);
217        let gen = TypedHirGen {
218            types: infer(&expr),
219            ..Default::default()
220        };
221        assert_eq!(
222            gen.gen(&expr),
223            TypedHir {
224                id: 5,
225                ty: Type::Function {
226                    parameters: vec![Type::Number, Type::String],
227                    body: Box::new(Type::Number),
228                },
229                expr: thir::Expr::Function {
230                    parameters: vec![Type::Number, Type::String],
231                    body: Box::new(TypedHir {
232                        id: 3,
233                        ty: Type::Number,
234                        expr: thir::Expr::Apply {
235                            function: Type::Number,
236                            link_name: None,
237                            arguments: vec![]
238                        },
239                    }),
240                },
241            }
242        );
243    }
244
245    #[test]
246    fn builtin() {
247        let expr = parse(r#"> \'number, 'number -> @sum 'number ~ 1, 2"#);
248        let gen = TypedHirGen {
249            types: infer(&expr),
250            ..Default::default()
251        };
252        assert_eq!(
253            gen.gen(&expr),
254            TypedHir {
255                id: 8,
256                ty: Type::Label {
257                    label: "sum".to_string(),
258                    item: Box::new(Type::Number),
259                },
260                expr: thir::Expr::Op {
261                    op: BuiltinOp::Add,
262                    operands: vec![
263                        TypedHir {
264                            id: 6,
265                            ty: Type::Number,
266                            expr: thir::Expr::Literal(thir::Literal::Int(1)),
267                        },
268                        TypedHir {
269                            id: 7,
270                            ty: Type::Number,
271                            expr: thir::Expr::Literal(thir::Literal::Int(2)),
272                        }
273                    ]
274                },
275            }
276        );
277    }
278
279    #[test]
280    fn builtin_curried() {
281        let expr = parse(r#"> \'number, 'number -> @sum 'number ~ 1"#);
282        let _gen = TypedHirGen {
283            types: infer(&expr),
284            id_gen: RefCell::new(IdGen { next_id: 100 }),
285        };
286        // TODO
287        // assert_eq!(
288        //     gen.gen(&expr),
289        //     TypedHir {
290        //         id: 8,
291        //         ty: Type::Label {
292        //             label: "sum".to_string(),
293        //             item: Box::new(Type::Number),
294        //         },
295        //         expr: thir::Expr::Function {
296        //             parameters: vec![
297        //                 Type::Label {
298        //                     label: "$$deskc 1".to_string(),
299        //                     item: Box::new(Type::Number)
300        //                 },
301        //                 Type::Label {
302        //                     label: "$$deskc 2".to_string(),
303        //                     item: Box::new(Type::Number)
304        //                 },
305        //             ],
306        //             body: Box::new(TypedHir {
307        //                 id: 100,
308        //                 ty: Type::Label {
309        //                     label: "sum".to_string(),
310        //                     item: Box::new(Type::Number),
311        //                 },
312        //                 expr: thir::Expr::BuiltinOp {
313        //                     op: BuiltinOp::Add,
314        //                     arguments: vec![
315        //                         TypedHir {
316        //                             id: 6,
317        //                             ty: Type::Label {
318        //                                 label: "$$deskc 1".to_string(),
319        //                                 item: Box::new(Type::Number)
320        //                             },
321        //                             expr: thir::Expr::Reference,
322        //                         },
323        //                         TypedHir {
324        //                             id: 7,
325        //                             ty: Type::Label {
326        //                                 label: "$$deskc 2".to_string(),
327        //                                 item: Box::new(Type::Number)
328        //                             },
329        //                             expr: thir::Expr::Reference,
330        //                         }
331        //                     ]
332        //                 }
333        //             })
334        //         },
335        //     }
336        // );
337    }
338
339    #[test]
340    fn match_() {
341        let expr = parse(
342            r#"
343        + 3 ~
344          'number -> 1,
345          'string -> "2".
346        "#,
347        );
348        let gen = TypedHirGen {
349            types: infer(&expr),
350            ..Default::default()
351        };
352        assert_eq!(
353            gen.gen(&expr),
354            TypedHir {
355                id: 5,
356                ty: Type::Sum(vec![Type::Number, Type::String]),
357                expr: thir::Expr::Match {
358                    input: Box::new(TypedHir {
359                        id: 0,
360                        ty: Type::Number,
361                        expr: thir::Expr::Literal(thir::Literal::Int(3)),
362                    }),
363                    cases: vec![
364                        thir::MatchCase {
365                            ty: Type::Number,
366                            expr: TypedHir {
367                                id: 2,
368                                ty: Type::Number,
369                                expr: thir::Expr::Literal(thir::Literal::Int(1)),
370                            }
371                        },
372                        thir::MatchCase {
373                            ty: Type::String,
374                            expr: TypedHir {
375                                id: 4,
376                                ty: Type::String,
377                                expr: thir::Expr::Literal(thir::Literal::String("2".into())),
378                            }
379                        },
380                    ]
381                },
382            }
383        );
384    }
385
386    // TODO: match exhaustive check
387}