1use std::collections::HashMap;
4use std::fmt;
5use std::ops::{Add, AddAssign, Sub, SubAssign, Mul, MulAssign};
6
7
8#[derive(Eq, PartialEq, Default)]
9pub struct ArithMap<'a, V> {
10 pub hashmap: HashMap<&'a str, V>,
11}
12
13impl<V> fmt::Debug for ArithMap<'_, V>
14where
15 V: fmt::Debug,
16{
17 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
18 f.debug_map()
19 .entries(self.hashmap.iter().map(|(k, v)| (k, v)))
20 .finish()
21 }
22}
23
24#[macro_export(local_inner_macros)]
25macro_rules! arithmap {
33 (@single $($x:tt)*) => (());
34 (@count $($rest:expr),*) => (<[()]>::len(&[$(arithmap!(@single $rest)),*]));
35 ($($key:expr => $value:expr,)+) => { arithmap!($($key => $value),+) };
36 ($($key:expr => $value:expr),*) => {
37 {
38 let _cap = arithmap!(@count $($key),*);
39 let mut _map = ::std::collections::HashMap::with_capacity(_cap);
40 $(
41 let _ = _map.insert($key, $value);
42 )*
43 ArithMap{hashmap: _map}
44 }
45 };
46}
47
48impl<V> ArithMap<'_, V>
49where
50 V: Copy + Default + PartialEq,
51{
52 pub fn prune(&mut self) {
61 let zero: V = Default::default();
62 self.hashmap.retain(|_, &mut v| v != zero);
63 }
64}
65
66impl<V> Add<V> for ArithMap<'_, V>
74where
75 V: AddAssign + Copy,
76{
77 type Output = Self;
78 fn add(mut self: Self, other: V) -> Self {
79 for v in self.hashmap.values_mut() {
80 *v += other;
81 }
82 self
83 }
84}
85
86impl<V> AddAssign<V> for ArithMap<'_, V>
95where
96 V: AddAssign + Copy,
97{
98 fn add_assign(&mut self, other: V) {
99 for v in self.hashmap.values_mut() {
100 *v += other;
101 }
102 }
103}
104
105impl<V> Sub<V> for ArithMap<'_, V>
113where
114 V: SubAssign + Copy,
115{
116 type Output = Self;
117 fn sub(mut self: Self, other: V) -> Self {
118 for v in self.hashmap.values_mut() {
119 *v -= other;
120 }
121 self
122 }
123}
124
125impl<V> SubAssign<V> for ArithMap<'_, V>
134where
135 V: SubAssign + Copy,
136{
137 fn sub_assign(&mut self, other: V) {
138 for v in self.hashmap.values_mut() {
139 *v -= other;
140 }
141 }
142}
143
144impl<V> Mul<V> for ArithMap<'_, V>
152where
153 V: MulAssign + Copy,
154{
155 type Output = Self;
156 fn mul(mut self: Self, other: V) -> Self {
157 for v in self.hashmap.values_mut() {
158 *v *= other;
159 }
160 self
161 }
162}
163
164impl<V> MulAssign<V> for ArithMap<'_, V>
173where
174 V: MulAssign + Copy,
175{
176 fn mul_assign(&mut self, other: V) {
177 for v in self.hashmap.values_mut() {
178 *v *= other;
179 }
180 }
181}
182
183impl<V> Add for ArithMap<'_, V>
192where
193 V: Add<Output = V> + AddAssign + Copy + Default,
194{
195 type Output = Self;
196 fn add(self: Self, other: Self) -> Self {
197 let mut r: Self = Default::default();
198 for (k, v) in self.hashmap.iter() {
199 r.hashmap.insert(*k, *v);
200 }
201 for (k, v2) in other.hashmap.iter() {
202 if let Some(v1) = r.hashmap.get_mut(k) {
203 *v1 += *v2;
204 } else {
205 r.hashmap.insert(*k, *v2);
206 }
207 }
208 r
209 }
210}
211
212impl<V> AddAssign for ArithMap<'_, V>
222where
223 V: Add<Output = V> + AddAssign + Copy + Default,
224{
225 fn add_assign(&mut self, other: Self) {
226 for (k, v2) in other.hashmap.iter() {
227 if let Some(v1) = self.hashmap.get_mut(k) {
228 *v1 += *v2;
229 } else {
230 self.hashmap.insert(*k, *v2);
231 }
232 }
233 }
234}
235
236impl<V> Sub for ArithMap<'_, V>
245where
246 V: Sub<Output = V> + SubAssign + Copy + Default,
247{
248 type Output = Self;
249 fn sub(self: Self, other: Self) -> Self {
250 let zero: V = Default::default();
251 let mut r: Self = Default::default();
252 for (k, v) in self.hashmap.iter() {
253 r.hashmap.insert(*k, *v);
254 }
255 for (k, v2) in other.hashmap.iter() {
256 if let Some(v1) = r.hashmap.get_mut(k) {
257 *v1 -= *v2;
258 } else {
259 r.hashmap.insert(*k, zero - *v2);
260 }
261 }
262 r
263 }
264}
265
266impl<V> SubAssign for ArithMap<'_, V>
276where
277 V: Sub<Output = V> + SubAssign + Copy + Default,
278{
279 fn sub_assign(&mut self, other: Self) {
280 let zero: V = Default::default();
281 for (k, v2) in other.hashmap.iter() {
282 if let Some(v1) = self.hashmap.get_mut(k) {
283 *v1 -= *v2;
284 } else {
285 self.hashmap.insert(*k, zero - *v2);
286 }
287 }
288 }
289}