1use derive_more::From;
2use nonempty_collections::{nev, IteratorExt, NEVec, NonEmptyIterator};
3use std::borrow::Cow;
4use std::cell::Cell;
5use std::collections::HashSet;
6use std::num::NonZeroU16;
7use std::rc::Rc;
8
9use kodept_ast::graph::{AnyNodeD, ChangeSet, GenericNodeId, GenericNodeKey, PermTkn, SyntaxTree};
10use kodept_ast::traits::Identifiable;
11use kodept_ast::utils::Execution;
12use kodept_ast::visit_side::{VisitGuard, VisitSide};
13use kodept_ast::BodyFnDecl;
14use kodept_core::structure::{rlt, Located};
15use kodept_inference::algorithm_w::{AlgorithmWError, CompoundInferError};
16use kodept_inference::language::{Language, Var};
17use kodept_inference::r#type::PolymorphicType;
18use kodept_inference::traits::EnvironmentProvider;
19use kodept_macros::error::report::{ReportMessage, Severity};
20use kodept_macros::traits::Context;
21use kodept_macros::Macro;
22use RecursiveTypeCheckingError::{AlgoWError, MutuallyRecursive, NodeNotFound};
23
24use crate::convert_model::ModelConvertibleNode;
25use crate::node_family::TypeRestrictedNode;
26use crate::scope::{ScopeError, ScopeSearch, ScopeTree};
27use crate::type_checker::InferError::Unknown;
28use crate::type_checker::RecursiveTypeCheckingError::InconvertibleToModel;
29use crate::{Cache, Witness};
30
31pub struct CannotInfer;
32
33pub struct TypeInfo<'a> {
34 name: &'a str,
35 ty: &'a PolymorphicType,
36}
37
38impl From<TypeInfo<'_>> for ReportMessage {
39 fn from(value: TypeInfo<'_>) -> Self {
40 Self::new(
41 Severity::Note,
42 "TC001",
43 format!(
44 "Type of function `{}` inferred to: {}",
45 value.name, value.ty
46 ),
47 )
48 }
49}
50
51impl From<CannotInfer> for ReportMessage {
52 fn from(_: CannotInfer) -> Self {
53 Self::new(Severity::Warning, "TC002", "Cannot infer type".to_string())
54 }
55}
56
57pub struct TypeChecker<'a> {
58 pub(crate) symbols: &'a ScopeTree,
59 models: Cache<Rc<Language>>,
60 evidence: Witness,
61 recursion_depth: NonZeroU16,
62}
63
64struct RecursiveTypeChecker<'a> {
65 search: ScopeSearch<'a>,
66 token: &'a PermTkn,
67 tree: &'a SyntaxTree,
68 models: &'a Cache<Rc<Language>>,
69 evidence: Witness,
70 current_recursion_depth: Cell<u16>,
71}
72
73#[derive(From, Debug)]
74pub enum InferError {
75 AlgorithmW(AlgorithmWError),
76 Scope(ScopeError),
77 Unknown,
78}
79
80#[derive(Debug, From)]
81enum RecursiveTypeCheckingError {
82 NodeNotFound(GenericNodeId),
83 InconvertibleToModel(AnyNodeD),
84 MutuallyRecursive,
85 #[from]
86 ScopeError(ScopeError),
87 #[from]
88 AlgoWError(AlgorithmWError),
89}
90
91#[derive(From, Debug)]
92struct RecursiveTypeCheckingErrors {
93 errors: NEVec<RecursiveTypeCheckingError>,
94}
95
96impl From<RecursiveTypeCheckingError> for RecursiveTypeCheckingErrors {
97 fn from(value: RecursiveTypeCheckingError) -> Self {
98 Self {
99 errors: nev![value],
100 }
101 }
102}
103
104impl From<ScopeError> for RecursiveTypeCheckingErrors {
105 fn from(value: ScopeError) -> Self {
106 Self {
107 errors: nev![value.into()],
108 }
109 }
110}
111
112impl From<InferError> for RecursiveTypeCheckingErrors {
113 fn from(value: InferError) -> Self {
114 Self {
115 errors: match value {
116 InferError::AlgorithmW(e) => nev![e.into()],
117 InferError::Scope(e) => nev![e.into()],
118 Unknown => panic!("Unknown error happened"),
119 },
120 }
121 }
122}
123
124impl From<AlgorithmWError> for RecursiveTypeCheckingErrors {
125 fn from(value: AlgorithmWError) -> Self {
126 Self {
127 errors: nev![value.into()],
128 }
129 }
130}
131
132impl From<RecursiveTypeCheckingError> for ReportMessage {
133 fn from(value: RecursiveTypeCheckingError) -> Self {
134 match value {
135 NodeNotFound(id) => Self::new(
136 Severity::Bug,
137 "TC005",
138 format!("Cannot find node with given id: {id}"),
139 ),
140 InconvertibleToModel(desc) => Self::new(
141 Severity::Bug,
142 "TC006",
143 format!("Cannot convert node with description `{desc}` to model"),
144 ),
145 RecursiveTypeCheckingError::ScopeError(e) => e.into(),
146 AlgoWError(e) => InferError::from(e).into(),
147 MutuallyRecursive => Self::new(
148 Severity::Error,
149 "TC007",
150 "Cannot type check due to mutual recursion".to_string(),
151 )
152 .with_notes(vec![
153 "Adjust `recursion_depth` CLI option if needed".to_string()
154 ]),
155 }
156 }
157}
158
159fn flatten(
160 value: CompoundInferError<RecursiveTypeCheckingErrors>,
161) -> NEVec<RecursiveTypeCheckingError> {
162 match value {
163 CompoundInferError::AlgoW(e) => nev![e.into()],
164 CompoundInferError::Both(e, es) => {
165 let errors: Vec<_> = es.into_iter().flat_map(|it| it.errors).collect();
166 if let Some(mut errors) = NEVec::from_vec(errors) {
167 errors.push(e.into());
168 errors
169 } else {
170 nev![e.into()]
171 }
172 }
173 CompoundInferError::Foreign(es) => es
174 .into_iter()
175 .flat_map(|it| it.errors)
176 .to_nonempty_iter()
177 .unwrap()
178 .collect(),
179 }
180}
181
182impl RecursiveTypeCheckingErrors {
183 fn into_report_messages(self) -> Vec<ReportMessage> {
184 self.errors.into_iter().map(ReportMessage::from).collect()
185 }
186}
187
188impl EnvironmentProvider<GenericNodeKey> for RecursiveTypeChecker<'_> {
189 type Error = RecursiveTypeCheckingErrors;
190
191 fn maybe_get(&self, key: &GenericNodeKey) -> Result<Option<Cow<PolymorphicType>>, Self::Error> {
192 let id: GenericNodeId = (*key).into();
193 let node = self.tree.get(id, self.token).ok_or(NodeNotFound(id))?;
194
195 if let Some(node) = node.try_cast::<TypeRestrictedNode>() {
196 let search = self.search.as_tree().lookup(node, self.tree, self.token)?;
197 match node.type_of(&search, self.tree, self.token) {
198 Execution::Failed(e) => return Err(e.into()),
199 Execution::Completed(x) => {
200 return Ok(Some(Cow::Owned(x.generalize(&HashSet::new()))))
201 }
202 Execution::Skipped => {}
203 };
204 }
205
206 let depth = self.current_recursion_depth.get();
207 match depth.checked_sub(1) {
208 None => return Err(MutuallyRecursive.into()),
209 Some(x) => self.current_recursion_depth.set(x)
210 }
211
212 let model = match self.models.get(*key) {
213 Some(x) => x.clone(),
214 None => {
215 let model = node
216 .try_cast::<ModelConvertibleNode>()
217 .ok_or(InconvertibleToModel(node.describe()))?
218 .to_model(self.search.as_tree(), self.tree, self.token, self.evidence)?;
219 let model = Rc::new(model);
220 self.models.insert(*key, model.clone());
221 model
222 }
223 };
224
225 match model.infer(self) {
226 Ok(x) => Ok(Some(Cow::Owned(x))),
227 Err(e) => Err(RecursiveTypeCheckingErrors { errors: flatten(e) }),
228 }
229 }
230}
231
232impl EnvironmentProvider<Var> for RecursiveTypeChecker<'_> {
233 type Error = RecursiveTypeCheckingErrors;
234
235 fn maybe_get(&self, key: &Var) -> Result<Option<Cow<PolymorphicType>>, Self::Error> {
236 let Some(id) = self.search.id_of_var(&key.name) else {
237 return Ok(None);
238 };
239 let key: GenericNodeKey = id.into();
240 self.maybe_get(&key)
241 }
242}
243
244impl<'a> TypeChecker<'a> {
245 pub fn new(symbols: &'a ScopeTree, recursion_depth: NonZeroU16, evidence: Witness) -> Self {
246 Self {
247 symbols,
248 models: Default::default(),
249 evidence,
250 recursion_depth,
251 }
252 }
253
254 pub fn into_inner(self) -> Cache<Rc<Language>> {
255 self.models
256 }
257}
258
259impl From<InferError> for ReportMessage {
260 fn from(value: InferError) -> Self {
261 match value {
262 InferError::AlgorithmW(x) => Self::new(Severity::Error, "TI001", x.to_string()),
263 InferError::Scope(x) => x.into(),
264 Unknown => Self::new(Severity::Bug, "TI004", "Bug in implementation".to_string()),
265 }
266 }
267}
268
269impl Macro for TypeChecker<'_> {
270 type Error = InferError;
271 type Node = BodyFnDecl;
272
273 fn transform(
274 &mut self,
275 guard: VisitGuard<Self::Node>,
276 context: &mut impl Context,
277 ) -> Execution<Self::Error, ChangeSet> {
278 let (node, side) = guard.allow_all();
279 let Some(tree) = context.tree().upgrade() else {
280 return Execution::Skipped;
281 };
282 if !matches!(side, VisitSide::Leaf | VisitSide::Exiting) {
283 return Execution::Skipped;
284 }
285
286 let search = self.symbols.lookup(&*node, &tree, node.token())?;
287 let rec = RecursiveTypeChecker {
288 search,
289 token: node.token(),
290 tree: &tree,
291 models: &self.models,
292 evidence: self.evidence,
293 current_recursion_depth: Cell::new(self.recursion_depth.get()),
294 };
295 let fn_location = context
296 .access(&*node)
297 .map_or(vec![], |it: &rlt::BodiedFunction| vec![it.id.location()]);
298 let key: GenericNodeKey = node.get_id().widen().into();
299 match rec.maybe_get(&key) {
300 Ok(Some(ty)) => {
301 context.add_report(
302 fn_location,
303 TypeInfo {
304 name: &node.name,
305 ty: &ty,
306 },
307 );
308 }
309 Ok(None) => context.add_report(fn_location.clone(), CannotInfer),
310 Err(e) => e
311 .into_report_messages()
312 .into_iter()
313 .for_each(|it| context.add_report(fn_location.clone(), it)),
314 }
315
316 Execution::Completed(ChangeSet::new())
317 }
318}