1use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
7pub enum ModelArchitecture {
8 Linear,
10 TwoLayer,
12 ThreeLayer,
14 Rmi,
16}
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct LearnedIndexConfig {
21 pub architecture: ModelArchitecture,
23
24 pub rmi_stages: Vec<usize>,
26
27 pub hidden_sizes: Vec<usize>,
29
30 pub error_bound_multiplier: f32,
32
33 pub min_training_examples: usize,
35
36 pub enable_hybrid: bool,
38
39 pub training: TrainingConfig,
41}
42
43impl LearnedIndexConfig {
44 pub fn default_config() -> Self {
46 Self {
47 architecture: ModelArchitecture::TwoLayer,
48 rmi_stages: vec![1, 10, 100],
49 hidden_sizes: vec![64, 32],
50 error_bound_multiplier: 2.0,
51 min_training_examples: 1000,
52 enable_hybrid: true,
53 training: TrainingConfig::default_config(),
54 }
55 }
56
57 pub fn speed_optimized() -> Self {
59 Self {
60 architecture: ModelArchitecture::Linear,
61 rmi_stages: vec![1, 5, 25],
62 hidden_sizes: vec![32],
63 error_bound_multiplier: 3.0,
64 min_training_examples: 500,
65 enable_hybrid: true,
66 training: TrainingConfig::speed_optimized(),
67 }
68 }
69
70 pub fn accuracy_optimized() -> Self {
72 Self {
73 architecture: ModelArchitecture::ThreeLayer,
74 rmi_stages: vec![1, 20, 200],
75 hidden_sizes: vec![128, 64, 32],
76 error_bound_multiplier: 1.5,
77 min_training_examples: 5000,
78 enable_hybrid: true,
79 training: TrainingConfig::accuracy_optimized(),
80 }
81 }
82
83 pub fn rmi_config() -> Self {
85 Self {
86 architecture: ModelArchitecture::Rmi,
87 rmi_stages: vec![1, 100, 10000],
88 hidden_sizes: vec![64, 32],
89 error_bound_multiplier: 2.0,
90 min_training_examples: 10000,
91 enable_hybrid: true,
92 training: TrainingConfig::default_config(),
93 }
94 }
95
96 pub fn validate(&self) -> Result<(), String> {
98 if self.rmi_stages.is_empty() {
99 return Err("RMI stages cannot be empty".to_string());
100 }
101
102 if self.error_bound_multiplier < 1.0 {
103 return Err("Error bound multiplier must be >= 1.0".to_string());
104 }
105
106 if self.min_training_examples < 10 {
107 return Err("Minimum training examples must be >= 10".to_string());
108 }
109
110 self.training.validate()?;
111
112 Ok(())
113 }
114}
115
116impl Default for LearnedIndexConfig {
117 fn default() -> Self {
118 Self::default_config()
119 }
120}
121
122#[derive(Debug, Clone, Serialize, Deserialize)]
124pub struct TrainingConfig {
125 pub learning_rate: f32,
127
128 pub num_epochs: usize,
130
131 pub batch_size: usize,
133
134 pub early_stopping_patience: usize,
136
137 pub loss_function: LossFunction,
139
140 pub optimizer: Optimizer,
142
143 pub use_data_augmentation: bool,
145
146 pub validation_split: f32,
148}
149
150impl TrainingConfig {
151 pub fn default_config() -> Self {
152 Self {
153 learning_rate: 0.001,
154 num_epochs: 100,
155 batch_size: 32,
156 early_stopping_patience: 10,
157 loss_function: LossFunction::MeanSquaredError,
158 optimizer: Optimizer::Adam,
159 use_data_augmentation: false,
160 validation_split: 0.2,
161 }
162 }
163
164 pub fn speed_optimized() -> Self {
165 Self {
166 learning_rate: 0.01,
167 num_epochs: 20,
168 batch_size: 128,
169 early_stopping_patience: 5,
170 loss_function: LossFunction::MeanAbsoluteError,
171 optimizer: Optimizer::Sgd,
172 use_data_augmentation: false,
173 validation_split: 0.1,
174 }
175 }
176
177 pub fn accuracy_optimized() -> Self {
178 Self {
179 learning_rate: 0.0001,
180 num_epochs: 200,
181 batch_size: 16,
182 early_stopping_patience: 20,
183 loss_function: LossFunction::Huber,
184 optimizer: Optimizer::Adam,
185 use_data_augmentation: true,
186 validation_split: 0.3,
187 }
188 }
189
190 pub fn validate(&self) -> Result<(), String> {
191 if self.learning_rate <= 0.0 || self.learning_rate > 1.0 {
192 return Err("Learning rate must be in (0, 1]".to_string());
193 }
194
195 if self.num_epochs == 0 {
196 return Err("Number of epochs must be > 0".to_string());
197 }
198
199 if self.batch_size == 0 {
200 return Err("Batch size must be > 0".to_string());
201 }
202
203 if self.validation_split < 0.0 || self.validation_split >= 1.0 {
204 return Err("Validation split must be in [0, 1)".to_string());
205 }
206
207 Ok(())
208 }
209}
210
211impl Default for TrainingConfig {
212 fn default() -> Self {
213 Self::default_config()
214 }
215}
216
217#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
219pub enum LossFunction {
220 MeanSquaredError,
222 MeanAbsoluteError,
224 Huber,
226 Quantile,
228}
229
230#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
232pub enum Optimizer {
233 Sgd,
235 Adam,
237 RmsProp,
239}
240
241#[cfg(test)]
242mod tests {
243 use super::*;
244
245 #[test]
246 fn test_default_config() {
247 let config = LearnedIndexConfig::default_config();
248 assert!(config.validate().is_ok());
249 assert_eq!(config.architecture, ModelArchitecture::TwoLayer);
250 }
251
252 #[test]
253 fn test_speed_optimized() {
254 let config = LearnedIndexConfig::speed_optimized();
255 assert!(config.validate().is_ok());
256 assert_eq!(config.architecture, ModelArchitecture::Linear);
257 }
258
259 #[test]
260 fn test_accuracy_optimized() {
261 let config = LearnedIndexConfig::accuracy_optimized();
262 assert!(config.validate().is_ok());
263 assert_eq!(config.architecture, ModelArchitecture::ThreeLayer);
264 }
265
266 #[test]
267 fn test_training_config_validation() {
268 let mut config = TrainingConfig::default_config();
269 assert!(config.validate().is_ok());
270
271 config.learning_rate = 0.0;
272 assert!(config.validate().is_err());
273
274 config.learning_rate = 0.001;
275 config.validation_split = 1.0;
276 assert!(config.validate().is_err());
277 }
278}