rstorch/data/dataset/
chain.rs1use super::{Dataset, IterableDataset};
2
3pub struct Chain<A, B> {
4 a: A,
5 b: B,
6}
7
8impl<A, B> Chain<A, B> {
9 #[inline]
10 #[must_use]
11 pub(in crate::data) fn new(a: A, b: B) -> Self {
12 Self { a, b }
13 }
14}
15
16impl<A, B> Dataset for Chain<A, B>
17where
18 A: Dataset,
19 B: Dataset<Item = A::Item>,
20{
21 type Item = A::Item;
22
23 #[inline]
24 fn get(&self, index: usize) -> Option<Self::Item> {
25 match index < self.a.len() {
26 true => self.a.get(index),
27 false => self.b.get(index - self.a.len()),
28 }
29 }
30
31 #[inline]
32 fn len(&self) -> usize {
33 self.a.len() + self.b.len()
34 }
35
36 #[inline]
37 fn is_empty(&self) -> bool {
38 self.a.is_empty() && self.b.is_empty()
39 }
40}
41
42impl<'a, A, B> IterableDataset<'a> for Chain<A, B>
43where
44 A: IterableDataset<'a>,
45 B: IterableDataset<'a, Item = A::Item>,
46{
47 type Iterator = std::iter::Chain<A::Iterator, B::Iterator>;
48
49 #[inline]
50 fn iter(&'a self) -> Self::Iterator {
51 self.a.iter().chain(self.b.iter())
52 }
53}