Skip to main content

feanor_math/seq/
subvector.rs

1use std::fmt::Debug;
2use std::marker::PhantomData;
3
4use super::{
5    SelfSubvectorFn, SelfSubvectorView, SparseVectorViewOperation, SwappableVectorViewMut, VectorFn, VectorView,
6    VectorViewMut, VectorViewSparse,
7};
8
9pub struct SubvectorView<V: VectorView<T>, T: ?Sized> {
10    begin: usize,
11    end: usize,
12    base: V,
13    element: PhantomData<T>,
14}
15
16impl<V: Clone + VectorView<T>, T: ?Sized> Clone for SubvectorView<V, T> {
17    fn clone(&self) -> Self {
18        Self {
19            begin: self.begin,
20            end: self.end,
21            base: self.base.clone(),
22            element: PhantomData,
23        }
24    }
25}
26
27impl<V: Copy + VectorView<T>, T: ?Sized> Copy for SubvectorView<V, T> {}
28
29impl<V: VectorView<T>, T: ?Sized> SubvectorView<V, T> {
30    pub fn new(base: V) -> Self {
31        Self {
32            begin: 0,
33            end: base.len(),
34            base,
35            element: PhantomData,
36        }
37    }
38}
39
40impl<V: VectorView<T>, T: ?Sized> VectorView<T> for SubvectorView<V, T> {
41    fn at(&self, i: usize) -> &T {
42        assert!(i < self.len());
43        self.base.at(i + self.begin)
44    }
45
46    fn len(&self) -> usize { self.end - self.begin }
47
48    fn specialize_sparse<'a, Op: SparseVectorViewOperation<T>>(&'a self, op: Op) -> Option<Op::Output<'a>> {
49        struct WrapSubvector<T: ?Sized, Op: SparseVectorViewOperation<T>> {
50            op: Op,
51            element: PhantomData<T>,
52            begin: usize,
53            end: usize,
54        }
55
56        impl<T: ?Sized, Op: SparseVectorViewOperation<T>> SparseVectorViewOperation<T> for WrapSubvector<T, Op> {
57            type Output<'a>
58                = Op::Output<'a>
59            where
60                Self: 'a;
61
62            fn execute<'a, V: 'a + VectorViewSparse<T> + Clone>(self, vector: V) -> Self::Output<'a>
63            where
64                Self: 'a,
65            {
66                self.op
67                    .execute(SubvectorView::new(vector).restrict_full(self.begin..self.end))
68            }
69        }
70
71        self.base.specialize_sparse(WrapSubvector {
72            op,
73            element: PhantomData,
74            begin: self.begin,
75            end: self.end,
76        })
77    }
78
79    fn as_slice(&self) -> Option<&[T]>
80    where
81        T: Sized,
82    {
83        self.base.as_slice().map(|slice| &slice[self.begin..self.end])
84    }
85}
86
87impl<V: VectorView<T> + Debug, T: ?Sized> Debug for SubvectorView<V, T> {
88    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89        f.debug_struct("SubvectorView")
90            .field("begin", &self.begin)
91            .field("end", &self.end)
92            .field("base", &self.base)
93            .finish()
94    }
95}
96
97pub struct FilterWithinRangeIter<'a, T: ?Sized, I>
98where
99    T: 'a,
100    I: Iterator<Item = (usize, &'a T)>,
101{
102    it: I,
103    begin: usize,
104    end: usize,
105}
106
107impl<'a, T: ?Sized, I> Iterator for FilterWithinRangeIter<'a, T, I>
108where
109    T: 'a,
110    I: Iterator<Item = (usize, &'a T)>,
111{
112    type Item = (usize, &'a T);
113
114    fn next(&mut self) -> Option<Self::Item> { self.it.by_ref().find(|(i, _)| *i >= self.begin && *i < self.end) }
115}
116
117impl<V: VectorViewSparse<T>, T: ?Sized> VectorViewSparse<T> for SubvectorView<V, T> {
118    type Iter<'a>
119        = FilterWithinRangeIter<'a, T, V::Iter<'a>>
120    where
121        Self: 'a,
122        T: 'a;
123
124    fn nontrivial_entries<'a>(&'a self) -> Self::Iter<'a> {
125        FilterWithinRangeIter {
126            it: self.base.nontrivial_entries(),
127            begin: self.begin,
128            end: self.end,
129        }
130    }
131}
132
133impl<V: VectorViewMut<T>, T: ?Sized> VectorViewMut<T> for SubvectorView<V, T> {
134    fn at_mut(&mut self, i: usize) -> &mut T {
135        assert!(i < self.len());
136        self.base.at_mut(i + self.begin)
137    }
138
139    fn as_slice_mut(&mut self) -> Option<&mut [T]>
140    where
141        T: Sized,
142    {
143        self.base.as_slice_mut().map(|slice| &mut slice[self.begin..self.end])
144    }
145}
146
147impl<V: SwappableVectorViewMut<T>, T: ?Sized> SwappableVectorViewMut<T> for SubvectorView<V, T> {
148    fn swap(&mut self, i: usize, j: usize) {
149        assert!(i < self.len());
150        assert!(j < self.len());
151        self.base.swap(i + self.begin, j + self.begin)
152    }
153}
154
155impl<V: VectorView<T>, T: ?Sized> SelfSubvectorView<T> for SubvectorView<V, T> {
156    fn restrict_full(mut self, range: std::ops::Range<usize>) -> Self {
157        assert!(range.end <= self.len());
158        debug_assert!(range.start <= range.end);
159        self.end = self.begin + range.end;
160        self.begin += range.start;
161        return self;
162    }
163}
164
165pub struct SubvectorFn<V: VectorFn<T>, T> {
166    begin: usize,
167    end: usize,
168    base: V,
169    element: PhantomData<T>,
170}
171
172impl<V: Clone + VectorFn<T>, T> Clone for SubvectorFn<V, T> {
173    fn clone(&self) -> Self {
174        Self {
175            begin: self.begin,
176            end: self.end,
177            base: self.base.clone(),
178            element: PhantomData,
179        }
180    }
181}
182
183impl<V: Copy + VectorFn<T>, T> Copy for SubvectorFn<V, T> {}
184
185impl<V: VectorFn<T>, T> SubvectorFn<V, T> {
186    pub fn new(base: V) -> Self {
187        Self {
188            begin: 0,
189            end: base.len(),
190            base,
191            element: PhantomData,
192        }
193    }
194}
195
196impl<V: VectorFn<T>, T> VectorFn<T> for SubvectorFn<V, T> {
197    fn at(&self, i: usize) -> T {
198        assert!(i < self.len());
199        self.base.at(i + self.begin)
200    }
201
202    fn len(&self) -> usize { self.end - self.begin }
203}
204
205impl<V: VectorFn<T>, T> SelfSubvectorFn<T> for SubvectorFn<V, T> {
206    fn restrict_full(mut self, range: std::ops::Range<usize>) -> Self {
207        assert!(range.end <= self.len());
208        debug_assert!(range.start <= range.end);
209        self.end = self.begin + range.end;
210        self.begin += range.start;
211        return self;
212    }
213}
214
215#[cfg(test)]
216use super::sparse::SparseMapVector;
217#[cfg(test)]
218use crate::primitive_int::StaticRing;
219
220#[test]
221fn test_subvector_ranges() {
222    let a = SubvectorView::new([0, 1, 2, 3, 4]);
223    assert_eq!(3, a.restrict(0..3).len());
224    assert_eq!(3, a.restrict(0..=2).len());
225    assert_eq!(5, a.restrict(0..).len());
226    assert_eq!(5, a.restrict(..).len());
227    assert_eq!(2, a.restrict(3..).len());
228}
229
230#[test]
231fn test_subvector_subvector() {
232    let a = SubvectorView::new([0, 1, 2, 3, 4]);
233    let b = a.restrict(1..4);
234    assert_eq!(3, b.len());
235    assert_eq!(1, *b.at(0));
236    assert_eq!(2, *b.at(1));
237    assert_eq!(3, *b.at(2));
238}
239
240#[test]
241#[should_panic]
242fn test_subvector_subvector_oob() {
243    let a = SubvectorView::new([0, 1, 2, 3, 4]);
244    let b = a.restrict(1..4);
245    _ = b.restrict(0..4);
246}
247
248#[test]
249fn test_subvector_fn_ranges() {
250    let a = SubvectorFn::new([0, 1, 2, 3, 4].clone_els_by(|x| *x));
251    assert_eq!(3, a.restrict(0..3).len());
252    assert_eq!(3, a.restrict(0..=2).len());
253    assert_eq!(5, a.restrict(0..).len());
254    assert_eq!(5, a.restrict(..).len());
255    assert_eq!(2, a.restrict(3..).len());
256}
257
258#[test]
259fn test_subvector_fn_subvector() {
260    let a = SubvectorFn::new([0, 1, 2, 3, 4].clone_els_by(|x| *x));
261    let b = a.restrict(1..4);
262    assert_eq!(3, b.len());
263    assert_eq!(1, b.at(0));
264    assert_eq!(2, b.at(1));
265    assert_eq!(3, b.at(2));
266}
267
268#[test]
269#[should_panic]
270fn test_subvector_fn_subvector_oob() {
271    let a = SubvectorFn::new([0, 1, 2, 3, 4].clone_els_by(|x| *x));
272    let b = a.restrict(1..4);
273    _ = b.restrict(0..4);
274}
275
276#[test]
277fn test_subvector_sparse() {
278    let mut sparse_vector = SparseMapVector::new(1000, StaticRing::<i64>::RING);
279    *sparse_vector.at_mut(6) = 6;
280    *sparse_vector.at_mut(20) = 20;
281    *sparse_vector.at_mut(256) = 256;
282    *sparse_vector.at_mut(257) = 257;
283
284    let subvector = SubvectorView::new(sparse_vector).restrict(20..=256);
285
286    struct Verify;
287
288    impl SparseVectorViewOperation<i64> for Verify {
289        type Output<'a> = ();
290
291        fn execute<'a, V: 'a + VectorViewSparse<i64>>(self, vector: V) -> Self::Output<'a> {
292            assert!(
293                vec![(20, &20), (256, &256)] == vector.nontrivial_entries().collect::<Vec<_>>()
294                    || vec![(256, &256), (20, &20)] == vector.nontrivial_entries().collect::<Vec<_>>()
295            );
296        }
297    }
298
299    subvector.specialize_sparse(Verify).unwrap();
300}