Skip to main content

tsz_solver/narrowing/
utils.rs

1//! Narrowing visitor and utility functions.
2//!
3//! Contains the `NarrowingVisitor` (`TypeVisitor` implementation for structural narrowing)
4//! and standalone public utility functions for nullish/falsy type handling.
5
6use super::{DiscriminantInfo, NarrowingContext};
7use crate::relations::subtype::SubtypeChecker;
8use crate::types::{IntrinsicKind, LiteralValue, TypeData, TypeId, TypeListId, TypeParamInfo};
9use crate::visitor::{TypeVisitor, is_object_like_type_db, literal_value, union_list_id};
10use crate::{QueryDatabase, TypeDatabase};
11use tsz_common::interner::Atom;
12
13/// Visitor that narrows a type by filtering/intersecting with a narrower type.
14pub(crate) struct NarrowingVisitor<'a> {
15    pub(crate) db: &'a dyn QueryDatabase,
16    pub(crate) narrower: TypeId,
17    /// PERF: Reusable `SubtypeChecker` to avoid per-call hash allocations
18    pub(crate) checker: SubtypeChecker<'a>,
19}
20
21impl<'a> TypeVisitor for NarrowingVisitor<'a> {
22    type Output = TypeId;
23
24    /// Override `visit_type` to handle types that need special handling.
25    /// We intercept Lazy/Ref/Application types for resolution, and Object/Function
26    /// types for proper subtype checking (we need the `TypeId` here).
27    fn visit_type(&mut self, types: &dyn TypeDatabase, type_id: TypeId) -> Self::Output {
28        // Check if this is a type that needs special handling
29        if let Some(type_key) = types.lookup(type_id) {
30            match type_key {
31                // Lazy types: resolve and recurse
32                TypeData::Lazy(_) => {
33                    // Use self.db (QueryDatabase) which has evaluate_type
34                    let resolved = self.db.evaluate_type(type_id);
35                    // If resolution changed the type, recurse with the resolved type
36                    if resolved != type_id {
37                        return self.visit_type(types, resolved);
38                    }
39                    // Otherwise, fall through to normal visitation
40                }
41                // Ref types: resolve and recurse
42                TypeData::TypeQuery(_) | TypeData::Application(_) => {
43                    let resolved = self.db.evaluate_type(type_id);
44                    if resolved != type_id {
45                        return self.visit_type(types, resolved);
46                    }
47                }
48                // Object types: check subtype relationships
49                TypeData::Object(_) => {
50                    // Case 1: type_id is subtype of narrower (e.g., { a: "foo" } narrowed by { a: string })
51                    // Result: type_id (keep the more specific type)
52                    self.checker.reset();
53                    if self.checker.is_subtype_of(type_id, self.narrower) {
54                        return type_id;
55                    }
56                    // Case 2: narrower is subtype of type_id (e.g., { a: string } narrowed by { a: "foo" })
57                    // Result: narrower (narrow down to the more specific type)
58                    self.checker.reset();
59                    if self.checker.is_subtype_of(self.narrower, type_id) {
60                        return self.narrower;
61                    }
62                    // Case 3: Both are object types but not directly related
63                    // They might overlap (e.g., interfaces with common properties)
64                    // For now, conservatively return the intersection
65                    if is_object_like_type_db(self.db, self.narrower) {
66                        return self.db.intersection2(type_id, self.narrower);
67                    }
68                    // Case 4: Disjoint object types
69                    return TypeId::NEVER;
70                }
71                // Function types: check subtype relationships
72                TypeData::Function(_) => {
73                    // Case 1: type_id is subtype of narrower (keep specific)
74                    self.checker.reset();
75                    if self.checker.is_subtype_of(type_id, self.narrower) {
76                        return type_id;
77                    }
78                    // Case 2: narrower is subtype of type_id (narrow down)
79                    self.checker.reset();
80                    if self.checker.is_subtype_of(self.narrower, type_id) {
81                        return self.narrower;
82                    }
83                    // Case 3: Disjoint function types
84                    return TypeId::NEVER;
85                }
86                _ => {}
87            }
88        }
89
90        // For all other types, use the default visit_type implementation
91        // which calls visit_type_key and dispatches to specific methods
92        <Self as TypeVisitor>::visit_type(self, types, type_id)
93    }
94
95    fn visit_intrinsic(&mut self, kind: IntrinsicKind) -> Self::Output {
96        match kind {
97            IntrinsicKind::Any => {
98                // Narrowing `any` by anything returns that type
99                self.narrower
100            }
101            IntrinsicKind::Unknown => {
102                // Narrowing `unknown` by anything returns that type
103                self.narrower
104            }
105            IntrinsicKind::Never => {
106                // Never stays never
107                TypeId::NEVER
108            }
109            _ => {
110                // For other intrinsics, we need to handle the overlap case
111                // Narrowing primitive by primitive is effectively intersection
112                let type_id = TypeId(kind as u32);
113
114                // Case 1: narrower is subtype of type_id (e.g., narrow(string, "foo"))
115                // Result: narrower
116                self.checker.reset();
117                if self.checker.is_subtype_of(self.narrower, type_id) {
118                    self.narrower
119                }
120                // Case 2: type_id is subtype of narrower (e.g., narrow("foo", string))
121                // Result: type_id (the original)
122                else {
123                    self.checker.reset();
124                    if self.checker.is_subtype_of(type_id, self.narrower) {
125                        type_id
126                    }
127                    // Case 3: Disjoint types (e.g., narrow(string, number))
128                    // Result: never
129                    else {
130                        TypeId::NEVER
131                    }
132                }
133            }
134        }
135    }
136
137    fn visit_literal(&mut self, _value: &LiteralValue) -> Self::Output {
138        // For literal types, check if assignable to narrower
139        // The literal type_id will be constructed and checked
140        // For now, return the narrower (will be refined with actual type_id)
141        self.narrower
142    }
143
144    fn visit_union(&mut self, list_id: u32) -> Self::Output {
145        let members = self.db.type_list(TypeListId(list_id));
146
147        // CRITICAL: Recursively narrow each union member, don't just check subtype
148        // This handles cases like: string narrowed by "foo" -> "foo"
149        // where "foo" is NOT a subtype of string, but string contains "foo"
150        let filtered: Vec<TypeId> = members
151            .iter()
152            .filter_map(|&member| {
153                let narrowed = self.visit_type(self.db, member);
154                if narrowed == TypeId::NEVER {
155                    None
156                } else {
157                    Some(narrowed)
158                }
159            })
160            .collect();
161
162        if filtered.is_empty() {
163            TypeId::NEVER
164        } else if filtered.len() == members.len() {
165            // All members matched - reconstruct the union
166            self.db.union(filtered)
167        } else if filtered.len() == 1 {
168            filtered[0]
169        } else {
170            self.db.union(filtered)
171        }
172    }
173
174    fn visit_intersection(&mut self, list_id: u32) -> Self::Output {
175        let members = self.db.type_list(TypeListId(list_id));
176
177        // Narrow each intersection member individually and collect non-never results
178        // For (A & B) narrowed by C, the result is (A narrowed by C) & (B narrowed by C)
179        let narrowed_members: Vec<TypeId> = members
180            .iter()
181            .filter_map(|&member| {
182                let narrowed = self.visit_type(self.db, member);
183                if narrowed == TypeId::NEVER {
184                    None
185                } else {
186                    Some(narrowed)
187                }
188            })
189            .collect();
190
191        if narrowed_members.is_empty() {
192            TypeId::NEVER
193        } else if narrowed_members.len() == 1 {
194            narrowed_members[0]
195        } else {
196            self.db.intersection(narrowed_members)
197        }
198    }
199
200    fn visit_type_parameter(&mut self, info: &TypeParamInfo) -> Self::Output {
201        // For type parameters, intersect with the narrower
202        // This constrains the generic type variable
203        if let Some(constraint) = info.constraint {
204            self.db.intersection2(constraint, self.narrower)
205        } else {
206            // No constraint, so narrowing gives us the narrower
207            self.narrower
208        }
209    }
210
211    fn visit_lazy(&mut self, _def_id: u32) -> Self::Output {
212        // Lazy types are now handled in visit_type by resolving and recursing
213        // This should never be called anymore, but if it is, return narrower
214        self.narrower
215    }
216
217    fn visit_ref(&mut self, _symbol_ref: u32) -> Self::Output {
218        // Ref types are now handled in visit_type by resolving and recursing
219        // This should never be called anymore, but if it is, return narrower
220        self.narrower
221    }
222
223    fn visit_application(&mut self, _app_id: u32) -> Self::Output {
224        // Application types are now handled in visit_type by resolving and recursing
225        // This should never be called anymore, but if it is, return narrower
226        self.narrower
227    }
228
229    fn visit_object(&mut self, _shape_id: u32) -> Self::Output {
230        // Object types are now handled in visit_type where we have the TypeId
231        // For now, conservatively return the narrower
232        self.narrower
233    }
234
235    fn visit_function(&mut self, _shape_id: u32) -> Self::Output {
236        // Function types are now handled in visit_type where we have the TypeId
237        // For now, conservatively return the narrower
238        self.narrower
239    }
240
241    fn visit_callable(&mut self, _shape_id: u32) -> Self::Output {
242        // For callable types, conservatively return the narrower
243        self.narrower
244    }
245
246    fn visit_tuple(&mut self, _list_id: u32) -> Self::Output {
247        // For tuple types, conservatively return the narrower
248        self.narrower
249    }
250
251    fn visit_array(&mut self, _element_type: TypeId) -> Self::Output {
252        // For array types, conservatively return the narrower
253        self.narrower
254    }
255
256    fn default_output() -> Self::Output {
257        // Fallback for types not explicitly handled above
258        // Conservative: return never (type doesn't match the narrower)
259        // This is safe because:
260        // - For unions, this member will be excluded from the filtered result
261        // - For other contexts, never means "no match"
262        TypeId::NEVER
263    }
264}
265
266/// Convenience function for finding discriminants.
267pub fn find_discriminants(
268    interner: &dyn QueryDatabase,
269    union_type: TypeId,
270) -> Vec<DiscriminantInfo> {
271    let ctx = NarrowingContext::new(interner);
272    ctx.find_discriminants(union_type)
273}
274
275/// Convenience function for narrowing by discriminant.
276pub fn narrow_by_discriminant(
277    interner: &dyn QueryDatabase,
278    union_type: TypeId,
279    property_path: &[Atom],
280    literal_value: TypeId,
281) -> TypeId {
282    let ctx = NarrowingContext::new(interner);
283    ctx.narrow_by_discriminant(union_type, property_path, literal_value)
284}
285
286/// Convenience function for typeof narrowing.
287pub fn narrow_by_typeof(
288    interner: &dyn QueryDatabase,
289    source_type: TypeId,
290    typeof_result: &str,
291) -> TypeId {
292    let ctx = NarrowingContext::new(interner);
293    ctx.narrow_by_typeof(source_type, typeof_result)
294}
295
296// =============================================================================
297// Nullish Type Helpers
298// =============================================================================
299
300fn top_level_union_members(types: &dyn TypeDatabase, type_id: TypeId) -> Option<Vec<TypeId>> {
301    union_list_id(types, type_id).map(|list_id| types.type_list(list_id).to_vec())
302}
303
304const fn is_undefined_intrinsic(type_id: TypeId) -> bool {
305    matches!(type_id, TypeId::UNDEFINED | TypeId::VOID)
306}
307
308fn normalize_nullish(type_id: TypeId) -> TypeId {
309    if type_id == TypeId::VOID {
310        TypeId::UNDEFINED
311    } else {
312        type_id
313    }
314}
315
316/// Check if a type is nullish (null/undefined/void or union containing them).
317pub fn is_nullish_type(types: &dyn TypeDatabase, type_id: TypeId) -> bool {
318    if type_id.is_nullable() {
319        return true;
320    }
321    if let Some(members) = top_level_union_members(types, type_id) {
322        return members.iter().any(|&member| is_nullish_type(types, member));
323    }
324    false
325}
326
327/// Check if a type contains undefined (or void).
328pub fn type_contains_undefined(types: &dyn TypeDatabase, type_id: TypeId) -> bool {
329    if is_undefined_intrinsic(type_id) {
330        return true;
331    }
332    if let Some(members) = top_level_union_members(types, type_id) {
333        return members
334            .iter()
335            .any(|&member| type_contains_undefined(types, member));
336    }
337    false
338}
339
340/// Check if a type is definitely nullish (only null/undefined/void).
341pub fn is_definitely_nullish(types: &dyn TypeDatabase, type_id: TypeId) -> bool {
342    if type_id.is_nullable() {
343        return true;
344    }
345    if let Some(members) = top_level_union_members(types, type_id) {
346        return members
347            .iter()
348            .all(|&member| is_definitely_nullish(types, member));
349    }
350    false
351}
352
353fn split_nullish_members(
354    types: &dyn TypeDatabase,
355    type_id: TypeId,
356    non_nullish: &mut Vec<TypeId>,
357    nullish: &mut Vec<TypeId>,
358) {
359    if type_id.is_nullable() {
360        nullish.push(normalize_nullish(type_id));
361        return;
362    }
363
364    if let Some(members) = top_level_union_members(types, type_id) {
365        for member in members {
366            split_nullish_members(types, member, non_nullish, nullish);
367        }
368        return;
369    }
370
371    non_nullish.push(type_id);
372}
373
374/// Split a type into its non-nullish part and its nullish cause.
375pub fn split_nullish_type(
376    types: &dyn TypeDatabase,
377    type_id: TypeId,
378) -> (Option<TypeId>, Option<TypeId>) {
379    let mut non_nullish = Vec::new();
380    let mut nullish = Vec::new();
381
382    split_nullish_members(types, type_id, &mut non_nullish, &mut nullish);
383
384    if nullish.is_empty() {
385        return (Some(type_id), None);
386    }
387
388    let non_nullish_type = if non_nullish.is_empty() {
389        None
390    } else if non_nullish.len() == 1 {
391        Some(non_nullish[0])
392    } else {
393        Some(types.union(non_nullish))
394    };
395
396    let nullish_type = if nullish.len() == 1 {
397        Some(nullish[0])
398    } else {
399        Some(types.union(nullish))
400    };
401
402    (non_nullish_type, nullish_type)
403}
404
405/// Remove nullish parts of a type (non-null assertion).
406pub fn remove_nullish(types: &dyn TypeDatabase, type_id: TypeId) -> TypeId {
407    let (non_nullish, _) = split_nullish_type(types, type_id);
408    non_nullish.unwrap_or(TypeId::NEVER)
409}
410
411/// Remove types that are *definitely* falsy from a union, without narrowing
412/// non-falsy types. This matches TypeScript's `removeDefinitelyFalsyTypes`:
413/// removes `null`, `undefined`, `void`, `false`, `0`, `""`, `0n` but keeps
414/// `boolean`, `string`, `number`, `bigint`, and object types unchanged.
415pub fn remove_definitely_falsy_types(types: &dyn TypeDatabase, type_id: TypeId) -> TypeId {
416    if is_always_falsy(types, type_id) {
417        return TypeId::NEVER;
418    }
419    if let Some(members_id) = union_list_id(types, type_id) {
420        let members = types.type_list(members_id);
421        let remaining: Vec<TypeId> = members
422            .iter()
423            .copied()
424            .filter(|&m| !is_always_falsy(types, m))
425            .collect();
426        if remaining.is_empty() {
427            return TypeId::NEVER;
428        }
429        if remaining.len() == 1 {
430            return remaining[0];
431        }
432        if remaining.len() == members.len() {
433            return type_id;
434        }
435        return types.union(remaining);
436    }
437    type_id
438}
439
440/// Check if a type is always falsy (null, undefined, void, false, 0, "", 0n).
441fn is_always_falsy(types: &dyn TypeDatabase, type_id: TypeId) -> bool {
442    if matches!(type_id, TypeId::NULL | TypeId::UNDEFINED | TypeId::VOID) {
443        return true;
444    }
445    if let Some(lit) = literal_value(types, type_id) {
446        return match lit {
447            LiteralValue::Boolean(false) => true,
448            LiteralValue::Number(n) => n.0 == 0.0 || n.0.is_nan(),
449            LiteralValue::String(atom) => types.resolve_atom_ref(atom).is_empty(),
450            LiteralValue::BigInt(atom) => types.resolve_atom_ref(atom).as_ref() == "0",
451            _ => false,
452        };
453    }
454    false
455}