burn_core/data/dataloader/
strategy.rs

1/// A strategy to batch items.
2pub trait BatchStrategy<I>: Send {
3    /// Adds an item to the strategy.
4    ///
5    /// # Arguments
6    ///
7    /// * `item` - The item to add.
8    fn add(&mut self, item: I);
9
10    /// Batches the items.
11    ///
12    /// # Arguments
13    ///
14    /// * `force` - Whether to force batching.
15    ///
16    /// # Returns
17    ///
18    /// The batched items.
19    fn batch(&mut self, force: bool) -> Option<Vec<I>>;
20
21    /// Creates a new strategy of the same type.
22    ///
23    /// # Returns
24    ///
25    /// The new strategy.
26    fn clone_dyn(&self) -> Box<dyn BatchStrategy<I>>;
27}
28
29/// A strategy to batch items with a fixed batch size.
30pub struct FixBatchStrategy<I> {
31    items: Vec<I>,
32    batch_size: usize,
33}
34
35impl<I> FixBatchStrategy<I> {
36    /// Creates a new strategy to batch items with a fixed batch size.
37    ///
38    /// # Arguments
39    ///
40    /// * `batch_size` - The batch size.
41    ///
42    /// # Returns
43    ///
44    /// The strategy.
45    pub fn new(batch_size: usize) -> Self {
46        FixBatchStrategy {
47            items: Vec::with_capacity(batch_size),
48            batch_size,
49        }
50    }
51}
52
53impl<I: Send + 'static> BatchStrategy<I> for FixBatchStrategy<I> {
54    fn add(&mut self, item: I) {
55        self.items.push(item);
56    }
57
58    fn batch(&mut self, force: bool) -> Option<Vec<I>> {
59        if self.items.len() < self.batch_size && !force {
60            return None;
61        }
62
63        let mut items = Vec::with_capacity(self.batch_size);
64        std::mem::swap(&mut items, &mut self.items);
65
66        if items.is_empty() {
67            return None;
68        }
69
70        Some(items)
71    }
72
73    fn clone_dyn(&self) -> Box<dyn BatchStrategy<I>> {
74        Box::new(Self::new(self.batch_size))
75    }
76}