kodept_inference/
substitution.rs1use std::collections::{HashMap, HashSet};
2use std::fmt::{Debug, Display, Formatter};
3use std::mem::take;
4use std::ops::Add;
5
6use itertools::Itertools;
7
8use crate::r#type::{MonomorphicType, TVar};
9
10#[derive(Clone, PartialEq)]
11#[repr(transparent)]
12pub struct Substitutions(HashMap<TVar, MonomorphicType>);
13
14impl Substitutions {
15 #[must_use]
16 pub fn compose(&self, other: &Substitutions) -> Self {
17 let mut copy = self.clone();
18 copy.merge(other.clone());
19 copy
20 }
21
22 pub fn merge(&mut self, other: Substitutions) {
23 let a: HashSet<_> = other
24 .0
25 .iter()
26 .map(|(key, ty)| (*key, ty & &*self))
27 .collect();
28 let b: HashSet<_> = take(&mut self.0)
29 .into_iter()
30 .map(|(key, ty)| (key, ty & &other))
31 .collect();
32
33 self.0 = b.union(&a).cloned().collect()
34 }
35
36 #[must_use]
37 pub fn empty() -> Substitutions {
38 Substitutions(HashMap::new())
39 }
40
41 #[must_use]
42 pub fn single(from: TVar, to: MonomorphicType) -> Substitutions {
43 Substitutions(HashMap::from([(from, to)]))
44 }
45
46 #[must_use]
47 pub fn get(&self, key: &TVar) -> Option<&MonomorphicType> {
48 self.0.get(key)
49 }
50
51 pub fn remove(&mut self, key: &TVar) {
52 self.0.remove(key);
53 }
54
55 #[cfg(test)]
56 pub(crate) fn into_inner(self) -> HashMap<TVar, MonomorphicType> {
57 self.0
58 }
59}
60
61impl<M: Into<MonomorphicType>> FromIterator<(TVar, M)> for Substitutions {
62 fn from_iter<T: IntoIterator<Item=(TVar, M)>>(iter: T) -> Self {
63 Self(HashMap::from_iter(iter.into_iter().map(|(a, b)| (a, b.into()))))
64 }
65}
66
67impl Add for &Substitutions {
68 type Output = Substitutions;
69
70 fn add(self, rhs: Self) -> Self::Output {
71 self.compose(rhs)
72 }
73}
74
75impl Add for Substitutions {
76 type Output = Substitutions;
77
78 fn add(mut self, rhs: Self) -> Self::Output {
79 self.merge(rhs);
80 self
81 }
82}
83
84impl Add<&Substitutions> for Substitutions {
85 type Output = Substitutions;
86
87 fn add(mut self, rhs: &Substitutions) -> Self::Output {
88 self.merge(rhs.clone());
89 self
90 }
91}
92
93impl Add<Substitutions> for &Substitutions {
94 type Output = Substitutions;
95
96 fn add(self, rhs: Substitutions) -> Self::Output {
97 self.compose(&rhs)
98 }
99}
100
101impl Display for Substitutions {
102 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
103 write!(
104 f,
105 "[{}]",
106 self.0
107 .iter()
108 .map(|it| format!("{} := {}", it.0, it.1))
109 .join(", ")
110 )
111 }
112}
113
114impl Debug for Substitutions {
115 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
116 write!(f, "{self}")
117 }
118}