burn_dataset/transform/
window.rs

1use std::{cmp::max, marker::PhantomData, num::NonZeroUsize};
2
3use crate::Dataset;
4
5/// Functionality to create a window.
6pub trait Window<I> {
7    /// Creates a window of a collection.
8    ///
9    /// # Returns
10    ///
11    /// A `Vec<I>` representing the window.
12    fn window(&self, current: usize, size: NonZeroUsize) -> Option<Vec<I>>;
13}
14
15impl<I, T: Dataset<I> + ?Sized> Window<I> for T {
16    fn window(&self, current: usize, size: NonZeroUsize) -> Option<Vec<I>> {
17        (current..current + size.get())
18            .map(|x| self.get(x))
19            .collect()
20    }
21}
22
23/// Functionality to create a `WindowsIterator`.
24pub trait Windows<I> {
25    /// Creates and returns an iterator over all the windows of length `size`.
26    fn windows(&self, size: usize) -> WindowsIterator<'_, I>;
27}
28
29impl<I, T: Dataset<I>> Windows<I> for T {
30    /// Is empty if the `Dataset` is shorter than `size`.
31    ///
32    /// # Panics
33    ///
34    /// Panics if `size` is 0.    
35    ///
36    /// # Examples
37    ///
38    /// ```
39    /// use crate::burn_dataset::{
40    ///    transform::{Windows, WindowsDataset},
41    ///    Dataset, InMemDataset,
42    /// };
43    ///
44    /// let items = [1, 2, 3, 4].to_vec();
45    /// let dataset = InMemDataset::new(items.clone());
46    ///
47    /// for window in dataset.windows(2) {
48    ///  // do sth with window
49    /// }
50    /// ```
51    fn windows(&self, size: usize) -> WindowsIterator<'_, I> {
52        let size = NonZeroUsize::new(size).expect("window size must be non-zero");
53        WindowsIterator::new(self, size)
54    }
55}
56
57/// Overlapping windows iterator.
58pub struct WindowsIterator<'a, I> {
59    /// The size of the windows.
60    pub size: NonZeroUsize,
61    current: usize,
62    dataset: &'a dyn Dataset<I>,
63}
64
65impl<'a, I> WindowsIterator<'a, I> {
66    /// Creates a new `WindowsIterator` instance. The windows overlap.
67    /// Is empty if the input `Dataset` is shorter than `size`.
68    ///
69    /// # Parameters
70    ///
71    /// - `dataset`: The dataset over which windows will be created.
72    /// - `size`: The size of the windows.
73    pub fn new(dataset: &'a dyn Dataset<I>, size: NonZeroUsize) -> Self {
74        WindowsIterator {
75            current: 0,
76            dataset,
77            size,
78        }
79    }
80}
81
82impl<I> Iterator for WindowsIterator<'_, I> {
83    type Item = Vec<I>;
84
85    fn next(&mut self) -> Option<Vec<I>> {
86        self.current += 1;
87        self.dataset.window(self.current - 1, self.size)
88    }
89}
90
91impl<I> Clone for WindowsIterator<'_, I> {
92    fn clone(&self) -> Self {
93        WindowsIterator {
94            size: self.size,
95            dataset: self.dataset,
96            current: self.current,
97        }
98    }
99}
100
101/// Dataset designed to work with overlapping windows of data.
102pub struct WindowsDataset<D, I> {
103    /// The size of the windows.
104    pub size: NonZeroUsize,
105    dataset: D,
106    input: PhantomData<I>,
107}
108
109impl<D, I> WindowsDataset<D, I>
110where
111    D: Dataset<I>,
112{
113    /// Creates a new `WindowsDataset` instance. The windows overlap.
114    /// Is empty if the input `Dataset` is shorter than `size`.
115    ///
116    /// # Parameters
117    ///
118    /// - `dataset`: The dataset over which windows will be created.
119    /// - `size`: The size of the windows.
120    pub fn new(dataset: D, size: usize) -> Self
121    where
122        D:,
123    {
124        let size = NonZeroUsize::new(size).expect("window size must be non-zero");
125        WindowsDataset::<D, I> {
126            size,
127            dataset,
128            input: PhantomData,
129        }
130    }
131}
132
133impl<D, I> Dataset<Vec<I>> for WindowsDataset<D, I>
134where
135    D: Dataset<I>,
136    I: Clone + Send + Sync,
137{
138    /// Retrieves a window of items from the dataset.
139    ///
140    /// # Parameters
141    ///
142    /// - `index`: The index of the window.
143    ///
144    /// # Returns
145    ///
146    /// A vector representing the window.
147    fn get(&self, index: usize) -> Option<Vec<I>> {
148        self.dataset.window(index, self.size)
149    }
150
151    /// Retrieves the number of windows in the dataset.
152    ///
153    /// # Returns
154    ///
155    /// A size representing the number of windows.
156    fn len(&self) -> usize {
157        let len = self.dataset.len() as isize - self.size.get() as isize + 1;
158        max(len, 0) as usize
159    }
160}
161
162#[cfg(test)]
163mod tests {
164    use rstest::rstest;
165
166    use crate::{
167        Dataset, InMemDataset,
168        transform::{Windows, WindowsDataset},
169    };
170
171    #[rstest]
172    pub fn windows_should_be_equal_to_vec_windows() {
173        let items = [1, 2, 3, 4, 5].to_vec();
174        let dataset = InMemDataset::new(items.clone());
175        let expected = items
176            .windows(3)
177            .map(|x| x.to_vec())
178            .collect::<Vec<Vec<i32>>>();
179
180        let result = dataset.windows(3).collect::<Vec<Vec<i32>>>();
181
182        assert_eq!(result, expected);
183    }
184
185    #[rstest]
186    pub fn windows_dataset_should_be_equal_to_vec_windows() {
187        let items = [1, 2, 3, 4, 5].to_vec();
188        let dataset = InMemDataset::new(items.clone());
189        let expected = items
190            .windows(3)
191            .map(|x| x.to_vec())
192            .collect::<Vec<Vec<i32>>>();
193
194        let result = WindowsDataset::new(dataset, 3)
195            .iter()
196            .collect::<Vec<Vec<i32>>>();
197
198        assert_eq!(result, expected);
199    }
200
201    #[rstest]
202    pub fn cloned_iterator_should_be_equal() {
203        let items = [1, 2, 3, 4, 5].to_vec();
204        let dataset = InMemDataset::new(items.clone());
205        let original = dataset.windows(4);
206
207        let cloned = original.clone();
208
209        assert!(std::ptr::eq(cloned.dataset, original.dataset));
210        assert_eq!(cloned.size, original.size);
211        assert_eq!(cloned.current, original.current);
212    }
213
214    #[rstest]
215    pub fn cloned_iterator_should_be_unaffected() {
216        let items = [1, 2, 3, 4, 5].to_vec();
217        let dataset = InMemDataset::new(items.clone());
218        let mut original = dataset.windows(4);
219
220        let cloned = original.clone();
221        original.current = 2;
222
223        assert_ne!(cloned.current, original.current);
224    }
225
226    #[rstest]
227    #[should_panic(expected = "window size must be non-zero")]
228    pub fn windows_should_panic() {
229        let items = [1, 2].to_vec();
230        let dataset = InMemDataset::new(items.clone());
231
232        dataset.windows(0);
233    }
234
235    #[rstest]
236    #[should_panic(expected = "window size must be non-zero")]
237    pub fn new_window_dataset_should_panic() {
238        let items = [1, 2].to_vec();
239        let dataset = InMemDataset::new(items.clone());
240
241        WindowsDataset::new(dataset, 0);
242    }
243
244    #[rstest]
245    pub fn window_dataset_len_should_be_equal() {
246        let dataset = InMemDataset::new([1, 2, 3, 4].to_vec());
247
248        let result = WindowsDataset::new(dataset, 2).len();
249
250        assert_eq!(result, 3);
251    }
252
253    #[rstest]
254    pub fn window_iterator_should_be_empty() {
255        let dataset = InMemDataset::new([1, 2].to_vec());
256        let mut peekable = dataset.windows(4).peekable();
257
258        let result = peekable.peek();
259
260        assert_eq!(result, None);
261    }
262
263    #[rstest]
264    pub fn window_dataset_len_should_be_zero() {
265        let dataset = InMemDataset::new([1, 2].to_vec());
266
267        let result = WindowsDataset::new(dataset, 4).len();
268
269        assert_eq!(result, 0);
270    }
271
272    #[rstest]
273    pub fn window_dataset_get_should_be_equal() {
274        let dataset = InMemDataset::new([1, 2, 3, 4].to_vec());
275        let expected = Some([1, 2, 3].to_vec());
276
277        let result = WindowsDataset::new(dataset, 3).get(0);
278
279        assert_eq!(result, expected);
280    }
281
282    #[rstest]
283    pub fn window_dataset_get_should_be_none() {
284        let dataset = InMemDataset::new([1, 2].to_vec());
285
286        let result = WindowsDataset::new(dataset, 4).get(0);
287
288        assert_eq!(result, None);
289    }
290}