burn_dataset/transform/
mapper.rs1use crate::Dataset;
2use std::marker::PhantomData;
3
4pub trait Mapper<I, O>: Send + Sync {
6 fn map(&self, item: &I) -> O;
8}
9
10#[derive(new)]
12pub struct MapperDataset<D, M, I> {
13 dataset: D,
14 mapper: M,
15 input: PhantomData<I>,
16}
17
18impl<D, M, I, O> Dataset<O> for MapperDataset<D, M, I>
19where
20 D: Dataset<I>,
21 M: Mapper<I, O> + Send + Sync,
22 I: Send + Sync,
23 O: Send + Sync,
24{
25 fn get(&self, index: usize) -> Option<O> {
26 let item = self.dataset.get(index);
27 item.map(|item| self.mapper.map(&item))
28 }
29
30 fn len(&self) -> usize {
31 self.dataset.len()
32 }
33}
34
35#[cfg(test)]
36mod tests {
37 use super::*;
38 use crate::{InMemDataset, test_data};
39
40 #[test]
41 pub fn given_mapper_dataset_when_iterate_should_iterate_though_all_map_items() {
42 struct StringToFirstChar;
43
44 impl Mapper<String, String> for StringToFirstChar {
45 fn map(&self, item: &String) -> String {
46 let mut item = item.clone();
47 item.truncate(1);
48 item
49 }
50 }
51
52 let items_original = test_data::string_items();
53 let dataset = InMemDataset::new(items_original);
54 let dataset = MapperDataset::new(dataset, StringToFirstChar);
55
56 let items: Vec<String> = dataset.iter().collect();
57
58 assert_eq!(vec!["1", "2", "3", "4"], items);
59 }
60}