burn_core/data/dataloader/
strategy.rs1pub trait BatchStrategy<I>: Send {
3 fn add(&mut self, item: I);
9
10 fn batch(&mut self, force: bool) -> Option<Vec<I>>;
20
21 fn clone_dyn(&self) -> Box<dyn BatchStrategy<I>>;
27}
28
29pub struct FixBatchStrategy<I> {
31 items: Vec<I>,
32 batch_size: usize,
33}
34
35impl<I> FixBatchStrategy<I> {
36 pub fn new(batch_size: usize) -> Self {
46 FixBatchStrategy {
47 items: Vec::with_capacity(batch_size),
48 batch_size,
49 }
50 }
51}
52
53impl<I: Send + 'static> BatchStrategy<I> for FixBatchStrategy<I> {
54 fn add(&mut self, item: I) {
55 self.items.push(item);
56 }
57
58 fn batch(&mut self, force: bool) -> Option<Vec<I>> {
59 if self.items.len() < self.batch_size && !force {
60 return None;
61 }
62
63 let mut items = Vec::with_capacity(self.batch_size);
64 std::mem::swap(&mut items, &mut self.items);
65
66 if items.is_empty() {
67 return None;
68 }
69
70 Some(items)
71 }
72
73 fn clone_dyn(&self) -> Box<dyn BatchStrategy<I>> {
74 Box::new(Self::new(self.batch_size))
75 }
76}