use crate::synth_data::decision_outcome::QualitySample;
pub struct SyntheticSampleStream {
samples: Vec<QualitySample>,
batch_size: usize,
cursor: usize,
}
impl SyntheticSampleStream {
pub fn new(samples: Vec<QualitySample>, batch_size: usize) -> Self {
assert!(batch_size >= 1, "batch_size must be >= 1");
Self {
samples,
batch_size,
cursor: 0,
}
}
pub fn len(&self) -> usize {
self.samples.len()
}
pub fn is_empty(&self) -> bool {
self.samples.is_empty()
}
pub fn batch_size(&self) -> usize {
self.batch_size
}
pub fn n_full_batches(&self) -> usize {
self.samples.len() / self.batch_size
}
pub fn final_batch_size(&self) -> usize {
let n = self.samples.len();
if n == 0 {
0
} else {
n - self.n_full_batches() * self.batch_size
}
}
pub fn n_batches(&self) -> usize {
let n = self.samples.len();
if n == 0 { 0 } else { self.n_full_batches() + 1 }
}
pub fn next_batch(&mut self) -> Option<Vec<QualitySample>> {
if self.cursor >= self.samples.len() {
return None;
}
let end = (self.cursor + self.batch_size).min(self.samples.len());
let batch: Vec<QualitySample> = self.samples[self.cursor..end].to_vec();
self.cursor = end;
Some(batch)
}
pub fn reset(&mut self) {
self.cursor = 0;
}
}