feanor_math/seq/
sparse.rs

1use std::collections::HashMap;
2use std::collections::hash_map;
3use std::collections::hash_map::Entry;
4
5use crate::ring::*;
6use crate::seq::*;
7
8pub struct SparseMapVector<R: RingStore> {
9    data: HashMap<usize, El<R>>,
10    modify_entry: (usize, El<R>),
11    zero: El<R>,
12    ring: R,
13    len: usize
14}
15
16impl<R: RingStore> SparseMapVector<R> {
17
18    pub fn new(len: usize, ring: R) -> Self {
19        SparseMapVector {
20            data: HashMap::new(), 
21            modify_entry: (usize::MAX, ring.zero()),
22            zero: ring.zero(),
23            ring: ring,
24            len: len
25        }
26    }
27
28    #[stability::unstable(feature = "enable")]
29    pub fn set_len(&mut self, new_len: usize) {
30        if new_len < self.len() {
31            for (i, _) in self.nontrivial_entries() {
32                debug_assert!(i < new_len);
33            }
34        }
35        self.len = new_len;
36    }
37
38    #[stability::unstable(feature = "enable")]
39    pub fn scan<F>(&mut self, mut f: F)
40        where F: FnMut(usize, &mut El<R>)
41    {
42        self.enter_in_map((usize::MAX, self.ring.zero()));
43        self.data.retain(|i, c| {
44            f(*i, c);
45            !self.ring.is_zero(c)
46        });
47    }
48
49    #[cfg(test)]
50    fn check_consistency(&self) {
51        assert!(self.ring.is_zero(&self.modify_entry.1) || self.modify_entry.0 < self.len());
52    }
53
54    fn enter_in_map(&mut self, new_modify_entry: (usize, El<R>)) {
55        if self.modify_entry.0 != usize::MAX {
56            let (index, value) = std::mem::replace(&mut self.modify_entry, new_modify_entry);
57            match self.data.entry(index) {
58                Entry::Occupied(mut e) if !self.ring.is_zero(&value) => { *e.get_mut() = value; },
59                Entry::Occupied(e) => { _ = e.remove(); },
60                Entry::Vacant(e) if !self.ring.is_zero(&value) => { _ = e.insert(value); },
61                Entry::Vacant(_) => {}
62            };
63        } else {
64            self.modify_entry = new_modify_entry;
65        }
66    }
67}
68
69impl<R: RingStore + Clone> Debug for SparseMapVector<R> {
70    
71    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
72        let mut output = f.debug_map();
73        for (key, value) in self.nontrivial_entries() {
74            _ = output.entry(&key, &self.ring.format(value));
75        }
76        output.finish()
77    }
78}
79
80impl<R: RingStore + Clone> Clone for SparseMapVector<R> {
81
82    fn clone(&self) -> Self {
83        SparseMapVector { 
84            data: self.data.iter().map(|(i, c)| (*i, self.ring.clone_el(c))).collect(), 
85            modify_entry: (self.modify_entry.0, self.ring.clone_el(&self.modify_entry.1)), 
86            zero: self.ring.clone_el(&self.zero), 
87            ring: self.ring.clone(), 
88            len: self.len
89        }
90    }
91}
92
93impl<R: RingStore> VectorView<El<R>> for SparseMapVector<R> {
94
95    fn at(&self, i: usize) -> &El<R> {
96        assert!(i < self.len());
97        if i == self.modify_entry.0 {
98            &self.modify_entry.1
99        } else if let Some(res) = self.data.get(&i) {
100            res
101        } else {
102            &self.zero
103        }
104    }
105
106    fn len(&self) -> usize {
107        self.len
108    }
109
110    fn specialize_sparse<'a, Op: SparseVectorViewOperation<El<R>>>(&'a self, op: Op) -> Result<Op::Output<'a>, ()> {
111        Ok(op.execute(self))
112    }
113}
114
115pub struct SparseMapVectorIter<'a, R>
116    where R: RingStore
117{
118    base: hash_map::Iter<'a, usize, El<R>>,
119    skip: usize,
120    once: Option<&'a El<R>>
121}
122
123impl<'a, R> Iterator for SparseMapVectorIter<'a, R>
124    where R: RingStore
125{
126    type Item = (usize, &'a El<R>);
127
128    fn next(&mut self) -> Option<Self::Item> {
129        if let Some(start) = self.once {
130            self.once = None;
131            return Some((self.skip, start));
132        } else {
133            while let Some((index, element)) = self.base.next() {
134                if *index != self.skip {
135                    return Some((*index, element));
136                }
137            }
138            return None;
139        }
140    }
141}
142
143impl<R: RingStore> VectorViewSparse<El<R>> for SparseMapVector<R> {
144
145    type Iter<'a> = SparseMapVectorIter<'a, R>
146        where Self: 'a;
147
148    fn nontrivial_entries<'a>(&'a self) -> Self::Iter<'a> {
149        SparseMapVectorIter {
150            base: self.data.iter(),
151            skip: self.modify_entry.0,
152            once: if !self.ring.is_zero(&self.modify_entry.1) { Some(&self.modify_entry.1) } else { None }
153        }
154    }
155}
156
157impl<R: RingStore> VectorViewMut<El<R>> for SparseMapVector<R> {
158
159    fn at_mut(&mut self, i: usize) -> &mut El<R> {
160        assert!(i < self.len());
161        if i == self.modify_entry.0 {
162            return &mut self.modify_entry.1;
163        }
164        let new_value = self.ring.clone_el(self.data.get(&i).unwrap_or(&self.zero));
165        self.enter_in_map((i, new_value));
166        return &mut self.modify_entry.1;
167    }
168}
169
170#[cfg(test)]
171use crate::primitive_int::StaticRing;
172
173#[cfg(test)]
174fn assert_vector_eq<const N: usize>(vec: &SparseMapVector<StaticRing<i64>>, values: [i64; N]) {
175    assert_eq!(vec.len(), N);
176    vec.check_consistency();
177    for i in 0..N {
178        // at_mut() might change the vector, so don't test that
179        assert_eq!(*vec.at(i), values[i]);
180    }
181}
182
183#[test]
184fn test_at_mut() {
185    let ring = StaticRing::<i64>::RING;
186    let mut vector = SparseMapVector::new(5, ring);
187
188    assert_vector_eq(&mut vector, [0, 0, 0, 0, 0]);
189    let mut entry = vector.at_mut(1);
190    assert_eq!(0, *entry);
191    *entry = 3;
192    assert_vector_eq(&mut vector, [0, 3, 0, 0, 0]);
193
194    entry = vector.at_mut(4);
195    assert_eq!(0, *entry);
196    *entry = -1;
197    assert_vector_eq(&mut vector, [0, 3, 0, 0, -1]);
198    
199    entry = vector.at_mut(1);
200    assert_eq!(3, *entry);
201    *entry = 4;
202    assert_vector_eq(&mut vector, [0, 4, 0, 0, -1]);
203
204    entry = vector.at_mut(1);
205    assert_eq!(4, *entry);
206    *entry = 5;
207    assert_vector_eq(&mut vector, [0, 5, 0, 0, -1]);
208
209    entry = vector.at_mut(3);
210    assert_eq!(0, *entry);
211    *entry = 0;
212    assert_vector_eq(&mut vector, [0, 5, 0, 0, -1]);
213}
214
215#[test]
216fn test_nontrivial_entries() {
217    let ring = StaticRing::<i64>::RING;
218    let mut vector = SparseMapVector::new(5, ring);
219    assert_eq!(vector.nontrivial_entries().collect::<HashMap<_, _>>(), [].into_iter().collect());
220    *vector.at_mut(1) = 3;
221    assert_eq!(vector.nontrivial_entries().collect::<HashMap<_, _>>(), [(1, &3)].into_iter().collect());
222    *vector.at_mut(4) = -1;
223    assert_eq!(vector.nontrivial_entries().collect::<HashMap<_, _>>(), [(1, &3), (4, &-1)].into_iter().collect());
224
225    *vector.at_mut(1) = 4;
226    assert_eq!(vector.nontrivial_entries().collect::<HashMap<_, _>>(), [(1, &4), (4, &-1)].into_iter().collect());
227    *vector.at_mut(1) = 0;
228    assert_eq!(vector.nontrivial_entries().collect::<HashMap<_, _>>(), [(4, &-1)].into_iter().collect());
229    *vector.at_mut(1) = 5;
230    assert_eq!(vector.nontrivial_entries().collect::<HashMap<_, _>>(), [(1, &5), (4, &-1)].into_iter().collect());
231
232    *vector.at_mut(3) = 0;
233    assert_eq!(vector.nontrivial_entries().collect::<HashMap<_, _>>(), [(1, &5), (4, &-1)].into_iter().collect());
234    *vector.at_mut(4) = -2;
235    assert_eq!(vector.nontrivial_entries().collect::<HashMap<_, _>>(), [(1, &5), (4, &-2)].into_iter().collect());
236
237    *vector.at_mut(1) = 0;
238    assert_eq!(vector.nontrivial_entries().count(), 1);
239    *vector.at_mut(4) = 0;
240    assert_eq!(vector.nontrivial_entries().count(), 0);
241}
242
243#[test]
244fn test_scan() {
245    let ring = StaticRing::<i64>::RING;
246    let mut vector = SparseMapVector::new(5, ring);
247    *vector.at_mut(1) = 2;
248    *vector.at_mut(3) = 1;
249    *vector.at_mut(4) = 0;
250    vector.scan(|_, c| {
251        *c -= 1;
252    });
253    assert_vector_eq(&vector, [0, 1, 0, 0, 0]);
254    assert_eq!(vector.nontrivial_entries().collect::<HashMap<_, _>>(), [(1, &1)].into_iter().collect());
255}