Skip to main content

burn_core/data/dataloader/
strategy.rs

1/// A strategy to batch items.
2pub trait BatchStrategy<I>: Send + Sync {
3    /// Adds an item to the strategy.
4    ///
5    /// # Arguments
6    ///
7    /// * `item` - The item to add.
8    fn add(&mut self, item: I);
9
10    /// Batches the items.
11    ///
12    /// # Arguments
13    ///
14    /// * `force` - Whether to force batching.
15    ///
16    /// # Returns
17    ///
18    /// The batched items.
19    fn batch(&mut self, force: bool) -> Option<Vec<I>>;
20
21    /// Creates a new strategy of the same type.
22    ///
23    /// # Returns
24    ///
25    /// The new strategy.
26    fn clone_dyn(&self) -> Box<dyn BatchStrategy<I>>;
27
28    /// Returns the expected batch size for this strategy.
29    ///
30    /// # Returns
31    ///
32    /// The batch size, or None if the strategy doesn't have a fixed batch size.
33    fn batch_size(&self) -> Option<usize>;
34}
35
36/// A strategy to batch items with a fixed batch size.
37pub struct FixBatchStrategy<I> {
38    items: Vec<I>,
39    batch_size: usize,
40}
41
42impl<I> FixBatchStrategy<I> {
43    /// Creates a new strategy to batch items with a fixed batch size.
44    ///
45    /// # Arguments
46    ///
47    /// * `batch_size` - The batch size.
48    ///
49    /// # Returns
50    ///
51    /// The strategy.
52    pub fn new(batch_size: usize) -> Self {
53        FixBatchStrategy {
54            items: Vec::with_capacity(batch_size),
55            batch_size,
56        }
57    }
58}
59
60impl<I: Send + Sync + 'static> BatchStrategy<I> for FixBatchStrategy<I> {
61    fn add(&mut self, item: I) {
62        self.items.push(item);
63    }
64
65    fn batch(&mut self, force: bool) -> Option<Vec<I>> {
66        if self.items.len() < self.batch_size && !force {
67            return None;
68        }
69
70        let mut items = Vec::with_capacity(self.batch_size);
71        std::mem::swap(&mut items, &mut self.items);
72
73        if items.is_empty() {
74            return None;
75        }
76
77        Some(items)
78    }
79
80    fn clone_dyn(&self) -> Box<dyn BatchStrategy<I>> {
81        Box::new(Self::new(self.batch_size))
82    }
83
84    fn batch_size(&self) -> Option<usize> {
85        Some(self.batch_size)
86    }
87}