Skip to main content

entrenar/prune/calibrate/
config.rs

1//! Calibration configuration for pruning.
2
3use serde::{Deserialize, Serialize};
4
5/// Configuration for calibration data collection.
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct CalibrationConfig {
8    /// Number of calibration samples to collect.
9    num_samples: usize,
10    /// Sequence length for text models.
11    sequence_length: usize,
12    /// Dataset identifier for calibration.
13    dataset: String,
14    /// Batch size for calibration forward passes.
15    batch_size: usize,
16    /// Whether to normalize activation statistics.
17    normalize: bool,
18}
19
20impl Default for CalibrationConfig {
21    fn default() -> Self {
22        Self {
23            num_samples: 128,
24            sequence_length: 2048,
25            dataset: "c4".to_string(),
26            batch_size: 1,
27            normalize: true,
28        }
29    }
30}
31
32impl CalibrationConfig {
33    /// Create a new calibration configuration with default values.
34    pub fn new() -> Self {
35        Self::default()
36    }
37
38    /// Set the number of calibration samples.
39    pub fn with_num_samples(mut self, n: usize) -> Self {
40        self.num_samples = n;
41        self
42    }
43
44    /// Set the sequence length.
45    pub fn with_sequence_length(mut self, len: usize) -> Self {
46        self.sequence_length = len;
47        self
48    }
49
50    /// Set the dataset name.
51    pub fn with_dataset(mut self, dataset: impl Into<String>) -> Self {
52        self.dataset = dataset.into();
53        self
54    }
55
56    /// Set the batch size.
57    pub fn with_batch_size(mut self, batch_size: usize) -> Self {
58        self.batch_size = batch_size;
59        self
60    }
61
62    /// Set whether to normalize statistics.
63    pub fn with_normalize(mut self, normalize: bool) -> Self {
64        self.normalize = normalize;
65        self
66    }
67
68    /// Get the number of samples.
69    pub fn num_samples(&self) -> usize {
70        self.num_samples
71    }
72
73    /// Get the sequence length.
74    pub fn sequence_length(&self) -> usize {
75        self.sequence_length
76    }
77
78    /// Get the dataset name.
79    pub fn dataset(&self) -> &str {
80        &self.dataset
81    }
82
83    /// Get the batch size.
84    pub fn batch_size(&self) -> usize {
85        self.batch_size
86    }
87
88    /// Check if normalization is enabled.
89    pub fn normalize(&self) -> bool {
90        self.normalize
91    }
92}