kodept_inference/
substitution.rs

1use std::collections::{HashMap, HashSet};
2use std::fmt::{Debug, Display, Formatter};
3use std::mem::take;
4use std::ops::Add;
5
6use itertools::Itertools;
7
8use crate::r#type::{MonomorphicType, TVar};
9
10#[derive(Clone, PartialEq)]
11#[repr(transparent)]
12pub struct Substitutions(HashMap<TVar, MonomorphicType>);
13
14impl Substitutions {
15    #[must_use]
16    pub fn compose(&self, other: &Substitutions) -> Self {
17        let mut copy = self.clone();
18        copy.merge(other.clone());
19        copy
20    }
21    
22    pub fn merge(&mut self, other: Substitutions) {
23        let a: HashSet<_> = other
24            .0
25            .iter()
26            .map(|(key, ty)| (*key, ty & &*self))
27            .collect();
28        let b: HashSet<_> = take(&mut self.0)
29            .into_iter()
30            .map(|(key, ty)| (key, ty & &other))
31            .collect();
32
33        self.0 = b.union(&a).cloned().collect()
34    }
35
36    #[must_use]
37    pub fn empty() -> Substitutions {
38        Substitutions(HashMap::new())
39    }
40
41    #[must_use]
42    pub fn single(from: TVar, to: MonomorphicType) -> Substitutions {
43        Substitutions(HashMap::from([(from, to)]))
44    }
45    
46    #[must_use]
47    pub fn get(&self, key: &TVar) -> Option<&MonomorphicType> {
48        self.0.get(key)
49    }
50    
51    pub fn remove(&mut self, key: &TVar) {
52        self.0.remove(key);
53    }
54    
55    #[cfg(test)]
56    pub(crate) fn into_inner(self) -> HashMap<TVar, MonomorphicType> {
57        self.0
58    }
59}
60
61impl<M: Into<MonomorphicType>> FromIterator<(TVar, M)> for Substitutions {
62    fn from_iter<T: IntoIterator<Item=(TVar, M)>>(iter: T) -> Self {
63        Self(HashMap::from_iter(iter.into_iter().map(|(a, b)| (a, b.into()))))
64    }
65}
66
67impl Add for &Substitutions {
68    type Output = Substitutions;
69
70    fn add(self, rhs: Self) -> Self::Output {
71        self.compose(rhs)
72    }
73}
74
75impl Add for Substitutions {
76    type Output = Substitutions;
77
78    fn add(mut self, rhs: Self) -> Self::Output {
79        self.merge(rhs);
80        self
81    }
82}
83
84impl Add<&Substitutions> for Substitutions {
85    type Output = Substitutions;
86
87    fn add(mut self, rhs: &Substitutions) -> Self::Output {
88        self.merge(rhs.clone());
89        self
90    }
91}
92
93impl Add<Substitutions> for &Substitutions {
94    type Output = Substitutions;
95
96    fn add(self, rhs: Substitutions) -> Self::Output {
97        self.compose(&rhs)
98    }
99}
100
101impl Display for Substitutions {
102    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
103        write!(
104            f,
105            "[{}]",
106            self.0
107                .iter()
108                .map(|it| format!("{} := {}", it.0, it.1))
109                .join(", ")
110        )
111    }
112}
113
114impl Debug for Substitutions {
115    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
116        write!(f, "{self}")
117    }
118}