feanor_math/seq/
sparse.rs

1use std::collections::HashMap;
2use std::collections::hash_map;
3use std::collections::hash_map::Entry;
4
5use crate::ring::*;
6use crate::seq::*;
7
8pub struct SparseMapVector<R: RingStore> {
9    data: HashMap<usize, El<R>>,
10    modify_entry: (usize, El<R>),
11    zero: El<R>,
12    ring: R,
13    len: usize
14}
15
16impl<R: RingStore> SparseMapVector<R> {
17
18    pub fn new(len: usize, ring: R) -> Self {
19        SparseMapVector {
20            data: HashMap::new(), 
21            modify_entry: (usize::MAX, ring.zero()),
22            zero: ring.zero(),
23            ring: ring,
24            len: len
25        }
26    }
27
28    #[stability::unstable(feature = "enable")]
29    pub fn set_len(&mut self, new_len: usize) {
30        if new_len < self.len() {
31            for (i, _) in self.nontrivial_entries() {
32                debug_assert!(i < new_len);
33            }
34        }
35        self.len = new_len;
36    }
37
38    #[stability::unstable(feature = "enable")]
39    pub fn scan<F>(&mut self, mut f: F)
40        where F: FnMut(usize, &mut El<R>)
41    {
42        self.enter_in_map((usize::MAX, self.ring.zero()));
43        self.data.retain(|i, c| {
44            f(*i, c);
45            !self.ring.is_zero(c)
46        });
47    }
48
49    #[cfg(test)]
50    fn check_consistency(&self) {
51        assert!(self.ring.is_zero(&self.modify_entry.1) || self.modify_entry.0 < self.len());
52    }
53
54    fn enter_in_map(&mut self, new_modify_entry: (usize, El<R>)) {
55        if self.modify_entry.0 != usize::MAX {
56            let (index, value) = std::mem::replace(&mut self.modify_entry, new_modify_entry);
57            match self.data.entry(index) {
58                Entry::Occupied(mut e) if !self.ring.is_zero(&value) => { *e.get_mut() = value; },
59                Entry::Occupied(e) => { _ = e.remove(); },
60                Entry::Vacant(e) if !self.ring.is_zero(&value) => { _ = e.insert(value); },
61                Entry::Vacant(_) => {}
62            };
63        } else {
64            self.modify_entry = new_modify_entry;
65        }
66    }
67}
68
69impl<R: RingStore + Clone> Clone for SparseMapVector<R> {
70
71    fn clone(&self) -> Self {
72        SparseMapVector { 
73            data: self.data.iter().map(|(i, c)| (*i, self.ring.clone_el(c))).collect(), 
74            modify_entry: (self.modify_entry.0, self.ring.clone_el(&self.modify_entry.1)), 
75            zero: self.ring.clone_el(&self.zero), 
76            ring: self.ring.clone(), 
77            len: self.len
78        }
79    }
80}
81
82impl<R: RingStore> VectorView<El<R>> for SparseMapVector<R> {
83
84    fn at(&self, i: usize) -> &El<R> {
85        assert!(i < self.len());
86        if i == self.modify_entry.0 {
87            &self.modify_entry.1
88        } else if let Some(res) = self.data.get(&i) {
89            res
90        } else {
91            &self.zero
92        }
93    }
94
95    fn len(&self) -> usize {
96        self.len
97    }
98
99    fn specialize_sparse<'a, Op: SparseVectorViewOperation<El<R>>>(&'a self, op: Op) -> Result<Op::Output<'a>, ()> {
100        Ok(op.execute(self))
101    }
102}
103
104pub struct SparseMapVectorIter<'a, R>
105    where R: RingStore
106{
107    base: hash_map::Iter<'a, usize, El<R>>,
108    skip: usize,
109    once: Option<&'a El<R>>
110}
111
112impl<'a, R> Iterator for SparseMapVectorIter<'a, R>
113    where R: RingStore
114{
115    type Item = (usize, &'a El<R>);
116
117    fn next(&mut self) -> Option<Self::Item> {
118        if let Some(start) = self.once {
119            self.once = None;
120            return Some((self.skip, start));
121        } else {
122            while let Some((index, element)) = self.base.next() {
123                if *index != self.skip {
124                    return Some((*index, element));
125                }
126            }
127            return None;
128        }
129    }
130}
131
132impl<R: RingStore> VectorViewSparse<El<R>> for SparseMapVector<R> {
133
134    type Iter<'a> = SparseMapVectorIter<'a, R>
135        where Self: 'a;
136
137    fn nontrivial_entries<'a>(&'a self) -> Self::Iter<'a> {
138        SparseMapVectorIter {
139            base: self.data.iter(),
140            skip: self.modify_entry.0,
141            once: if !self.ring.is_zero(&self.modify_entry.1) { Some(&self.modify_entry.1) } else { None }
142        }
143    }
144}
145
146impl<R: RingStore> VectorViewMut<El<R>> for SparseMapVector<R> {
147
148    fn at_mut(&mut self, i: usize) -> &mut El<R> {
149        assert!(i < self.len());
150        if i == self.modify_entry.0 {
151            return &mut self.modify_entry.1;
152        }
153        let new_value = self.ring.clone_el(self.data.get(&i).unwrap_or(&self.zero));
154        self.enter_in_map((i, new_value));
155        return &mut self.modify_entry.1;
156    }
157}
158
159#[cfg(test)]
160use crate::primitive_int::StaticRing;
161
162#[cfg(test)]
163fn assert_vector_eq<const N: usize>(vec: &SparseMapVector<StaticRing<i64>>, values: [i64; N]) {
164    assert_eq!(vec.len(), N);
165    vec.check_consistency();
166    for i in 0..N {
167        // at_mut() might change the vector, so don't test that
168        assert_eq!(*vec.at(i), values[i]);
169    }
170}
171
172#[test]
173fn test_at_mut() {
174    let ring = StaticRing::<i64>::RING;
175    let mut vector = SparseMapVector::new(5, ring);
176
177    assert_vector_eq(&mut vector, [0, 0, 0, 0, 0]);
178    let mut entry = vector.at_mut(1);
179    assert_eq!(0, *entry);
180    *entry = 3;
181    assert_vector_eq(&mut vector, [0, 3, 0, 0, 0]);
182
183    entry = vector.at_mut(4);
184    assert_eq!(0, *entry);
185    *entry = -1;
186    assert_vector_eq(&mut vector, [0, 3, 0, 0, -1]);
187    
188    entry = vector.at_mut(1);
189    assert_eq!(3, *entry);
190    *entry = 4;
191    assert_vector_eq(&mut vector, [0, 4, 0, 0, -1]);
192
193    entry = vector.at_mut(1);
194    assert_eq!(4, *entry);
195    *entry = 5;
196    assert_vector_eq(&mut vector, [0, 5, 0, 0, -1]);
197
198    entry = vector.at_mut(3);
199    assert_eq!(0, *entry);
200    *entry = 0;
201    assert_vector_eq(&mut vector, [0, 5, 0, 0, -1]);
202}
203
204#[test]
205fn test_nontrivial_entries() {
206    let ring = StaticRing::<i64>::RING;
207    let mut vector = SparseMapVector::new(5, ring);
208    assert_eq!(vector.nontrivial_entries().collect::<HashMap<_, _>>(), [].into_iter().collect());
209    *vector.at_mut(1) = 3;
210    assert_eq!(vector.nontrivial_entries().collect::<HashMap<_, _>>(), [(1, &3)].into_iter().collect());
211    *vector.at_mut(4) = -1;
212    assert_eq!(vector.nontrivial_entries().collect::<HashMap<_, _>>(), [(1, &3), (4, &-1)].into_iter().collect());
213
214    *vector.at_mut(1) = 4;
215    assert_eq!(vector.nontrivial_entries().collect::<HashMap<_, _>>(), [(1, &4), (4, &-1)].into_iter().collect());
216    *vector.at_mut(1) = 0;
217    assert_eq!(vector.nontrivial_entries().collect::<HashMap<_, _>>(), [(4, &-1)].into_iter().collect());
218    *vector.at_mut(1) = 5;
219    assert_eq!(vector.nontrivial_entries().collect::<HashMap<_, _>>(), [(1, &5), (4, &-1)].into_iter().collect());
220
221    *vector.at_mut(3) = 0;
222    assert_eq!(vector.nontrivial_entries().collect::<HashMap<_, _>>(), [(1, &5), (4, &-1)].into_iter().collect());
223    *vector.at_mut(4) = -2;
224    assert_eq!(vector.nontrivial_entries().collect::<HashMap<_, _>>(), [(1, &5), (4, &-2)].into_iter().collect());
225
226    *vector.at_mut(1) = 0;
227    assert_eq!(vector.nontrivial_entries().count(), 1);
228    *vector.at_mut(4) = 0;
229    assert_eq!(vector.nontrivial_entries().count(), 0);
230}
231
232#[test]
233fn test_scan() {
234    let ring = StaticRing::<i64>::RING;
235    let mut vector = SparseMapVector::new(5, ring);
236    *vector.at_mut(1) = 2;
237    *vector.at_mut(3) = 1;
238    *vector.at_mut(4) = 0;
239    vector.scan(|_, c| {
240        *c -= 1;
241    });
242    assert_vector_eq(&vector, [0, 1, 0, 0, 0]);
243    assert_eq!(vector.nontrivial_entries().collect::<HashMap<_, _>>(), [(1, &1)].into_iter().collect());
244}