rstorch/data/dataset/
chain.rs

1use super::{Dataset, IterableDataset};
2
3pub struct Chain<A, B> {
4    a: A,
5    b: B,
6}
7
8impl<A, B> Chain<A, B> {
9    #[inline]
10    #[must_use]
11    pub(in crate::data) fn new(a: A, b: B) -> Self {
12        Self { a, b }
13    }
14}
15
16impl<A, B> Dataset for Chain<A, B>
17where
18    A: Dataset,
19    B: Dataset<Item = A::Item>,
20{
21    type Item = A::Item;
22
23    #[inline]
24    fn get(&self, index: usize) -> Option<Self::Item> {
25        match index < self.a.len() {
26            true => self.a.get(index),
27            false => self.b.get(index - self.a.len()),
28        }
29    }
30
31    #[inline]
32    fn len(&self) -> usize {
33        self.a.len() + self.b.len()
34    }
35
36    #[inline]
37    fn is_empty(&self) -> bool {
38        self.a.is_empty() && self.b.is_empty()
39    }
40}
41
42impl<'a, A, B> IterableDataset<'a> for Chain<A, B>
43where
44    A: IterableDataset<'a>,
45    B: IterableDataset<'a, Item = A::Item>,
46{
47    type Iterator = std::iter::Chain<A::Iterator, B::Iterator>;
48
49    #[inline]
50    fn iter(&'a self) -> Self::Iterator {
51        self.a.iter().chain(self.b.iter())
52    }
53}