1use std::collections::hash_map::Entry;
2use std::collections::{HashMap, hash_map};
3
4use crate::ring::*;
5use crate::seq::*;
6
7pub struct SparseMapVector<R: RingStore> {
8 data: HashMap<usize, El<R>>,
9 modify_entry: (usize, El<R>),
10 zero: El<R>,
11 ring: R,
12 len: usize,
13}
14
15impl<R: RingStore> SparseMapVector<R> {
16 pub fn new(len: usize, ring: R) -> Self {
17 SparseMapVector {
18 data: HashMap::new(),
19 modify_entry: (usize::MAX, ring.zero()),
20 zero: ring.zero(),
21 ring,
22 len,
23 }
24 }
25
26 #[stability::unstable(feature = "enable")]
27 pub fn set_len(&mut self, new_len: usize) {
28 if new_len < self.len() {
29 for (i, _) in self.nontrivial_entries() {
30 debug_assert!(i < new_len);
31 }
32 }
33 self.len = new_len;
34 }
35
36 #[stability::unstable(feature = "enable")]
37 pub fn scan<F>(&mut self, mut f: F)
38 where
39 F: FnMut(usize, &mut El<R>),
40 {
41 self.enter_in_map((usize::MAX, self.ring.zero()));
42 self.data.retain(|i, c| {
43 f(*i, c);
44 !self.ring.is_zero(c)
45 });
46 }
47
48 #[cfg(test)]
49 fn check_consistency(&self) {
50 assert!(self.ring.is_zero(&self.modify_entry.1) || self.modify_entry.0 < self.len());
51 }
52
53 fn enter_in_map(&mut self, new_modify_entry: (usize, El<R>)) {
54 if self.modify_entry.0 != usize::MAX {
55 let (index, value) = std::mem::replace(&mut self.modify_entry, new_modify_entry);
56 match self.data.entry(index) {
57 Entry::Occupied(mut e) if !self.ring.is_zero(&value) => {
58 *e.get_mut() = value;
59 }
60 Entry::Occupied(e) => {
61 _ = e.remove();
62 }
63 Entry::Vacant(e) if !self.ring.is_zero(&value) => {
64 _ = e.insert(value);
65 }
66 Entry::Vacant(_) => {}
67 };
68 } else {
69 self.modify_entry = new_modify_entry;
70 }
71 }
72}
73
74impl<R: RingStore + Clone> Debug for SparseMapVector<R> {
75 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
76 let mut output = f.debug_map();
77 for (key, value) in self.nontrivial_entries() {
78 _ = output.entry(&key, &self.ring.format(value));
79 }
80 output.finish()
81 }
82}
83
84impl<R: RingStore + Clone> Clone for SparseMapVector<R> {
85 fn clone(&self) -> Self {
86 SparseMapVector {
87 data: self.data.iter().map(|(i, c)| (*i, self.ring.clone_el(c))).collect(),
88 modify_entry: (self.modify_entry.0, self.ring.clone_el(&self.modify_entry.1)),
89 zero: self.ring.clone_el(&self.zero),
90 ring: self.ring.clone(),
91 len: self.len,
92 }
93 }
94}
95
96impl<R: RingStore> VectorView<El<R>> for SparseMapVector<R> {
97 fn at(&self, i: usize) -> &El<R> {
98 assert!(i < self.len());
99 if i == self.modify_entry.0 {
100 &self.modify_entry.1
101 } else if let Some(res) = self.data.get(&i) {
102 res
103 } else {
104 &self.zero
105 }
106 }
107
108 fn len(&self) -> usize { self.len }
109
110 fn specialize_sparse<'a, Op: SparseVectorViewOperation<El<R>>>(&'a self, op: Op) -> Option<Op::Output<'a>> {
111 Some(op.execute(self))
112 }
113}
114
115pub struct SparseMapVectorIter<'a, R>
116where
117 R: RingStore,
118{
119 base: hash_map::Iter<'a, usize, El<R>>,
120 skip: usize,
121 once: Option<&'a El<R>>,
122}
123
124impl<'a, R> Iterator for SparseMapVectorIter<'a, R>
125where
126 R: RingStore,
127{
128 type Item = (usize, &'a El<R>);
129
130 fn next(&mut self) -> Option<Self::Item> {
131 if let Some(start) = self.once {
132 self.once = None;
133 return Some((self.skip, start));
134 } else {
135 for (index, element) in self.base.by_ref() {
136 if *index != self.skip {
137 return Some((*index, element));
138 }
139 }
140 return None;
141 }
142 }
143}
144
145impl<R: RingStore> VectorViewSparse<El<R>> for SparseMapVector<R> {
146 type Iter<'a>
147 = SparseMapVectorIter<'a, R>
148 where
149 Self: 'a;
150
151 fn nontrivial_entries<'a>(&'a self) -> Self::Iter<'a> {
152 SparseMapVectorIter {
153 base: self.data.iter(),
154 skip: self.modify_entry.0,
155 once: if !self.ring.is_zero(&self.modify_entry.1) {
156 Some(&self.modify_entry.1)
157 } else {
158 None
159 },
160 }
161 }
162}
163
164impl<R: RingStore> VectorViewMut<El<R>> for SparseMapVector<R> {
165 fn at_mut(&mut self, i: usize) -> &mut El<R> {
166 assert!(i < self.len());
167 if i == self.modify_entry.0 {
168 return &mut self.modify_entry.1;
169 }
170 let new_value = self.ring.clone_el(self.data.get(&i).unwrap_or(&self.zero));
171 self.enter_in_map((i, new_value));
172 return &mut self.modify_entry.1;
173 }
174}
175
176#[cfg(test)]
177use crate::primitive_int::StaticRing;
178
179#[cfg(test)]
180fn assert_vector_eq<const N: usize>(vec: &SparseMapVector<StaticRing<i64>>, values: [i64; N]) {
181 assert_eq!(vec.len(), N);
182 vec.check_consistency();
183 for i in 0..N {
184 assert_eq!(*vec.at(i), values[i]);
186 }
187}
188
189#[test]
190fn test_at_mut() {
191 let ring = StaticRing::<i64>::RING;
192 let mut vector = SparseMapVector::new(5, ring);
193
194 assert_vector_eq(&mut vector, [0, 0, 0, 0, 0]);
195 let mut entry = vector.at_mut(1);
196 assert_eq!(0, *entry);
197 *entry = 3;
198 assert_vector_eq(&mut vector, [0, 3, 0, 0, 0]);
199
200 entry = vector.at_mut(4);
201 assert_eq!(0, *entry);
202 *entry = -1;
203 assert_vector_eq(&mut vector, [0, 3, 0, 0, -1]);
204
205 entry = vector.at_mut(1);
206 assert_eq!(3, *entry);
207 *entry = 4;
208 assert_vector_eq(&mut vector, [0, 4, 0, 0, -1]);
209
210 entry = vector.at_mut(1);
211 assert_eq!(4, *entry);
212 *entry = 5;
213 assert_vector_eq(&mut vector, [0, 5, 0, 0, -1]);
214
215 entry = vector.at_mut(3);
216 assert_eq!(0, *entry);
217 *entry = 0;
218 assert_vector_eq(&mut vector, [0, 5, 0, 0, -1]);
219}
220
221#[test]
222fn test_nontrivial_entries() {
223 let ring = StaticRing::<i64>::RING;
224 let mut vector = SparseMapVector::new(5, ring);
225 assert_eq!(
226 vector.nontrivial_entries().collect::<HashMap<_, _>>(),
227 [].into_iter().collect()
228 );
229 *vector.at_mut(1) = 3;
230 assert_eq!(
231 vector.nontrivial_entries().collect::<HashMap<_, _>>(),
232 [(1, &3)].into_iter().collect()
233 );
234 *vector.at_mut(4) = -1;
235 assert_eq!(
236 vector.nontrivial_entries().collect::<HashMap<_, _>>(),
237 [(1, &3), (4, &-1)].into_iter().collect()
238 );
239
240 *vector.at_mut(1) = 4;
241 assert_eq!(
242 vector.nontrivial_entries().collect::<HashMap<_, _>>(),
243 [(1, &4), (4, &-1)].into_iter().collect()
244 );
245 *vector.at_mut(1) = 0;
246 assert_eq!(
247 vector.nontrivial_entries().collect::<HashMap<_, _>>(),
248 [(4, &-1)].into_iter().collect()
249 );
250 *vector.at_mut(1) = 5;
251 assert_eq!(
252 vector.nontrivial_entries().collect::<HashMap<_, _>>(),
253 [(1, &5), (4, &-1)].into_iter().collect()
254 );
255
256 *vector.at_mut(3) = 0;
257 assert_eq!(
258 vector.nontrivial_entries().collect::<HashMap<_, _>>(),
259 [(1, &5), (4, &-1)].into_iter().collect()
260 );
261 *vector.at_mut(4) = -2;
262 assert_eq!(
263 vector.nontrivial_entries().collect::<HashMap<_, _>>(),
264 [(1, &5), (4, &-2)].into_iter().collect()
265 );
266
267 *vector.at_mut(1) = 0;
268 assert_eq!(vector.nontrivial_entries().count(), 1);
269 *vector.at_mut(4) = 0;
270 assert_eq!(vector.nontrivial_entries().count(), 0);
271}
272
273#[test]
274fn test_scan() {
275 let ring = StaticRing::<i64>::RING;
276 let mut vector = SparseMapVector::new(5, ring);
277 *vector.at_mut(1) = 2;
278 *vector.at_mut(3) = 1;
279 *vector.at_mut(4) = 0;
280 vector.scan(|_, c| {
281 *c -= 1;
282 });
283 assert_vector_eq(&vector, [0, 1, 0, 0, 0]);
284 assert_eq!(
285 vector.nontrivial_entries().collect::<HashMap<_, _>>(),
286 [(1, &1)].into_iter().collect()
287 );
288}