Skip to main content

feanor_math/seq/
sparse.rs

1use std::collections::hash_map::Entry;
2use std::collections::{HashMap, hash_map};
3
4use crate::ring::*;
5use crate::seq::*;
6
7pub struct SparseMapVector<R: RingStore> {
8    data: HashMap<usize, El<R>>,
9    modify_entry: (usize, El<R>),
10    zero: El<R>,
11    ring: R,
12    len: usize,
13}
14
15impl<R: RingStore> SparseMapVector<R> {
16    pub fn new(len: usize, ring: R) -> Self {
17        SparseMapVector {
18            data: HashMap::new(),
19            modify_entry: (usize::MAX, ring.zero()),
20            zero: ring.zero(),
21            ring,
22            len,
23        }
24    }
25
26    #[stability::unstable(feature = "enable")]
27    pub fn set_len(&mut self, new_len: usize) {
28        if new_len < self.len() {
29            for (i, _) in self.nontrivial_entries() {
30                debug_assert!(i < new_len);
31            }
32        }
33        self.len = new_len;
34    }
35
36    #[stability::unstable(feature = "enable")]
37    pub fn scan<F>(&mut self, mut f: F)
38    where
39        F: FnMut(usize, &mut El<R>),
40    {
41        self.enter_in_map((usize::MAX, self.ring.zero()));
42        self.data.retain(|i, c| {
43            f(*i, c);
44            !self.ring.is_zero(c)
45        });
46    }
47
48    #[cfg(test)]
49    fn check_consistency(&self) {
50        assert!(self.ring.is_zero(&self.modify_entry.1) || self.modify_entry.0 < self.len());
51    }
52
53    fn enter_in_map(&mut self, new_modify_entry: (usize, El<R>)) {
54        if self.modify_entry.0 != usize::MAX {
55            let (index, value) = std::mem::replace(&mut self.modify_entry, new_modify_entry);
56            match self.data.entry(index) {
57                Entry::Occupied(mut e) if !self.ring.is_zero(&value) => {
58                    *e.get_mut() = value;
59                }
60                Entry::Occupied(e) => {
61                    _ = e.remove();
62                }
63                Entry::Vacant(e) if !self.ring.is_zero(&value) => {
64                    _ = e.insert(value);
65                }
66                Entry::Vacant(_) => {}
67            };
68        } else {
69            self.modify_entry = new_modify_entry;
70        }
71    }
72}
73
74impl<R: RingStore + Clone> Debug for SparseMapVector<R> {
75    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
76        let mut output = f.debug_map();
77        for (key, value) in self.nontrivial_entries() {
78            _ = output.entry(&key, &self.ring.format(value));
79        }
80        output.finish()
81    }
82}
83
84impl<R: RingStore + Clone> Clone for SparseMapVector<R> {
85    fn clone(&self) -> Self {
86        SparseMapVector {
87            data: self.data.iter().map(|(i, c)| (*i, self.ring.clone_el(c))).collect(),
88            modify_entry: (self.modify_entry.0, self.ring.clone_el(&self.modify_entry.1)),
89            zero: self.ring.clone_el(&self.zero),
90            ring: self.ring.clone(),
91            len: self.len,
92        }
93    }
94}
95
96impl<R: RingStore> VectorView<El<R>> for SparseMapVector<R> {
97    fn at(&self, i: usize) -> &El<R> {
98        assert!(i < self.len());
99        if i == self.modify_entry.0 {
100            &self.modify_entry.1
101        } else if let Some(res) = self.data.get(&i) {
102            res
103        } else {
104            &self.zero
105        }
106    }
107
108    fn len(&self) -> usize { self.len }
109
110    fn specialize_sparse<'a, Op: SparseVectorViewOperation<El<R>>>(&'a self, op: Op) -> Option<Op::Output<'a>> {
111        Some(op.execute(self))
112    }
113}
114
115pub struct SparseMapVectorIter<'a, R>
116where
117    R: RingStore,
118{
119    base: hash_map::Iter<'a, usize, El<R>>,
120    skip: usize,
121    once: Option<&'a El<R>>,
122}
123
124impl<'a, R> Iterator for SparseMapVectorIter<'a, R>
125where
126    R: RingStore,
127{
128    type Item = (usize, &'a El<R>);
129
130    fn next(&mut self) -> Option<Self::Item> {
131        if let Some(start) = self.once {
132            self.once = None;
133            return Some((self.skip, start));
134        } else {
135            for (index, element) in self.base.by_ref() {
136                if *index != self.skip {
137                    return Some((*index, element));
138                }
139            }
140            return None;
141        }
142    }
143}
144
145impl<R: RingStore> VectorViewSparse<El<R>> for SparseMapVector<R> {
146    type Iter<'a>
147        = SparseMapVectorIter<'a, R>
148    where
149        Self: 'a;
150
151    fn nontrivial_entries<'a>(&'a self) -> Self::Iter<'a> {
152        SparseMapVectorIter {
153            base: self.data.iter(),
154            skip: self.modify_entry.0,
155            once: if !self.ring.is_zero(&self.modify_entry.1) {
156                Some(&self.modify_entry.1)
157            } else {
158                None
159            },
160        }
161    }
162}
163
164impl<R: RingStore> VectorViewMut<El<R>> for SparseMapVector<R> {
165    fn at_mut(&mut self, i: usize) -> &mut El<R> {
166        assert!(i < self.len());
167        if i == self.modify_entry.0 {
168            return &mut self.modify_entry.1;
169        }
170        let new_value = self.ring.clone_el(self.data.get(&i).unwrap_or(&self.zero));
171        self.enter_in_map((i, new_value));
172        return &mut self.modify_entry.1;
173    }
174}
175
176#[cfg(test)]
177use crate::primitive_int::StaticRing;
178
179#[cfg(test)]
180fn assert_vector_eq<const N: usize>(vec: &SparseMapVector<StaticRing<i64>>, values: [i64; N]) {
181    assert_eq!(vec.len(), N);
182    vec.check_consistency();
183    for i in 0..N {
184        // at_mut() might change the vector, so don't test that
185        assert_eq!(*vec.at(i), values[i]);
186    }
187}
188
189#[test]
190fn test_at_mut() {
191    let ring = StaticRing::<i64>::RING;
192    let mut vector = SparseMapVector::new(5, ring);
193
194    assert_vector_eq(&mut vector, [0, 0, 0, 0, 0]);
195    let mut entry = vector.at_mut(1);
196    assert_eq!(0, *entry);
197    *entry = 3;
198    assert_vector_eq(&mut vector, [0, 3, 0, 0, 0]);
199
200    entry = vector.at_mut(4);
201    assert_eq!(0, *entry);
202    *entry = -1;
203    assert_vector_eq(&mut vector, [0, 3, 0, 0, -1]);
204
205    entry = vector.at_mut(1);
206    assert_eq!(3, *entry);
207    *entry = 4;
208    assert_vector_eq(&mut vector, [0, 4, 0, 0, -1]);
209
210    entry = vector.at_mut(1);
211    assert_eq!(4, *entry);
212    *entry = 5;
213    assert_vector_eq(&mut vector, [0, 5, 0, 0, -1]);
214
215    entry = vector.at_mut(3);
216    assert_eq!(0, *entry);
217    *entry = 0;
218    assert_vector_eq(&mut vector, [0, 5, 0, 0, -1]);
219}
220
221#[test]
222fn test_nontrivial_entries() {
223    let ring = StaticRing::<i64>::RING;
224    let mut vector = SparseMapVector::new(5, ring);
225    assert_eq!(
226        vector.nontrivial_entries().collect::<HashMap<_, _>>(),
227        [].into_iter().collect()
228    );
229    *vector.at_mut(1) = 3;
230    assert_eq!(
231        vector.nontrivial_entries().collect::<HashMap<_, _>>(),
232        [(1, &3)].into_iter().collect()
233    );
234    *vector.at_mut(4) = -1;
235    assert_eq!(
236        vector.nontrivial_entries().collect::<HashMap<_, _>>(),
237        [(1, &3), (4, &-1)].into_iter().collect()
238    );
239
240    *vector.at_mut(1) = 4;
241    assert_eq!(
242        vector.nontrivial_entries().collect::<HashMap<_, _>>(),
243        [(1, &4), (4, &-1)].into_iter().collect()
244    );
245    *vector.at_mut(1) = 0;
246    assert_eq!(
247        vector.nontrivial_entries().collect::<HashMap<_, _>>(),
248        [(4, &-1)].into_iter().collect()
249    );
250    *vector.at_mut(1) = 5;
251    assert_eq!(
252        vector.nontrivial_entries().collect::<HashMap<_, _>>(),
253        [(1, &5), (4, &-1)].into_iter().collect()
254    );
255
256    *vector.at_mut(3) = 0;
257    assert_eq!(
258        vector.nontrivial_entries().collect::<HashMap<_, _>>(),
259        [(1, &5), (4, &-1)].into_iter().collect()
260    );
261    *vector.at_mut(4) = -2;
262    assert_eq!(
263        vector.nontrivial_entries().collect::<HashMap<_, _>>(),
264        [(1, &5), (4, &-2)].into_iter().collect()
265    );
266
267    *vector.at_mut(1) = 0;
268    assert_eq!(vector.nontrivial_entries().count(), 1);
269    *vector.at_mut(4) = 0;
270    assert_eq!(vector.nontrivial_entries().count(), 0);
271}
272
273#[test]
274fn test_scan() {
275    let ring = StaticRing::<i64>::RING;
276    let mut vector = SparseMapVector::new(5, ring);
277    *vector.at_mut(1) = 2;
278    *vector.at_mut(3) = 1;
279    *vector.at_mut(4) = 0;
280    vector.scan(|_, c| {
281        *c -= 1;
282    });
283    assert_vector_eq(&vector, [0, 1, 0, 0, 0]);
284    assert_eq!(
285        vector.nontrivial_entries().collect::<HashMap<_, _>>(),
286        [(1, &1)].into_iter().collect()
287    );
288}