use serde::{Deserialize, Serialize};
use std::path::PathBuf;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CalibrationDataConfig {
num_samples: usize,
batch_size: usize,
sequence_length: usize,
dataset: String,
cache_dir: Option<PathBuf>,
seed: u64,
}
impl Default for CalibrationDataConfig {
fn default() -> Self {
Self {
num_samples: 128,
batch_size: 1,
sequence_length: 2048,
dataset: "c4".to_string(),
cache_dir: None,
seed: 42,
}
}
}
impl CalibrationDataConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_num_samples(mut self, n: usize) -> Self {
self.num_samples = n;
self
}
pub fn with_batch_size(mut self, batch_size: usize) -> Self {
self.batch_size = batch_size.max(1);
self
}
pub fn with_sequence_length(mut self, len: usize) -> Self {
self.sequence_length = len;
self
}
pub fn with_dataset(mut self, dataset: impl Into<String>) -> Self {
self.dataset = dataset.into();
self
}
pub fn with_cache_dir(mut self, dir: impl Into<PathBuf>) -> Self {
self.cache_dir = Some(dir.into());
self
}
pub fn with_seed(mut self, seed: u64) -> Self {
self.seed = seed;
self
}
pub fn num_samples(&self) -> usize {
self.num_samples
}
pub fn batch_size(&self) -> usize {
self.batch_size
}
pub fn sequence_length(&self) -> usize {
self.sequence_length
}
pub fn dataset(&self) -> &str {
&self.dataset
}
pub fn cache_dir(&self) -> Option<&PathBuf> {
self.cache_dir.as_ref()
}
pub fn seed(&self) -> u64 {
self.seed
}
pub fn num_batches(&self) -> usize {
self.num_samples.div_ceil(self.batch_size)
}
}