rue_types/
subtract.rs

1use id_arena::Arena;
2use indexmap::IndexMap;
3use log::trace;
4
5use crate::{
6    Alias, BuiltinTypes, Comparison, Struct, Type, TypeId, Union, compare, stringify_impl,
7    substitute,
8};
9
10pub fn subtract(
11    arena: &mut Arena<Type>,
12    builtins: &BuiltinTypes,
13    lhs_id: TypeId,
14    rhs_id: TypeId,
15) -> TypeId {
16    let lhs_id = substitute(arena, lhs_id);
17    let rhs_id = substitute(arena, rhs_id);
18    let lhs_name = stringify_impl(arena, lhs_id, &mut IndexMap::new());
19    let rhs_name = stringify_impl(arena, rhs_id, &mut IndexMap::new());
20    trace!("Subtracting {lhs_name} from {rhs_name}");
21    let result = subtract_impl(arena, builtins, lhs_id, rhs_id);
22    let new_name = stringify_impl(arena, result, &mut IndexMap::new());
23    trace!("Subtraction from {lhs_name} to {rhs_name} yielded {new_name}");
24    result
25}
26
27fn subtract_impl(
28    arena: &mut Arena<Type>,
29    builtins: &BuiltinTypes,
30    lhs_id: TypeId,
31    rhs_id: TypeId,
32) -> TypeId {
33    let mut lhs_variants = variants_of(arena, builtins, lhs_id);
34
35    lhs_variants
36        .retain(|variant| compare(arena, builtins, variant.type_id, rhs_id) > Comparison::Cast);
37
38    repackage(arena, lhs_variants)
39}
40
41fn repackage(arena: &mut Arena<Type>, variants: Vec<Variant>) -> TypeId {
42    let mut collected = vec![];
43    let mut semantic_type_id = None;
44    let mut leftover = vec![];
45
46    for mut variant in variants {
47        if let Some(id) = variant.semantic_type_ids.last().copied() {
48            if let Some(semantic_type_id) = semantic_type_id {
49                if id == semantic_type_id {
50                    variant.semantic_type_ids.pop().unwrap();
51                    collected.push(variant);
52                } else {
53                    leftover.push(variant);
54                }
55            } else {
56                semantic_type_id = Some(id);
57                variant.semantic_type_ids.pop().unwrap();
58                collected.push(variant);
59            }
60        } else {
61            leftover.push(variant);
62        }
63    }
64
65    if collected.is_empty() {
66        let leftover: Vec<TypeId> = leftover
67            .into_iter()
68            .map(|variant| variant.type_id)
69            .collect();
70
71        if leftover.is_empty() {
72            return arena.alloc(Type::Never);
73        } else if leftover.len() == 1 {
74            return leftover[0];
75        }
76
77        return arena.alloc(Type::Union(Union::new(leftover)));
78    }
79
80    let semantic_type_id = semantic_type_id.unwrap();
81
82    let inner = repackage(arena, collected);
83
84    let result = match arena[semantic_type_id].clone() {
85        Type::Alias(alias) => arena.alloc(Type::Alias(Alias { inner, ..alias })),
86        Type::Struct(ty) => arena.alloc(Type::Struct(Struct { inner, ..ty })),
87        _ => unreachable!(),
88    };
89
90    if leftover.is_empty() {
91        return result;
92    }
93
94    let leftover = repackage(arena, leftover);
95
96    arena.alloc(Type::Union(Union::new(vec![result, leftover])))
97}
98
99struct Variant {
100    semantic_type_ids: Vec<TypeId>,
101    type_id: TypeId,
102}
103
104fn variants_of(arena: &Arena<Type>, builtins: &BuiltinTypes, id: TypeId) -> Vec<Variant> {
105    match arena[id].clone() {
106        Type::Apply(_) => unreachable!(),
107        Type::Ref(id) => variants_of(arena, builtins, id),
108        Type::Unresolved | Type::Atom(_) | Type::Pair(_) => vec![Variant {
109            semantic_type_ids: vec![],
110            type_id: id,
111        }],
112        Type::Never => vec![],
113        Type::Alias(alias) => {
114            let mut variants = variants_of(arena, builtins, alias.inner);
115
116            for variant in &mut variants {
117                variant.semantic_type_ids.push(id);
118            }
119
120            variants
121        }
122        Type::Struct(ty) => {
123            let mut variants = variants_of(arena, builtins, ty.inner);
124
125            for variant in &mut variants {
126                variant.semantic_type_ids.push(id);
127            }
128
129            variants
130        }
131        Type::Function(_) | Type::Generic => vec![
132            Variant {
133                semantic_type_ids: vec![],
134                type_id: builtins.atom,
135            },
136            Variant {
137                semantic_type_ids: vec![],
138                type_id: builtins.any_pair,
139            },
140        ],
141        Type::Union(ty) => {
142            let mut variants = Vec::new();
143
144            for variant in ty.types {
145                variants.extend(variants_of(arena, builtins, variant));
146            }
147
148            variants
149        }
150    }
151}