entrenar/prune/data_loader/
config.rs1use serde::{Deserialize, Serialize};
4use std::path::PathBuf;
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct CalibrationDataConfig {
9 num_samples: usize,
11 batch_size: usize,
13 sequence_length: usize,
15 dataset: String,
17 cache_dir: Option<PathBuf>,
19 seed: u64,
21}
22
23impl Default for CalibrationDataConfig {
24 fn default() -> Self {
25 Self {
26 num_samples: 128,
27 batch_size: 1,
28 sequence_length: 2048,
29 dataset: "c4".to_string(),
30 cache_dir: None,
31 seed: 42,
32 }
33 }
34}
35
36impl CalibrationDataConfig {
37 pub fn new() -> Self {
39 Self::default()
40 }
41
42 pub fn with_num_samples(mut self, n: usize) -> Self {
44 self.num_samples = n;
45 self
46 }
47
48 pub fn with_batch_size(mut self, batch_size: usize) -> Self {
50 self.batch_size = batch_size.max(1);
51 self
52 }
53
54 pub fn with_sequence_length(mut self, len: usize) -> Self {
56 self.sequence_length = len;
57 self
58 }
59
60 pub fn with_dataset(mut self, dataset: impl Into<String>) -> Self {
62 self.dataset = dataset.into();
63 self
64 }
65
66 pub fn with_cache_dir(mut self, dir: impl Into<PathBuf>) -> Self {
68 self.cache_dir = Some(dir.into());
69 self
70 }
71
72 pub fn with_seed(mut self, seed: u64) -> Self {
74 self.seed = seed;
75 self
76 }
77
78 pub fn num_samples(&self) -> usize {
80 self.num_samples
81 }
82
83 pub fn batch_size(&self) -> usize {
85 self.batch_size
86 }
87
88 pub fn sequence_length(&self) -> usize {
90 self.sequence_length
91 }
92
93 pub fn dataset(&self) -> &str {
95 &self.dataset
96 }
97
98 pub fn cache_dir(&self) -> Option<&PathBuf> {
100 self.cache_dir.as_ref()
101 }
102
103 pub fn seed(&self) -> u64 {
105 self.seed
106 }
107
108 pub fn num_batches(&self) -> usize {
110 self.num_samples.div_ceil(self.batch_size)
111 }
112}