entrenar/prune/calibrate/
config.rs1use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct CalibrationConfig {
8 num_samples: usize,
10 sequence_length: usize,
12 dataset: String,
14 batch_size: usize,
16 normalize: bool,
18}
19
20impl Default for CalibrationConfig {
21 fn default() -> Self {
22 Self {
23 num_samples: 128,
24 sequence_length: 2048,
25 dataset: "c4".to_string(),
26 batch_size: 1,
27 normalize: true,
28 }
29 }
30}
31
32impl CalibrationConfig {
33 pub fn new() -> Self {
35 Self::default()
36 }
37
38 pub fn with_num_samples(mut self, n: usize) -> Self {
40 self.num_samples = n;
41 self
42 }
43
44 pub fn with_sequence_length(mut self, len: usize) -> Self {
46 self.sequence_length = len;
47 self
48 }
49
50 pub fn with_dataset(mut self, dataset: impl Into<String>) -> Self {
52 self.dataset = dataset.into();
53 self
54 }
55
56 pub fn with_batch_size(mut self, batch_size: usize) -> Self {
58 self.batch_size = batch_size;
59 self
60 }
61
62 pub fn with_normalize(mut self, normalize: bool) -> Self {
64 self.normalize = normalize;
65 self
66 }
67
68 pub fn num_samples(&self) -> usize {
70 self.num_samples
71 }
72
73 pub fn sequence_length(&self) -> usize {
75 self.sequence_length
76 }
77
78 pub fn dataset(&self) -> &str {
80 &self.dataset
81 }
82
83 pub fn batch_size(&self) -> usize {
85 self.batch_size
86 }
87
88 pub fn normalize(&self) -> bool {
90 self.normalize
91 }
92}