1use crate::runtime::reify::Reifiable;
2use std::{clone::Clone, fmt, hash::Hash, rc::Rc};
3
4extern crate im_rc;
5
6use self::im_rc::HashMap;
7
8thread_local! {
9 static next_id: std::cell::RefCell<u32> = std::cell::RefCell::new(0);
10}
11
12fn get_next_id() -> u32 {
13 next_id.with(|id| {
14 let res = *id.borrow();
15 *id.borrow_mut() += 1;
16 res
17 })
18}
19
20#[derive(Clone)]
22pub struct Assoc<K, V>
23where K: Eq + Hash + Clone
24{
25 hamt: HashMap<K, V>,
26 id: u32,
30}
31
32impl<K: Eq + Hash + Clone, V: Clone + PartialEq> PartialEq for Assoc<K, V> {
33 fn eq(&self, other: &Self) -> bool { self.hamt == other.hamt }
35}
36
37impl<K: Eq + Hash + Clone, V: Clone + Eq> Eq for Assoc<K, V> {}
38
39impl<K: Eq + Hash + Clone, V: Clone> Default for Assoc<K, V> {
40 fn default() -> Self { Self::new() }
41}
42
43impl<K: Eq + Hash + Clone + Reifiable, V: Clone + Reifiable> Reifiable for Assoc<K, V> {
44 fn ty_name() -> crate::name::Name { crate::name::n("Assoc") }
45
46 fn concrete_arguments() -> Option<Vec<crate::ast::Ast>> {
47 Some(vec![K::ty_invocation(), V::ty_invocation()])
48 }
49
50 fn reify(&self) -> crate::runtime::eval::Value {
51 let res: Vec<_> =
52 self.hamt.iter().map(|(k, v)| Rc::new((k.clone(), v.clone()).reify())).collect();
53
54 crate::runtime::eval::Value::Sequence(res)
55 }
56
57 fn reflect(v: &crate::runtime::eval::Value) -> Self {
58 let mut res = Assoc::<K, V>::new();
59
60 extract!((v) crate::runtime::eval::Value::Sequence = (ref parts) => {
61 for part in parts {
62 let (k_part, v_part) = <(K,V)>::reflect(&**part);
63 res = res.set(k_part, v_part);
64 }
65 });
66 res
67 }
68}
69
70impl<K: Eq + Hash + Clone + fmt::Debug, V: Clone + fmt::Debug> fmt::Debug for Assoc<K, V> {
71 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
72 write!(f, "⟦")?;
73 let mut first = true;
74 for (k, v) in self.iter_pairs() {
75 if !first {
76 write!(f, ", ")?;
77 }
78 write!(f, "{:#?} ⇒ {:#?}", k, v)?;
79 first = false;
80 }
81 write!(f, "⟧")
82 }
83}
84
85impl<K: Eq + Hash + Clone + fmt::Display, V: Clone + fmt::Display> fmt::Display for Assoc<K, V> {
86 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
87 write!(f, "⟦")?;
88 let mut first = true;
89 for (k, v) in self.iter_pairs() {
90 if !first {
91 write!(f, ", ")?;
92 }
93 write!(f, "{} ⇒ {}", k, v)?;
94 first = false;
95 }
96 write!(f, "⟧")
97 }
98}
99
100impl<K: Eq + Hash + Clone, V: Clone> Assoc<K, V> {
103 fn from_hamt(hamt: HashMap<K, V>) -> Self { Assoc { hamt: hamt, id: get_next_id() } }
104
105 pub fn new() -> Self { Self::from_hamt(HashMap::new()) }
106
107 pub fn find(&self, key: &K) -> Option<&V> { self.hamt.get(key) }
108
109 pub fn set(&self, key: K, value: V) -> Self { Self::from_hamt(self.hamt.update(key, value)) }
110
111 pub fn set_assoc(&self, other: &Self) -> Self {
112 Self::from_hamt(other.hamt.clone().union(self.hamt.clone()))
113 }
114
115 pub fn mut_set(&mut self, key: K, value: V) { self.hamt.insert(key, value); }
116
117 pub fn single(key: K, value: V) -> Self { Self::new().set(key, value) }
118
119 pub fn empty(&self) -> bool { self.hamt.is_empty() }
120
121 pub fn iter_pairs(&self) -> im_rc::hashmap::Iter<K, V> { self.hamt.iter() }
122
123 pub fn iter_keys(&self) -> im_rc::hashmap::Keys<K, V> { self.hamt.keys() }
124
125 pub fn iter_values(&self) -> im_rc::hashmap::Values<K, V> { self.hamt.values() }
126
127 pub fn map<NewV: Clone, F>(&self, mut f: F) -> Assoc<K, NewV>
128 where F: FnMut(&V) -> NewV {
129 self.map_borrow_f(&mut f)
130 }
131
132 pub fn map_borrow_f<'assoc, NewV: Clone, F>(&'assoc self, f: &mut F) -> Assoc<K, NewV>
133 where F: FnMut(&'assoc V) -> NewV {
134 Assoc::<K, NewV>::from_hamt(self.hamt.iter().map(|(k, ref v)| (k.clone(), f(v))).collect())
135 }
136 pub fn keyed_map_borrow_f<NewV: Clone, F>(&self, f: &mut F) -> Assoc<K, NewV>
137 where F: FnMut(&K, &V) -> NewV {
138 Assoc::<K, NewV>::from_hamt(
139 self.hamt.iter().map(|(k, ref v)| (k.clone(), f(k, v))).collect(),
140 )
141 }
142
143 pub fn map_with<OtherV: Clone, NewV: Clone>(
144 &self,
145 other: &Assoc<K, OtherV>,
146 f: &dyn Fn(&V, &OtherV) -> NewV,
147 ) -> Assoc<K, NewV> {
148 Assoc::<K, NewV>::from_hamt(
149 self.hamt
150 .clone()
151 .intersection_with_key(other.hamt.clone(), |_, ref v_l, ref v_r| f(v_l, v_r)),
152 )
153 }
154
155 pub fn keyed_map_with<OtherV: Clone, NewV: Clone>(
156 &self,
157 other: &Assoc<K, OtherV>,
158 f: &dyn Fn(&K, &V, &OtherV) -> NewV,
159 ) -> Assoc<K, NewV> {
160 Assoc::<K, NewV>::from_hamt(
161 self.hamt
162 .clone()
163 .intersection_with_key(other.hamt.clone(), |ref k, ref v_l, ref v_r| {
164 f(k, v_l, v_r)
165 }),
166 )
167 }
168
169 pub fn find_value<'assoc, 'f>(&'assoc self, target: &'f V) -> Option<&'assoc K>
170 where V: PartialEq {
171 self.hamt.iter().find(|(_, v)| v == &target).map(|(k, _)| k)
172 }
173
174 pub fn find_or_panic<'assoc, 'f>(&'assoc self, target: &'f K) -> &'assoc V
175 where K: fmt::Display {
176 self.find(target).unwrap_or_else(|| icp!("{} not found in {}", target, self.map(|_| "…")))
177 }
178
179 pub fn remove<'assoc, 'f>(&'assoc mut self, target: &'f K) -> Option<V> {
180 self.hamt.remove(target)
181 }
182
183 pub fn remove_or_panic<'assoc, 'f>(&'assoc mut self, target: &'f K) -> V
184 where K: fmt::Display {
185 self.hamt
186 .remove(target)
187 .unwrap_or_else(|| icp!("{} not found in {}", target, self.map(|_| "…")))
188 }
189
190 pub fn cut_common(&self, other: &Assoc<K, V>) -> Assoc<K, V>
192 where V: PartialEq {
193 let mut hamt = self.hamt.clone();
194 hamt.retain(|k, v| other.find(k) != Some(v));
195 Self::from_hamt(hamt)
196 }
197
198 pub fn unset(&self, k: &K) -> Assoc<K, V> { Self::from_hamt(self.hamt.without(k)) }
199
200 pub fn reduce<Out>(&self, red: &dyn Fn(&K, &V, Out) -> Out, base: Out) -> Out {
201 self.hamt.iter().fold(base, |base, (k, v)| red(k, v, base))
202 }
203}
204
205impl<K: Eq + Hash + Clone, V: Clone> Assoc<K, V> {
206 pub fn almost_ptr_eq(&self, other: &Assoc<K, V>) -> bool {
207 self.id == other.id }
209}
210
211impl<K: Eq + Hash + Clone, V: Clone, E: Clone> Assoc<K, Result<V, E>> {
212 pub fn lift_result(self) -> Result<Assoc<K, V>, E> {
213 let mut oks = vec![];
214 for (k, res_v) in self.hamt.into_iter() {
215 oks.push((k, res_v?))
216 }
217 Ok(Assoc::<K, V>::from_hamt(HashMap::from(oks)))
218 }
219}
220
221#[test]
222fn basic_assoc() {
223 let mt: Assoc<i32, i32> = Assoc::new();
224 let a1 = mt.set(5, 6);
225 let a2 = a1.set(6, 7);
226 let a_override = a2.set(5, 500);
227
228 assert_eq!(mt.find(&5), None);
229 assert_eq!(a1.find(&6), None);
230 assert_eq!(a2.find(&999), None);
231 assert_eq!(a_override.find(&999), None);
232 assert_eq!(a1.find(&5), Some(&6));
233 assert_eq!(a2.find(&5), Some(&6));
234 assert_eq!(a2.find(&6), Some(&7));
235 assert_eq!(a2.find(&5), Some(&6));
236 assert_eq!(a_override.find(&5), Some(&500));
237 assert_eq!(a_override.find(&6), Some(&7));
238
239 assert_eq!(a_override.unset(&5).find(&5), None);
240 assert_eq!(a_override.unset(&6).find(&6), None);
241
242 assert_eq!(a_override.unset(&6).find(&5), Some(&500));
243 assert_eq!(a_override.unset(&5).find(&6), Some(&7));
244
245 assert_eq!(a_override.unset(&-111).find(&5), Some(&500));
246}
247
248#[test]
249fn assoc_equality() {
250 let mt: Assoc<i32, i32> = Assoc::new();
251 let a1 = mt.set(5, 6);
252 let a2 = a1.set(6, 7);
253 let a_override = a2.set(5, 500);
254
255 let a2_opposite = mt.set(6, 7).set(5, 6);
256 let a_override_direct = mt.set(5, 500).set(6, 7);
257
258 assert_eq!(mt, Assoc::new());
259 assert_eq!(a1, a1);
260 assert!(a1 != mt);
261 assert!(mt != a1);
262 assert_eq!(a2, a2);
263 assert_eq!(a2, a2_opposite);
264 assert_eq!(a_override, a_override_direct);
265 assert!(a2 != a_override);
266
267 let a1_again = mt.set(5, 6);
268
269 assert_eq!(mt.cut_common(&mt), mt);
271 assert_eq!(a1.cut_common(&mt), a1);
272 assert_eq!(mt.cut_common(&a1), mt);
273
274 assert_eq!(a1_again.cut_common(&a1), mt);
276 assert_eq!(a_override_direct.cut_common(&a_override), mt);
277 assert_eq!(a_override.cut_common(&a_override_direct), mt);
278 assert_eq!(a1.cut_common(&a1), mt);
279 assert_eq!(a2.cut_common(&a2), mt);
280
281 assert_eq!(a2.cut_common(&a1), mt.set(6, 7));
283 assert_eq!(a_override.cut_common(&a2), mt.set(5, 500));
284
285 assert!(mt.almost_ptr_eq(&mt));
286 assert!(a2.almost_ptr_eq(&a2));
287 assert!(a_override_direct.almost_ptr_eq(&a_override_direct));
288 assert!(!a2.almost_ptr_eq(&a2_opposite));
289 }
291
292#[test]
293fn assoc_r_and_r_roundtrip() {
294 use num::BigInt;
295 let mt: Assoc<BigInt, BigInt> = Assoc::new();
296 let a1 = mt.set(BigInt::from(5), BigInt::from(6));
297 let a2 = a1.set(BigInt::from(6), BigInt::from(7));
298
299 assert_eq!(mt, Assoc::<BigInt, BigInt>::reflect(&mt.reify()));
300 assert_eq!(a2, Assoc::<BigInt, BigInt>::reflect(&a2.reify()));
301}
302
303#[test]
304fn assoc_map() {
305 let a1 = assoc_n!("x" => 1, "y" => 2, "z" => 3);
306 assert_eq!(a1.map(|a| a + 1), assoc_n!("x" => 2, "y" => 3, "z" => 4));
307
308 let a2 = assoc_n!("y" => -2, "z" => -3, "x" => -1);
309 assert_eq!(a1.map_with(&a2, &|a, b| a + b), assoc_n!("x" => 0, "y" => 0, "z" => 0));
310}
311
312#[test]
313fn assoc_reduce() {
314 let a1 = assoc_n!("x" => 1, "y" => 2, "z" => 3);
315 assert_eq!(a1.reduce(&|_key, a, b| a + b, 0), 6);
316
317 let a1 = assoc_n!("x" => 1, "y" => 2, "z" => 3);
318 assert_eq!(a1.reduce(&|key, a, b| if key.is("y") { b } else { a + b }, 0), 4);
319}