nessa/
inference.rs

1use colored::Colorize;
2
3use crate::compilation::NessaError;
4use crate::context::NessaContext;
5use crate::interfaces::ITERABLE_ID;
6use crate::parser::Location;
7use crate::parser::NessaExpr;
8use crate::functions::*;
9use crate::operations::*;
10use crate::types::Type;
11
12impl NessaContext {
13    pub fn get_first_unary_op(&self, id: usize, arg_type: Type, call_templates: Option<Vec<Type>>, sub_t: bool, l: &Location) -> Result<(usize, Type, bool, Vec<Type>), NessaError> {
14        if let Operator::Unary{operations, ..} = &self.unary_ops[id] {
15            'outer: for (i, op_ov) in operations.iter().enumerate() {
16                if let (true, subs) = arg_type.bindable_to_subtitutions(&op_ov.args, self) { // Take first that matches
17                    if let Some(call_t) = call_templates {
18                        for (i, t) in call_t.iter().enumerate() {
19                            if let Some(s_t) = subs.get(&i) {
20                                if t != s_t {
21                                    break 'outer;
22                                }   
23                            }
24                        }
25                    }
26                    
27                    let t_args = (0..op_ov.templates).map(|i| subs.get(&i).cloned().unwrap_or(Type::TemplateParam(i, vec!()))).collect();
28                    return Ok((i, if sub_t { op_ov.ret.sub_templates(&subs) } else { op_ov.ret.clone() }, op_ov.operation.is_some(), t_args));
29                }
30            }
31        }
32
33        if let Operator::Unary{representation, prefix, ..} = &self.unary_ops[id] {
34            if *prefix {
35                Err(NessaError::compiler_error(format!(
36                    "Unable to get unary operator overload for {}({})",
37                    representation,
38                    arg_type.get_name(self)
39                ), l, vec!()))
40
41            } else {
42                Err(NessaError::compiler_error(format!(
43                    "Unable to get unary operator overload for ({}){}",
44                    arg_type.get_name(self),
45                    representation
46                ), l, vec!()))
47            }
48
49        } else {
50            unreachable!()
51        }
52    }
53
54    pub fn is_unary_op_ambiguous(&self, id: usize, arg_type: Type) -> Option<Vec<(Type, Type)>> {
55        if let Operator::Unary{operations, ..} = &self.unary_ops[id] {
56            let overloads = operations.iter()
57                            .map(|op_ov| (op_ov.args.clone(), op_ov.ret.clone()))
58                            .filter(|(a, _)| arg_type.bindable_to(a, self)).collect::<Vec<_>>();
59
60            // Return Some(overloads) if the call is ambiguous, else return None
61            if overloads.len() > 1 {
62                return Some(overloads);
63
64            } else {
65                return None;
66            }
67        }
68
69        unreachable!();
70    }
71
72    pub fn get_first_binary_op(&self, id: usize, a_type: Type, b_type: Type, call_templates: Option<Vec<Type>>, sub_t: bool, l: &Location) -> Result<(usize, Type, bool, Vec<Type>), NessaError> {
73        let t = Type::And(vec!(a_type.clone(), b_type.clone()));
74
75        if let Operator::Binary{operations, ..} = &self.binary_ops[id] {
76            'outer: for (i, op_ov) in operations.iter().enumerate() {
77                if let (true, subs) = t.bindable_to_subtitutions(&op_ov.args, self) { // Take first that matches
78                    if let Some(call_t) = call_templates {
79                        for (i, t) in call_t.iter().enumerate() {
80                            if let Some(s_t) = subs.get(&i) {
81                                if t != s_t {
82                                    break 'outer;
83                                }   
84                            }
85                        }
86                    }
87
88                    let t_args = (0..op_ov.templates).map(|i| subs.get(&i).cloned().unwrap_or(Type::TemplateParam(i, vec!()))).collect();
89                    return Ok((i, if sub_t { op_ov.ret.sub_templates(&subs) } else { op_ov.ret.clone() }, op_ov.operation.is_some(), t_args));
90                }
91            }
92        }
93
94        if let Operator::Binary{representation, ..} = &self.binary_ops[id] {
95            Err(NessaError::compiler_error(format!(
96                "Unable to get binary operator overload for ({}){}({})",
97                a_type.get_name(self),
98                representation,
99                b_type.get_name(self)
100            ), l, vec!()))
101
102        } else {
103            unreachable!()
104        }
105    }
106
107    pub fn is_binary_op_ambiguous(&self, id: usize, a_type: Type, b_type: Type) -> Option<Vec<(Type, Type, Type)>> {
108        let t = Type::And(vec!(a_type, b_type));
109
110        if let Operator::Binary{operations, ..} = &self.binary_ops[id] {
111            let overloads = operations.iter()
112                            .filter(|op_ov| t.bindable_to(&op_ov.args, self))
113                            .map(|op_ov| {
114                                if let Type::And(t) = &op_ov.args {
115                                    (t[0].clone(), t[1].clone(), op_ov.ret.clone())
116
117                                } else {
118                                    unreachable!()
119                                }
120                            })
121                            .collect::<Vec<_>>();
122
123            // Return Some(overloads) if the call is ambiguous, else return None
124            if overloads.len() > 1 {
125                return Some(overloads);
126
127            } else {
128                return None;
129            }
130        }
131
132        unreachable!();
133    }
134
135    pub fn get_first_nary_op(&self, id: usize, a_type: Type, b_type: Vec<Type>, call_templates: Option<Vec<Type>>, sub_t: bool, l: &Location) -> Result<(usize, Type, bool, Vec<Type>), NessaError> {
136        let mut arg_types = vec!(a_type.clone());
137        arg_types.extend(b_type.iter().cloned());
138
139        let t = Type::And(arg_types.clone());
140
141        if let Operator::Nary{operations, ..} = &self.nary_ops[id] {
142            'outer: for (i, op_ov) in operations.iter().enumerate() {
143                if let (true, subs) = t.bindable_to_subtitutions(&op_ov.args, self) { // Take first that matches
144                    if let Some(call_t) = call_templates {
145                        for (i, t) in call_t.iter().enumerate() {
146                            if let Some(s_t) = subs.get(&i) {
147                                if t != s_t {
148                                    break 'outer;
149                                }   
150                            }
151                        }
152                    }
153
154                    let t_args = (0..op_ov.templates).map(|i| subs.get(&i).cloned().unwrap_or(Type::TemplateParam(i, vec!()))).collect();
155                    return Ok((i, if sub_t { op_ov.ret.sub_templates(&subs) } else { op_ov.ret.clone() }, op_ov.operation.is_some(), t_args));
156                }
157            }
158        }
159
160        if let Operator::Nary{open_rep, close_rep, ..} = &self.nary_ops[id] {
161            Err(NessaError::compiler_error(format!(
162                "Unable to get n-ary operator overload for {}{}{}{}",
163                a_type.get_name(self),
164                open_rep,
165                b_type.iter().map(|i| i.get_name(self)).collect::<Vec<_>>().join(", "),
166                close_rep
167            ), l, vec!()))
168
169        } else {
170            unreachable!()
171        }
172    }
173
174    pub fn is_nary_op_ambiguous(&self, id: usize, a_type: Type, b_type: Vec<Type>) -> Option<Vec<(Type, Vec<Type>, Type)>> {
175        let mut arg_types = vec!(a_type.clone());
176        arg_types.extend(b_type.iter().cloned());
177
178        let t = Type::And(arg_types);
179        
180        if let Operator::Nary{operations, ..} = &self.nary_ops[id] {
181            let overloads = operations.iter()
182                            .filter(|op_ov| t.bindable_to(&op_ov.args, self))
183                            .map(|op_ov| {
184                                if let Type::And(t) = &op_ov.args {
185                                    (t[0].clone(), t[1..].to_vec(), op_ov.ret.clone())
186
187                                } else {
188                                    unreachable!()
189                                }
190                            })
191                            .collect::<Vec<_>>();
192
193            // Return Some(overloads) if the call is ambiguous, else return None
194            if overloads.len() > 1 {
195                return Some(overloads);
196
197            } else {
198                return None;
199            }
200        }
201
202        unreachable!();
203    }
204
205    pub fn get_first_function_overload(&self, id: usize, arg_type: Vec<Type>, call_templates: Option<Vec<Type>>, sub_t: bool, l: &Location) -> Result<(usize, Type, bool, Vec<Type>), NessaError> {
206        let t = Type::And(arg_type.clone());
207
208        'outer: for (i, f_ov) in self.functions[id].overloads.iter().enumerate() {
209            if let (true, subs) = t.bindable_to_subtitutions(&f_ov.args, self) { // Take first that matches
210                if let Some(call_t) = &call_templates {
211                    for (i, t) in call_t.iter().enumerate() {
212                        if let Some(s_t) = subs.get(&i) {
213                            if t != s_t {
214                                break 'outer;
215                            }   
216                        }
217                    }
218                }
219                
220                let t_args = (0..f_ov.templates).map(|i| subs.get(&i).cloned().unwrap_or(Type::TemplateParam(i, vec!()))).collect();
221                return Ok((i, if sub_t { f_ov.ret.sub_templates(&subs) } else { f_ov.ret.clone() }, f_ov.function.is_some(), t_args));
222            }
223        }
224
225        Err(NessaError::compiler_error(format!(
226            "Unable to get function overload for {}{}({})",
227            self.functions[id].name.green(),
228            if call_templates.is_none() || call_templates.as_ref().unwrap().is_empty() { 
229                "".into() 
230            } else { 
231                format!("<{}>", call_templates.unwrap().iter().map(|i| i.get_name(self)).collect::<Vec<_>>().join(", ")) 
232            },
233            arg_type.iter().map(|i| i.get_name(self)).collect::<Vec<_>>().join(", ")
234        ), l, vec!()))
235    }
236
237    pub fn is_function_overload_ambiguous(&self, id: usize, arg_type: Vec<Type>) -> Option<Vec<(Type, Type)>> {
238        let t = Type::And(arg_type);
239
240        let overloads = self.functions[id].overloads.iter()
241                            .map(|f_ov| (f_ov.args.clone(), f_ov.ret.clone()))
242                            .filter(|(a, _)| t.bindable_to(a, self)).collect::<Vec<_>>();
243
244        // Return Some(overloads) if the call is ambiguous, else return None
245        if overloads.len() > 1 {
246            Some(overloads)
247
248        } else {
249            None
250        }
251    }
252
253    pub fn implements_iterable(&self, container_type: &Type) -> bool {
254        for i in &self.interface_impls {
255            if i.interface_id == ITERABLE_ID && container_type.bindable_to(&i.interface_type, self) {
256                return true;
257            }
258        }
259
260        false
261    }
262
263    pub fn get_iterator_type(&self, container_type: &Type, l: &Location) -> Result<(usize, Type, bool, Vec<Type>), NessaError> {
264        self.get_first_function_overload(ITERATOR_FUNC_ID, vec!(container_type.clone()), None, true, l)
265    }
266
267    pub fn get_iterator_output_type(&self, iterator_type: &Type, l: &Location) -> Result<(usize, Type, bool, Vec<Type>), NessaError> {
268        let it_mut = Type::MutRef(Box::new(iterator_type.clone()));
269
270        self.get_first_function_overload(NEXT_FUNC_ID, vec!(it_mut.clone()), None, true, l)
271    }
272
273    pub fn infer_type(&self, expr: &NessaExpr) -> Result<Type, NessaError> {
274        return match expr {
275            NessaExpr::Literal(_, obj) => Ok(obj.get_type()),
276
277            NessaExpr::DoBlock(_, _, t) => Ok(t.clone()),
278
279            NessaExpr::AttributeAccess(_, e, att_idx) => {
280                use Type::*;
281
282                let arg_type = self.infer_type(e)?;
283
284                if let Basic(id) | Template(id, _) = arg_type.deref_type() {
285                    let mut att_type = self.type_templates[*id].attributes[*att_idx].1.clone();
286
287                    // Subtitute template parameters if needed
288                    if let Template(_, ts) = arg_type.deref_type() {
289                        att_type = att_type.sub_templates(&ts.iter().cloned().enumerate().collect());
290                    }
291                    
292                    return match (&arg_type, &att_type) {
293                        (MutRef(_), Ref(_) | MutRef(_)) => Ok(att_type.clone()),
294                        (MutRef(_), _) => Ok(MutRef(Box::new(att_type.clone()))),
295
296                        (Ref(_), MutRef(i)) => Ok(Ref(i.clone())),
297                        (Ref(_), Ref(_)) => Ok(att_type.clone()),
298                        (Ref(_), _) => Ok(Ref(Box::new(att_type.clone()))),
299
300                        (_, _) => Ok(att_type.clone())
301                    };
302
303                } else {
304                    unreachable!()
305                }
306            }
307
308            NessaExpr::CompiledLambda(_, _, _, a, r, _) => Ok(
309                if a.len() == 1 {
310                    Type::Function(
311                        Box::new(a[0].1.clone()),
312                        Box::new(r.clone())
313                    )
314
315                } else {
316                    Type::Function(
317                        Box::new(Type::And(a.iter().map(|(_, t)| t).cloned().collect())),
318                        Box::new(r.clone())
319                    )
320                }
321            ),
322            
323            NessaExpr::Tuple(_, e) => {
324                let mut args = vec!();
325
326                for i in e {
327                    args.push(self.infer_type(i)?);
328                }
329
330                Ok(Type::And(args))
331            },
332
333            NessaExpr::Variable(_, _, _, t) => {
334                match t {
335                    Type::Ref(_) | Type::MutRef(_) => Ok(t.clone()),
336                    t => Ok(Type::MutRef(Box::new(t.clone())))
337                }
338            },
339
340            NessaExpr::UnaryOperation(l, id, t, a) => {
341                let t_sub_call = t.iter().cloned().enumerate().collect();
342                let args_type = self.infer_type(a)?.sub_templates(&t_sub_call);
343
344                let (_, r, _, subs) = self.get_first_unary_op(*id, args_type, None, false, l)?;
345
346                let t_sub_ov = subs.iter().cloned().enumerate().collect();
347
348                return Ok(r.sub_templates(&t_sub_ov).sub_templates(&t_sub_call));
349            },
350
351            NessaExpr::BinaryOperation(l, id, t, a, b) => {
352                let t_sub_call = t.iter().cloned().enumerate().collect();
353                let a_type = self.infer_type(a)?.sub_templates(&t_sub_call);
354                let b_type = self.infer_type(b)?.sub_templates(&t_sub_call);
355
356                let (_, r, _, subs) = self.get_first_binary_op(*id, a_type, b_type, None, false, l)?;
357
358                let t_sub_ov = subs.iter().cloned().enumerate().collect();
359
360                return Ok(r.sub_templates(&t_sub_ov).sub_templates(&t_sub_call));
361            },
362
363            NessaExpr::NaryOperation(l, id, t, a, b) => {
364                let t_sub_call = t.iter().cloned().enumerate().collect();
365                let a_type = self.infer_type(a)?.sub_templates(&t_sub_call);
366                let b_type = b.iter().map(|i| self.infer_type(i))
367                                     .collect::<Result<Vec<_>, NessaError>>()?
368                                     .into_iter()
369                                     .map(|i| i.sub_templates(&t_sub_call))
370                                     .collect();
371
372                let (_, r, _, subs) = self.get_first_nary_op(*id, a_type, b_type, None, false, l)?;
373
374                let t_sub_ov = subs.iter().cloned().enumerate().collect();
375
376                return Ok(r.sub_templates(&t_sub_ov).sub_templates(&t_sub_call));
377            },
378
379            NessaExpr::FunctionCall(l, id, t, args) => {
380                let t_sub_call = t.iter().cloned().enumerate().collect();
381                let arg_types = args.iter().map(|i| self.infer_type(i))
382                                           .collect::<Result<Vec<_>, NessaError>>()?
383                                           .into_iter()
384                                           .map(|i| i.sub_templates(&t_sub_call))
385                                           .collect();
386
387                let (_, r, _, subs) = self.get_first_function_overload(*id, arg_types, None, true, l)?;
388
389                let t_sub_ov = subs.iter().cloned().enumerate().collect();
390
391                return Ok(r.sub_templates(&t_sub_ov).sub_templates(&t_sub_call));
392            }
393
394            NessaExpr::QualifiedName(l, _, Some(id)) => {
395                let func = &self.functions[*id];
396
397                if func.overloads.len() == 1 {
398                    let ov = &func.overloads[0];
399
400                    if ov.templates != 0 {
401                        return Err(NessaError::compiler_error(
402                            format!(
403                                "Implicit lambda for function with name {} cannot be formed from generic overload",
404                                func.name.green()
405                            ), 
406                            l, vec!()
407                        ));
408                    }
409                    
410                    if let Type::And(a) = &ov.args {
411                        if a.len() == 1 {
412                            return Ok(Type::Function(
413                                Box::new(a[0].clone()),
414                                Box::new(ov.ret.clone())
415                            ))
416        
417                        } else {
418                            return Ok(Type::Function(
419                                Box::new(Type::And(a.clone())),
420                                Box::new(ov.ret.clone())
421                            ))
422                        }
423                    }
424
425                    return Ok(Type::Function(
426                        Box::new(ov.args.clone()),
427                        Box::new(ov.ret.clone())
428                    ))
429                }
430
431                return Err(NessaError::compiler_error(
432                    format!(
433                        "Implicit lambda for function with name {} is ambiguous (found {} overloads)",
434                        func.name.green(),
435                        func.overloads.len()
436                    ), 
437                    l, vec!()
438                ));
439            }
440
441            NessaExpr::QualifiedName(l, _, _) |
442            NessaExpr::AttributeAssignment(l, _, _, _) |
443            NessaExpr::CompiledVariableDefinition(l, _, _, _, _) |
444            NessaExpr::CompiledVariableAssignment(l, _, _, _, _) |
445            NessaExpr::CompiledFor(l, _, _, _, _, _) |
446            NessaExpr::Macro(l, _, _, _, _, _) |
447            NessaExpr::Lambda(l, _, _, _, _) |
448            NessaExpr::NameReference(l, _) |
449            NessaExpr::VariableDefinition(l, _, _, _) |
450            NessaExpr::VariableAssignment(l, _, _) |
451            NessaExpr::FunctionDefinition(l, _, _, _, _, _, _) |
452            NessaExpr::PrefixOperatorDefinition(l, _, _) |
453            NessaExpr::PostfixOperatorDefinition(l, _, _) |
454            NessaExpr::BinaryOperatorDefinition(l, _, _, _) |
455            NessaExpr::NaryOperatorDefinition(l, _, _, _) |
456            NessaExpr::ClassDefinition(l, _, _, _, _, _, _) |
457            NessaExpr::InterfaceDefinition(l, _, _, _, _, _, _, _) |
458            NessaExpr::InterfaceImplementation(l, _, _, _, _) |
459            NessaExpr::PrefixOperationDefinition(l, _, _, _, _, _, _, _) |
460            NessaExpr::PostfixOperationDefinition(l, _, _, _, _, _, _, _) |
461            NessaExpr::BinaryOperationDefinition(l, _, _, _, _, _, _, _) |
462            NessaExpr::NaryOperationDefinition(l, _, _, _, _, _, _, _) |
463            NessaExpr::If(l, _, _, _, _) |
464            NessaExpr::Break(l) |
465            NessaExpr::Continue(l) |
466            NessaExpr::While(l, _, _) |
467            NessaExpr::For(l, _, _, _) |
468            NessaExpr::Return(l, _) => Err(NessaError::compiler_error(
469                "Expression cannot be evaluated to a type".into(), 
470                l, vec!()
471            ))
472        };
473    }
474}