burn_core/data/dataloader/
strategy.rs1pub trait BatchStrategy<I>: Send + Sync {
3 fn add(&mut self, item: I);
9
10 fn batch(&mut self, force: bool) -> Option<Vec<I>>;
20
21 fn clone_dyn(&self) -> Box<dyn BatchStrategy<I>>;
27
28 fn batch_size(&self) -> Option<usize>;
34}
35
36pub struct FixBatchStrategy<I> {
38 items: Vec<I>,
39 batch_size: usize,
40}
41
42impl<I> FixBatchStrategy<I> {
43 pub fn new(batch_size: usize) -> Self {
53 FixBatchStrategy {
54 items: Vec::with_capacity(batch_size),
55 batch_size,
56 }
57 }
58}
59
60impl<I: Send + Sync + 'static> BatchStrategy<I> for FixBatchStrategy<I> {
61 fn add(&mut self, item: I) {
62 self.items.push(item);
63 }
64
65 fn batch(&mut self, force: bool) -> Option<Vec<I>> {
66 if self.items.len() < self.batch_size && !force {
67 return None;
68 }
69
70 let mut items = Vec::with_capacity(self.batch_size);
71 std::mem::swap(&mut items, &mut self.items);
72
73 if items.is_empty() {
74 return None;
75 }
76
77 Some(items)
78 }
79
80 fn clone_dyn(&self) -> Box<dyn BatchStrategy<I>> {
81 Box::new(Self::new(self.batch_size))
82 }
83
84 fn batch_size(&self) -> Option<usize> {
85 Some(self.batch_size)
86 }
87}