1use std::collections::HashMap;
2use std::collections::hash_map;
3use std::collections::hash_map::Entry;
4
5use crate::ring::*;
6use crate::seq::*;
7
8pub struct SparseMapVector<R: RingStore> {
9 data: HashMap<usize, El<R>>,
10 modify_entry: (usize, El<R>),
11 zero: El<R>,
12 ring: R,
13 len: usize
14}
15
16impl<R: RingStore> SparseMapVector<R> {
17
18 pub fn new(len: usize, ring: R) -> Self {
19 SparseMapVector {
20 data: HashMap::new(),
21 modify_entry: (usize::MAX, ring.zero()),
22 zero: ring.zero(),
23 ring: ring,
24 len: len
25 }
26 }
27
28 #[stability::unstable(feature = "enable")]
29 pub fn set_len(&mut self, new_len: usize) {
30 if new_len < self.len() {
31 for (i, _) in self.nontrivial_entries() {
32 debug_assert!(i < new_len);
33 }
34 }
35 self.len = new_len;
36 }
37
38 #[stability::unstable(feature = "enable")]
39 pub fn scan<F>(&mut self, mut f: F)
40 where F: FnMut(usize, &mut El<R>)
41 {
42 self.enter_in_map((usize::MAX, self.ring.zero()));
43 self.data.retain(|i, c| {
44 f(*i, c);
45 !self.ring.is_zero(c)
46 });
47 }
48
49 #[cfg(test)]
50 fn check_consistency(&self) {
51 assert!(self.ring.is_zero(&self.modify_entry.1) || self.modify_entry.0 < self.len());
52 }
53
54 fn enter_in_map(&mut self, new_modify_entry: (usize, El<R>)) {
55 if self.modify_entry.0 != usize::MAX {
56 let (index, value) = std::mem::replace(&mut self.modify_entry, new_modify_entry);
57 match self.data.entry(index) {
58 Entry::Occupied(mut e) if !self.ring.is_zero(&value) => { *e.get_mut() = value; },
59 Entry::Occupied(e) => { _ = e.remove(); },
60 Entry::Vacant(e) if !self.ring.is_zero(&value) => { _ = e.insert(value); },
61 Entry::Vacant(_) => {}
62 };
63 } else {
64 self.modify_entry = new_modify_entry;
65 }
66 }
67}
68
69impl<R: RingStore + Clone> Clone for SparseMapVector<R> {
70
71 fn clone(&self) -> Self {
72 SparseMapVector {
73 data: self.data.iter().map(|(i, c)| (*i, self.ring.clone_el(c))).collect(),
74 modify_entry: (self.modify_entry.0, self.ring.clone_el(&self.modify_entry.1)),
75 zero: self.ring.clone_el(&self.zero),
76 ring: self.ring.clone(),
77 len: self.len
78 }
79 }
80}
81
82impl<R: RingStore> VectorView<El<R>> for SparseMapVector<R> {
83
84 fn at(&self, i: usize) -> &El<R> {
85 assert!(i < self.len());
86 if i == self.modify_entry.0 {
87 &self.modify_entry.1
88 } else if let Some(res) = self.data.get(&i) {
89 res
90 } else {
91 &self.zero
92 }
93 }
94
95 fn len(&self) -> usize {
96 self.len
97 }
98
99 fn specialize_sparse<'a, Op: SparseVectorViewOperation<El<R>>>(&'a self, op: Op) -> Result<Op::Output<'a>, ()> {
100 Ok(op.execute(self))
101 }
102}
103
104pub struct SparseMapVectorIter<'a, R>
105 where R: RingStore
106{
107 base: hash_map::Iter<'a, usize, El<R>>,
108 skip: usize,
109 once: Option<&'a El<R>>
110}
111
112impl<'a, R> Iterator for SparseMapVectorIter<'a, R>
113 where R: RingStore
114{
115 type Item = (usize, &'a El<R>);
116
117 fn next(&mut self) -> Option<Self::Item> {
118 if let Some(start) = self.once {
119 self.once = None;
120 return Some((self.skip, start));
121 } else {
122 while let Some((index, element)) = self.base.next() {
123 if *index != self.skip {
124 return Some((*index, element));
125 }
126 }
127 return None;
128 }
129 }
130}
131
132impl<R: RingStore> VectorViewSparse<El<R>> for SparseMapVector<R> {
133
134 type Iter<'a> = SparseMapVectorIter<'a, R>
135 where Self: 'a;
136
137 fn nontrivial_entries<'a>(&'a self) -> Self::Iter<'a> {
138 SparseMapVectorIter {
139 base: self.data.iter(),
140 skip: self.modify_entry.0,
141 once: if !self.ring.is_zero(&self.modify_entry.1) { Some(&self.modify_entry.1) } else { None }
142 }
143 }
144}
145
146impl<R: RingStore> VectorViewMut<El<R>> for SparseMapVector<R> {
147
148 fn at_mut(&mut self, i: usize) -> &mut El<R> {
149 assert!(i < self.len());
150 if i == self.modify_entry.0 {
151 return &mut self.modify_entry.1;
152 }
153 let new_value = self.ring.clone_el(self.data.get(&i).unwrap_or(&self.zero));
154 self.enter_in_map((i, new_value));
155 return &mut self.modify_entry.1;
156 }
157}
158
159#[cfg(test)]
160use crate::primitive_int::StaticRing;
161
162#[cfg(test)]
163fn assert_vector_eq<const N: usize>(vec: &SparseMapVector<StaticRing<i64>>, values: [i64; N]) {
164 assert_eq!(vec.len(), N);
165 vec.check_consistency();
166 for i in 0..N {
167 assert_eq!(*vec.at(i), values[i]);
169 }
170}
171
172#[test]
173fn test_at_mut() {
174 let ring = StaticRing::<i64>::RING;
175 let mut vector = SparseMapVector::new(5, ring);
176
177 assert_vector_eq(&mut vector, [0, 0, 0, 0, 0]);
178 let mut entry = vector.at_mut(1);
179 assert_eq!(0, *entry);
180 *entry = 3;
181 assert_vector_eq(&mut vector, [0, 3, 0, 0, 0]);
182
183 entry = vector.at_mut(4);
184 assert_eq!(0, *entry);
185 *entry = -1;
186 assert_vector_eq(&mut vector, [0, 3, 0, 0, -1]);
187
188 entry = vector.at_mut(1);
189 assert_eq!(3, *entry);
190 *entry = 4;
191 assert_vector_eq(&mut vector, [0, 4, 0, 0, -1]);
192
193 entry = vector.at_mut(1);
194 assert_eq!(4, *entry);
195 *entry = 5;
196 assert_vector_eq(&mut vector, [0, 5, 0, 0, -1]);
197
198 entry = vector.at_mut(3);
199 assert_eq!(0, *entry);
200 *entry = 0;
201 assert_vector_eq(&mut vector, [0, 5, 0, 0, -1]);
202}
203
204#[test]
205fn test_nontrivial_entries() {
206 let ring = StaticRing::<i64>::RING;
207 let mut vector = SparseMapVector::new(5, ring);
208 assert_eq!(vector.nontrivial_entries().collect::<HashMap<_, _>>(), [].into_iter().collect());
209 *vector.at_mut(1) = 3;
210 assert_eq!(vector.nontrivial_entries().collect::<HashMap<_, _>>(), [(1, &3)].into_iter().collect());
211 *vector.at_mut(4) = -1;
212 assert_eq!(vector.nontrivial_entries().collect::<HashMap<_, _>>(), [(1, &3), (4, &-1)].into_iter().collect());
213
214 *vector.at_mut(1) = 4;
215 assert_eq!(vector.nontrivial_entries().collect::<HashMap<_, _>>(), [(1, &4), (4, &-1)].into_iter().collect());
216 *vector.at_mut(1) = 0;
217 assert_eq!(vector.nontrivial_entries().collect::<HashMap<_, _>>(), [(4, &-1)].into_iter().collect());
218 *vector.at_mut(1) = 5;
219 assert_eq!(vector.nontrivial_entries().collect::<HashMap<_, _>>(), [(1, &5), (4, &-1)].into_iter().collect());
220
221 *vector.at_mut(3) = 0;
222 assert_eq!(vector.nontrivial_entries().collect::<HashMap<_, _>>(), [(1, &5), (4, &-1)].into_iter().collect());
223 *vector.at_mut(4) = -2;
224 assert_eq!(vector.nontrivial_entries().collect::<HashMap<_, _>>(), [(1, &5), (4, &-2)].into_iter().collect());
225
226 *vector.at_mut(1) = 0;
227 assert_eq!(vector.nontrivial_entries().count(), 1);
228 *vector.at_mut(4) = 0;
229 assert_eq!(vector.nontrivial_entries().count(), 0);
230}
231
232#[test]
233fn test_scan() {
234 let ring = StaticRing::<i64>::RING;
235 let mut vector = SparseMapVector::new(5, ring);
236 *vector.at_mut(1) = 2;
237 *vector.at_mut(3) = 1;
238 *vector.at_mut(4) = 0;
239 vector.scan(|_, c| {
240 *c -= 1;
241 });
242 assert_vector_eq(&vector, [0, 1, 0, 0, 0]);
243 assert_eq!(vector.nontrivial_entries().collect::<HashMap<_, _>>(), [(1, &1)].into_iter().collect());
244}