1use std::{cmp::max, marker::PhantomData, num::NonZeroUsize};
2
3use crate::Dataset;
4
5pub trait Window<I> {
7 fn window(&self, current: usize, size: NonZeroUsize) -> Option<Vec<I>>;
13}
14
15impl<I, T: Dataset<I> + ?Sized> Window<I> for T {
16 fn window(&self, current: usize, size: NonZeroUsize) -> Option<Vec<I>> {
17 (current..current + size.get())
18 .map(|x| self.get(x))
19 .collect()
20 }
21}
22
23pub trait Windows<I> {
25 fn windows(&self, size: usize) -> WindowsIterator<'_, I>;
27}
28
29impl<I, T: Dataset<I>> Windows<I> for T {
30 fn windows(&self, size: usize) -> WindowsIterator<'_, I> {
52 let size = NonZeroUsize::new(size).expect("window size must be non-zero");
53 WindowsIterator::new(self, size)
54 }
55}
56
57pub struct WindowsIterator<'a, I> {
59 pub size: NonZeroUsize,
61 current: usize,
62 dataset: &'a dyn Dataset<I>,
63}
64
65impl<'a, I> WindowsIterator<'a, I> {
66 pub fn new(dataset: &'a dyn Dataset<I>, size: NonZeroUsize) -> Self {
74 WindowsIterator {
75 current: 0,
76 dataset,
77 size,
78 }
79 }
80}
81
82impl<I> Iterator for WindowsIterator<'_, I> {
83 type Item = Vec<I>;
84
85 fn next(&mut self) -> Option<Vec<I>> {
86 self.current += 1;
87 self.dataset.window(self.current - 1, self.size)
88 }
89}
90
91impl<I> Clone for WindowsIterator<'_, I> {
92 fn clone(&self) -> Self {
93 WindowsIterator {
94 size: self.size,
95 dataset: self.dataset,
96 current: self.current,
97 }
98 }
99}
100
101pub struct WindowsDataset<D, I> {
103 pub size: NonZeroUsize,
105 dataset: D,
106 input: PhantomData<I>,
107}
108
109impl<D, I> WindowsDataset<D, I>
110where
111 D: Dataset<I>,
112{
113 pub fn new(dataset: D, size: usize) -> Self
121 where
122 D:,
123 {
124 let size = NonZeroUsize::new(size).expect("window size must be non-zero");
125 WindowsDataset::<D, I> {
126 size,
127 dataset,
128 input: PhantomData,
129 }
130 }
131}
132
133impl<D, I> Dataset<Vec<I>> for WindowsDataset<D, I>
134where
135 D: Dataset<I>,
136 I: Clone + Send + Sync,
137{
138 fn get(&self, index: usize) -> Option<Vec<I>> {
148 self.dataset.window(index, self.size)
149 }
150
151 fn len(&self) -> usize {
157 let len = self.dataset.len() as isize - self.size.get() as isize + 1;
158 max(len, 0) as usize
159 }
160}
161
162#[cfg(test)]
163mod tests {
164 use rstest::rstest;
165
166 use crate::{
167 Dataset, InMemDataset,
168 transform::{Windows, WindowsDataset},
169 };
170
171 #[rstest]
172 pub fn windows_should_be_equal_to_vec_windows() {
173 let items = [1, 2, 3, 4, 5].to_vec();
174 let dataset = InMemDataset::new(items.clone());
175 let expected = items
176 .windows(3)
177 .map(|x| x.to_vec())
178 .collect::<Vec<Vec<i32>>>();
179
180 let result = dataset.windows(3).collect::<Vec<Vec<i32>>>();
181
182 assert_eq!(result, expected);
183 }
184
185 #[rstest]
186 pub fn windows_dataset_should_be_equal_to_vec_windows() {
187 let items = [1, 2, 3, 4, 5].to_vec();
188 let dataset = InMemDataset::new(items.clone());
189 let expected = items
190 .windows(3)
191 .map(|x| x.to_vec())
192 .collect::<Vec<Vec<i32>>>();
193
194 let result = WindowsDataset::new(dataset, 3)
195 .iter()
196 .collect::<Vec<Vec<i32>>>();
197
198 assert_eq!(result, expected);
199 }
200
201 #[rstest]
202 pub fn cloned_iterator_should_be_equal() {
203 let items = [1, 2, 3, 4, 5].to_vec();
204 let dataset = InMemDataset::new(items.clone());
205 let original = dataset.windows(4);
206
207 let cloned = original.clone();
208
209 assert!(std::ptr::eq(cloned.dataset, original.dataset));
210 assert_eq!(cloned.size, original.size);
211 assert_eq!(cloned.current, original.current);
212 }
213
214 #[rstest]
215 pub fn cloned_iterator_should_be_unaffected() {
216 let items = [1, 2, 3, 4, 5].to_vec();
217 let dataset = InMemDataset::new(items.clone());
218 let mut original = dataset.windows(4);
219
220 let cloned = original.clone();
221 original.current = 2;
222
223 assert_ne!(cloned.current, original.current);
224 }
225
226 #[rstest]
227 #[should_panic(expected = "window size must be non-zero")]
228 pub fn windows_should_panic() {
229 let items = [1, 2].to_vec();
230 let dataset = InMemDataset::new(items.clone());
231
232 dataset.windows(0);
233 }
234
235 #[rstest]
236 #[should_panic(expected = "window size must be non-zero")]
237 pub fn new_window_dataset_should_panic() {
238 let items = [1, 2].to_vec();
239 let dataset = InMemDataset::new(items.clone());
240
241 WindowsDataset::new(dataset, 0);
242 }
243
244 #[rstest]
245 pub fn window_dataset_len_should_be_equal() {
246 let dataset = InMemDataset::new([1, 2, 3, 4].to_vec());
247
248 let result = WindowsDataset::new(dataset, 2).len();
249
250 assert_eq!(result, 3);
251 }
252
253 #[rstest]
254 pub fn window_iterator_should_be_empty() {
255 let dataset = InMemDataset::new([1, 2].to_vec());
256 let mut peekable = dataset.windows(4).peekable();
257
258 let result = peekable.peek();
259
260 assert_eq!(result, None);
261 }
262
263 #[rstest]
264 pub fn window_dataset_len_should_be_zero() {
265 let dataset = InMemDataset::new([1, 2].to_vec());
266
267 let result = WindowsDataset::new(dataset, 4).len();
268
269 assert_eq!(result, 0);
270 }
271
272 #[rstest]
273 pub fn window_dataset_get_should_be_equal() {
274 let dataset = InMemDataset::new([1, 2, 3, 4].to_vec());
275 let expected = Some([1, 2, 3].to_vec());
276
277 let result = WindowsDataset::new(dataset, 3).get(0);
278
279 assert_eq!(result, expected);
280 }
281
282 #[rstest]
283 pub fn window_dataset_get_should_be_none() {
284 let dataset = InMemDataset::new([1, 2].to_vec());
285
286 let result = WindowsDataset::new(dataset, 4).get(0);
287
288 assert_eq!(result, None);
289 }
290}