cairo_lang_utils/
collection_arithmetics.rs

1#[cfg(test)]
2#[path = "collection_arithmetics_test.rs"]
3mod test;
4
5use core::hash::{BuildHasher, Hash};
6use core::ops::{Add, Sub};
7
8use crate::ordered_hash_map::{self, OrderedHashMap};
9#[cfg(feature = "std")]
10use crate::small_ordered_map::{self, SmallOrderedMap};
11
12/// A trait for types which have a zero value.
13///
14/// Functions may assume the following:
15/// * `x = x + zero() = zero() + x`
16pub trait HasZero {
17    /// Returns the zero value for the type.
18    fn zero() -> Self;
19}
20impl HasZero for i32 {
21    fn zero() -> Self {
22        0
23    }
24}
25impl HasZero for i64 {
26    fn zero() -> Self {
27        0
28    }
29}
30
31/// A trait for types which support addition on collections.
32pub trait AddCollection<Key, Value> {
33    /// Returns a new collection with the sum of the values from the given two collections, for each
34    /// key.
35    ///
36    /// If the key is missing from one of them, it is treated as zero.
37    fn add_collection(self, other: impl IntoIterator<Item = (Key, Value)>) -> Self;
38}
39
40/// A trait for types which support subtraction on collections.
41pub trait SubCollection<Key, Value> {
42    /// Returns a new collection with the difference of the values from the given two collections,
43    /// for each key.
44    ///
45    /// If the key is missing from one of them, it is treated as zero.
46    fn sub_collection(self, other: impl IntoIterator<Item = (Key, Value)>) -> Self;
47}
48
49pub trait MergeCollection<Key, Value> {
50    /// Returns a collection which contains the combination by using `action` of the values from the
51    /// given two collections, for each key.
52    ///
53    /// If the key is missing from one of them, it is treated as zero.
54    fn merge_collection(
55        self,
56        other: impl IntoIterator<Item = (Key, Value)>,
57        action: impl Fn(Value, Value) -> Value,
58    ) -> Self;
59}
60
61impl<Key, Value: Add<Output = Value>, T: MergeCollection<Key, Value>> AddCollection<Key, Value>
62    for T
63{
64    fn add_collection(self, other: impl IntoIterator<Item = (Key, Value)>) -> Self {
65        self.merge_collection(other, |a, b| a + b)
66    }
67}
68
69impl<Key, Value: Sub<Output = Value>, T: MergeCollection<Key, Value>> SubCollection<Key, Value>
70    for T
71{
72    fn sub_collection(self, other: impl IntoIterator<Item = (Key, Value)>) -> Self {
73        self.merge_collection(other, |a, b| a - b)
74    }
75}
76
77impl<Key: Hash + Eq, Value: HasZero + Clone + Eq, BH: BuildHasher> MergeCollection<Key, Value>
78    for OrderedHashMap<Key, Value, BH>
79{
80    fn merge_collection(
81        mut self,
82        other: impl IntoIterator<Item = (Key, Value)>,
83        action: impl Fn(Value, Value) -> Value,
84    ) -> Self {
85        for (key, other_val) in other {
86            match self.entry(key) {
87                ordered_hash_map::Entry::Occupied(mut e) => {
88                    let new_val = action(e.get().clone(), other_val);
89                    if new_val == Value::zero() {
90                        e.swap_remove();
91                    } else {
92                        e.insert(new_val);
93                    }
94                }
95                ordered_hash_map::Entry::Vacant(e) => {
96                    let zero = Value::zero();
97                    if other_val != zero {
98                        e.insert(action(zero, other_val));
99                    }
100                }
101            }
102        }
103        self
104    }
105}
106
107#[cfg(feature = "std")]
108impl<Key: Eq, Value: HasZero + Clone + Eq> MergeCollection<Key, Value>
109    for SmallOrderedMap<Key, Value>
110{
111    fn merge_collection(
112        mut self,
113        other: impl IntoIterator<Item = (Key, Value)>,
114        action: impl Fn(Value, Value) -> Value,
115    ) -> Self {
116        for (key, other_val) in other {
117            match self.entry(key) {
118                small_ordered_map::Entry::Occupied(mut e) => {
119                    let new_val = action(e.get().clone(), other_val);
120                    if new_val == Value::zero() {
121                        e.remove();
122                    } else {
123                        e.insert(new_val);
124                    }
125                }
126                small_ordered_map::Entry::Vacant(e) => {
127                    let zero = Value::zero();
128                    if other_val != zero {
129                        e.insert(action(zero, other_val));
130                    }
131                }
132            }
133        }
134        self
135    }
136}