Skip to main content

entrenar/prune/data_loader/
config.rs

1//! Configuration for calibration data loading.
2
3use serde::{Deserialize, Serialize};
4use std::path::PathBuf;
5
6/// Configuration for calibration data loading.
7#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct CalibrationDataConfig {
9    /// Number of calibration samples to load.
10    num_samples: usize,
11    /// Batch size for calibration.
12    batch_size: usize,
13    /// Sequence length (for text models).
14    sequence_length: usize,
15    /// Dataset name or path.
16    dataset: String,
17    /// Cache directory for downloaded data.
18    cache_dir: Option<PathBuf>,
19    /// Random seed for sampling.
20    seed: u64,
21}
22
23impl Default for CalibrationDataConfig {
24    fn default() -> Self {
25        Self {
26            num_samples: 128,
27            batch_size: 1,
28            sequence_length: 2048,
29            dataset: "c4".to_string(),
30            cache_dir: None,
31            seed: 42,
32        }
33    }
34}
35
36impl CalibrationDataConfig {
37    /// Create a new configuration with default values.
38    pub fn new() -> Self {
39        Self::default()
40    }
41
42    /// Set the number of calibration samples.
43    pub fn with_num_samples(mut self, n: usize) -> Self {
44        self.num_samples = n;
45        self
46    }
47
48    /// Set the batch size.
49    pub fn with_batch_size(mut self, batch_size: usize) -> Self {
50        self.batch_size = batch_size.max(1);
51        self
52    }
53
54    /// Set the sequence length.
55    pub fn with_sequence_length(mut self, len: usize) -> Self {
56        self.sequence_length = len;
57        self
58    }
59
60    /// Set the dataset name or path.
61    pub fn with_dataset(mut self, dataset: impl Into<String>) -> Self {
62        self.dataset = dataset.into();
63        self
64    }
65
66    /// Set the cache directory.
67    pub fn with_cache_dir(mut self, dir: impl Into<PathBuf>) -> Self {
68        self.cache_dir = Some(dir.into());
69        self
70    }
71
72    /// Set the random seed.
73    pub fn with_seed(mut self, seed: u64) -> Self {
74        self.seed = seed;
75        self
76    }
77
78    /// Get the number of samples.
79    pub fn num_samples(&self) -> usize {
80        self.num_samples
81    }
82
83    /// Get the batch size.
84    pub fn batch_size(&self) -> usize {
85        self.batch_size
86    }
87
88    /// Get the sequence length.
89    pub fn sequence_length(&self) -> usize {
90        self.sequence_length
91    }
92
93    /// Get the dataset name.
94    pub fn dataset(&self) -> &str {
95        &self.dataset
96    }
97
98    /// Get the cache directory.
99    pub fn cache_dir(&self) -> Option<&PathBuf> {
100        self.cache_dir.as_ref()
101    }
102
103    /// Get the random seed.
104    pub fn seed(&self) -> u64 {
105        self.seed
106    }
107
108    /// Get the number of batches.
109    pub fn num_batches(&self) -> usize {
110        self.num_samples.div_ceil(self.batch_size)
111    }
112}