kodept_interpret/
type_checker.rs

1use derive_more::From;
2use nonempty_collections::{nev, IteratorExt, NEVec, NonEmptyIterator};
3use std::borrow::Cow;
4use std::cell::Cell;
5use std::collections::HashSet;
6use std::num::NonZeroU16;
7use std::rc::Rc;
8
9use kodept_ast::graph::{AnyNodeD, ChangeSet, GenericNodeId, GenericNodeKey, PermTkn, SyntaxTree};
10use kodept_ast::traits::Identifiable;
11use kodept_ast::utils::Execution;
12use kodept_ast::visit_side::{VisitGuard, VisitSide};
13use kodept_ast::BodyFnDecl;
14use kodept_core::structure::{rlt, Located};
15use kodept_inference::algorithm_w::{AlgorithmWError, CompoundInferError};
16use kodept_inference::language::{Language, Var};
17use kodept_inference::r#type::PolymorphicType;
18use kodept_inference::traits::EnvironmentProvider;
19use kodept_macros::error::report::{ReportMessage, Severity};
20use kodept_macros::traits::Context;
21use kodept_macros::Macro;
22use RecursiveTypeCheckingError::{AlgoWError, MutuallyRecursive, NodeNotFound};
23
24use crate::convert_model::ModelConvertibleNode;
25use crate::node_family::TypeRestrictedNode;
26use crate::scope::{ScopeError, ScopeSearch, ScopeTree};
27use crate::type_checker::InferError::Unknown;
28use crate::type_checker::RecursiveTypeCheckingError::InconvertibleToModel;
29use crate::{Cache, Witness};
30
31pub struct CannotInfer;
32
33pub struct TypeInfo<'a> {
34    name: &'a str,
35    ty: &'a PolymorphicType,
36}
37
38impl From<TypeInfo<'_>> for ReportMessage {
39    fn from(value: TypeInfo<'_>) -> Self {
40        Self::new(
41            Severity::Note,
42            "TC001",
43            format!(
44                "Type of function `{}` inferred to: {}",
45                value.name, value.ty
46            ),
47        )
48    }
49}
50
51impl From<CannotInfer> for ReportMessage {
52    fn from(_: CannotInfer) -> Self {
53        Self::new(Severity::Warning, "TC002", "Cannot infer type".to_string())
54    }
55}
56
57pub struct TypeChecker<'a> {
58    pub(crate) symbols: &'a ScopeTree,
59    models: Cache<Rc<Language>>,
60    evidence: Witness,
61    recursion_depth: NonZeroU16,
62}
63
64struct RecursiveTypeChecker<'a> {
65    search: ScopeSearch<'a>,
66    token: &'a PermTkn,
67    tree: &'a SyntaxTree,
68    models: &'a Cache<Rc<Language>>,
69    evidence: Witness,
70    current_recursion_depth: Cell<u16>,
71}
72
73#[derive(From, Debug)]
74pub enum InferError {
75    AlgorithmW(AlgorithmWError),
76    Scope(ScopeError),
77    Unknown,
78}
79
80#[derive(Debug, From)]
81enum RecursiveTypeCheckingError {
82    NodeNotFound(GenericNodeId),
83    InconvertibleToModel(AnyNodeD),
84    MutuallyRecursive,
85    #[from]
86    ScopeError(ScopeError),
87    #[from]
88    AlgoWError(AlgorithmWError),
89}
90
91#[derive(From, Debug)]
92struct RecursiveTypeCheckingErrors {
93    errors: NEVec<RecursiveTypeCheckingError>,
94}
95
96impl From<RecursiveTypeCheckingError> for RecursiveTypeCheckingErrors {
97    fn from(value: RecursiveTypeCheckingError) -> Self {
98        Self {
99            errors: nev![value],
100        }
101    }
102}
103
104impl From<ScopeError> for RecursiveTypeCheckingErrors {
105    fn from(value: ScopeError) -> Self {
106        Self {
107            errors: nev![value.into()],
108        }
109    }
110}
111
112impl From<InferError> for RecursiveTypeCheckingErrors {
113    fn from(value: InferError) -> Self {
114        Self {
115            errors: match value {
116                InferError::AlgorithmW(e) => nev![e.into()],
117                InferError::Scope(e) => nev![e.into()],
118                Unknown => panic!("Unknown error happened"),
119            },
120        }
121    }
122}
123
124impl From<AlgorithmWError> for RecursiveTypeCheckingErrors {
125    fn from(value: AlgorithmWError) -> Self {
126        Self {
127            errors: nev![value.into()],
128        }
129    }
130}
131
132impl From<RecursiveTypeCheckingError> for ReportMessage {
133    fn from(value: RecursiveTypeCheckingError) -> Self {
134        match value {
135            NodeNotFound(id) => Self::new(
136                Severity::Bug,
137                "TC005",
138                format!("Cannot find node with given id: {id}"),
139            ),
140            InconvertibleToModel(desc) => Self::new(
141                Severity::Bug,
142                "TC006",
143                format!("Cannot convert node with description `{desc}` to model"),
144            ),
145            RecursiveTypeCheckingError::ScopeError(e) => e.into(),
146            AlgoWError(e) => InferError::from(e).into(),
147            MutuallyRecursive => Self::new(
148                Severity::Error,
149                "TC007",
150                "Cannot type check due to mutual recursion".to_string(),
151            )
152            .with_notes(vec![
153                "Adjust `recursion_depth` CLI option if needed".to_string()
154            ]),
155        }
156    }
157}
158
159fn flatten(
160    value: CompoundInferError<RecursiveTypeCheckingErrors>,
161) -> NEVec<RecursiveTypeCheckingError> {
162    match value {
163        CompoundInferError::AlgoW(e) => nev![e.into()],
164        CompoundInferError::Both(e, es) => {
165            let errors: Vec<_> = es.into_iter().flat_map(|it| it.errors).collect();
166            if let Some(mut errors) = NEVec::from_vec(errors) {
167                errors.push(e.into());
168                errors
169            } else {
170                nev![e.into()]
171            }
172        }
173        CompoundInferError::Foreign(es) => es
174            .into_iter()
175            .flat_map(|it| it.errors)
176            .to_nonempty_iter()
177            .unwrap()
178            .collect(),
179    }
180}
181
182impl RecursiveTypeCheckingErrors {
183    fn into_report_messages(self) -> Vec<ReportMessage> {
184        self.errors.into_iter().map(ReportMessage::from).collect()
185    }
186}
187
188impl EnvironmentProvider<GenericNodeKey> for RecursiveTypeChecker<'_> {
189    type Error = RecursiveTypeCheckingErrors;
190
191    fn maybe_get(&self, key: &GenericNodeKey) -> Result<Option<Cow<PolymorphicType>>, Self::Error> {
192        let id: GenericNodeId = (*key).into();
193        let node = self.tree.get(id, self.token).ok_or(NodeNotFound(id))?;
194
195        if let Some(node) = node.try_cast::<TypeRestrictedNode>() {
196            let search = self.search.as_tree().lookup(node, self.tree, self.token)?;
197            match node.type_of(&search, self.tree, self.token) {
198                Execution::Failed(e) => return Err(e.into()),
199                Execution::Completed(x) => {
200                    return Ok(Some(Cow::Owned(x.generalize(&HashSet::new()))))
201                }
202                Execution::Skipped => {}
203            };
204        }
205
206        let depth = self.current_recursion_depth.get();
207        match depth.checked_sub(1) {
208            None => return Err(MutuallyRecursive.into()),
209            Some(x) => self.current_recursion_depth.set(x)
210        }
211        
212        let model = match self.models.get(*key) {
213            Some(x) => x.clone(),
214            None => {
215                let model = node
216                    .try_cast::<ModelConvertibleNode>()
217                    .ok_or(InconvertibleToModel(node.describe()))?
218                    .to_model(self.search.as_tree(), self.tree, self.token, self.evidence)?;
219                let model = Rc::new(model);
220                self.models.insert(*key, model.clone());
221                model
222            }
223        };
224
225        match model.infer(self) {
226            Ok(x) => Ok(Some(Cow::Owned(x))),
227            Err(e) => Err(RecursiveTypeCheckingErrors { errors: flatten(e) }),
228        }
229    }
230}
231
232impl EnvironmentProvider<Var> for RecursiveTypeChecker<'_> {
233    type Error = RecursiveTypeCheckingErrors;
234
235    fn maybe_get(&self, key: &Var) -> Result<Option<Cow<PolymorphicType>>, Self::Error> {
236        let Some(id) = self.search.id_of_var(&key.name) else {
237            return Ok(None);
238        };
239        let key: GenericNodeKey = id.into();
240        self.maybe_get(&key)
241    }
242}
243
244impl<'a> TypeChecker<'a> {
245    pub fn new(symbols: &'a ScopeTree, recursion_depth: NonZeroU16, evidence: Witness) -> Self {
246        Self {
247            symbols,
248            models: Default::default(),
249            evidence,
250            recursion_depth,
251        }
252    }
253
254    pub fn into_inner(self) -> Cache<Rc<Language>> {
255        self.models
256    }
257}
258
259impl From<InferError> for ReportMessage {
260    fn from(value: InferError) -> Self {
261        match value {
262            InferError::AlgorithmW(x) => Self::new(Severity::Error, "TI001", x.to_string()),
263            InferError::Scope(x) => x.into(),
264            Unknown => Self::new(Severity::Bug, "TI004", "Bug in implementation".to_string()),
265        }
266    }
267}
268
269impl Macro for TypeChecker<'_> {
270    type Error = InferError;
271    type Node = BodyFnDecl;
272
273    fn transform(
274        &mut self,
275        guard: VisitGuard<Self::Node>,
276        context: &mut impl Context,
277    ) -> Execution<Self::Error, ChangeSet> {
278        let (node, side) = guard.allow_all();
279        let Some(tree) = context.tree().upgrade() else {
280            return Execution::Skipped;
281        };
282        if !matches!(side, VisitSide::Leaf | VisitSide::Exiting) {
283            return Execution::Skipped;
284        }
285
286        let search = self.symbols.lookup(&*node, &tree, node.token())?;
287        let rec = RecursiveTypeChecker {
288            search,
289            token: node.token(),
290            tree: &tree,
291            models: &self.models,
292            evidence: self.evidence,
293            current_recursion_depth: Cell::new(self.recursion_depth.get()),
294        };
295        let fn_location = context
296            .access(&*node)
297            .map_or(vec![], |it: &rlt::BodiedFunction| vec![it.id.location()]);
298        let key: GenericNodeKey = node.get_id().widen().into();
299        match rec.maybe_get(&key) {
300            Ok(Some(ty)) => {
301                context.add_report(
302                    fn_location,
303                    TypeInfo {
304                        name: &node.name,
305                        ty: &ty,
306                    },
307                );
308            }
309            Ok(None) => context.add_report(fn_location.clone(), CannotInfer),
310            Err(e) => e
311                .into_report_messages()
312                .into_iter()
313                .for_each(|it| context.add_report(fn_location.clone(), it)),
314        }
315
316        Execution::Completed(ChangeSet::new())
317    }
318}