oxirs_vec/learned_index/
config.rs

1//! Configuration for learned indexes
2
3use serde::{Deserialize, Serialize};
4
5/// Model architecture for learned index
6#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
7pub enum ModelArchitecture {
8    /// Simple linear model (fast, low accuracy)
9    Linear,
10    /// Two-layer neural network
11    TwoLayer,
12    /// Three-layer neural network
13    ThreeLayer,
14    /// Recursive Model Index (RMI) with multiple stages
15    Rmi,
16}
17
18/// Configuration for learned index
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct LearnedIndexConfig {
21    /// Model architecture
22    pub architecture: ModelArchitecture,
23
24    /// Number of models in each RMI stage
25    pub rmi_stages: Vec<usize>,
26
27    /// Hidden layer sizes
28    pub hidden_sizes: Vec<usize>,
29
30    /// Error bound multiplier for search range
31    pub error_bound_multiplier: f32,
32
33    /// Minimum training examples required
34    pub min_training_examples: usize,
35
36    /// Enable hybrid mode (fallback to binary search)
37    pub enable_hybrid: bool,
38
39    /// Training configuration
40    pub training: TrainingConfig,
41}
42
43impl LearnedIndexConfig {
44    /// Create default configuration
45    pub fn default_config() -> Self {
46        Self {
47            architecture: ModelArchitecture::TwoLayer,
48            rmi_stages: vec![1, 10, 100],
49            hidden_sizes: vec![64, 32],
50            error_bound_multiplier: 2.0,
51            min_training_examples: 1000,
52            enable_hybrid: true,
53            training: TrainingConfig::default_config(),
54        }
55    }
56
57    /// Create configuration optimized for speed
58    pub fn speed_optimized() -> Self {
59        Self {
60            architecture: ModelArchitecture::Linear,
61            rmi_stages: vec![1, 5, 25],
62            hidden_sizes: vec![32],
63            error_bound_multiplier: 3.0,
64            min_training_examples: 500,
65            enable_hybrid: true,
66            training: TrainingConfig::speed_optimized(),
67        }
68    }
69
70    /// Create configuration optimized for accuracy
71    pub fn accuracy_optimized() -> Self {
72        Self {
73            architecture: ModelArchitecture::ThreeLayer,
74            rmi_stages: vec![1, 20, 200],
75            hidden_sizes: vec![128, 64, 32],
76            error_bound_multiplier: 1.5,
77            min_training_examples: 5000,
78            enable_hybrid: true,
79            training: TrainingConfig::accuracy_optimized(),
80        }
81    }
82
83    /// Create RMI configuration
84    pub fn rmi_config() -> Self {
85        Self {
86            architecture: ModelArchitecture::Rmi,
87            rmi_stages: vec![1, 100, 10000],
88            hidden_sizes: vec![64, 32],
89            error_bound_multiplier: 2.0,
90            min_training_examples: 10000,
91            enable_hybrid: true,
92            training: TrainingConfig::default_config(),
93        }
94    }
95
96    /// Validate configuration
97    pub fn validate(&self) -> Result<(), String> {
98        if self.rmi_stages.is_empty() {
99            return Err("RMI stages cannot be empty".to_string());
100        }
101
102        if self.error_bound_multiplier < 1.0 {
103            return Err("Error bound multiplier must be >= 1.0".to_string());
104        }
105
106        if self.min_training_examples < 10 {
107            return Err("Minimum training examples must be >= 10".to_string());
108        }
109
110        self.training.validate()?;
111
112        Ok(())
113    }
114}
115
116impl Default for LearnedIndexConfig {
117    fn default() -> Self {
118        Self::default_config()
119    }
120}
121
122/// Training configuration
123#[derive(Debug, Clone, Serialize, Deserialize)]
124pub struct TrainingConfig {
125    /// Learning rate
126    pub learning_rate: f32,
127
128    /// Number of training epochs
129    pub num_epochs: usize,
130
131    /// Batch size
132    pub batch_size: usize,
133
134    /// Early stopping patience (epochs without improvement)
135    pub early_stopping_patience: usize,
136
137    /// Loss function
138    pub loss_function: LossFunction,
139
140    /// Optimizer
141    pub optimizer: Optimizer,
142
143    /// Use data augmentation
144    pub use_data_augmentation: bool,
145
146    /// Validation split (0.0 to 1.0)
147    pub validation_split: f32,
148}
149
150impl TrainingConfig {
151    pub fn default_config() -> Self {
152        Self {
153            learning_rate: 0.001,
154            num_epochs: 100,
155            batch_size: 32,
156            early_stopping_patience: 10,
157            loss_function: LossFunction::MeanSquaredError,
158            optimizer: Optimizer::Adam,
159            use_data_augmentation: false,
160            validation_split: 0.2,
161        }
162    }
163
164    pub fn speed_optimized() -> Self {
165        Self {
166            learning_rate: 0.01,
167            num_epochs: 20,
168            batch_size: 128,
169            early_stopping_patience: 5,
170            loss_function: LossFunction::MeanAbsoluteError,
171            optimizer: Optimizer::Sgd,
172            use_data_augmentation: false,
173            validation_split: 0.1,
174        }
175    }
176
177    pub fn accuracy_optimized() -> Self {
178        Self {
179            learning_rate: 0.0001,
180            num_epochs: 200,
181            batch_size: 16,
182            early_stopping_patience: 20,
183            loss_function: LossFunction::Huber,
184            optimizer: Optimizer::Adam,
185            use_data_augmentation: true,
186            validation_split: 0.3,
187        }
188    }
189
190    pub fn validate(&self) -> Result<(), String> {
191        if self.learning_rate <= 0.0 || self.learning_rate > 1.0 {
192            return Err("Learning rate must be in (0, 1]".to_string());
193        }
194
195        if self.num_epochs == 0 {
196            return Err("Number of epochs must be > 0".to_string());
197        }
198
199        if self.batch_size == 0 {
200            return Err("Batch size must be > 0".to_string());
201        }
202
203        if self.validation_split < 0.0 || self.validation_split >= 1.0 {
204            return Err("Validation split must be in [0, 1)".to_string());
205        }
206
207        Ok(())
208    }
209}
210
211impl Default for TrainingConfig {
212    fn default() -> Self {
213        Self::default_config()
214    }
215}
216
217/// Loss function for training
218#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
219pub enum LossFunction {
220    /// Mean Squared Error
221    MeanSquaredError,
222    /// Mean Absolute Error
223    MeanAbsoluteError,
224    /// Huber loss (robust to outliers)
225    Huber,
226    /// Quantile loss
227    Quantile,
228}
229
230/// Optimizer for training
231#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
232pub enum Optimizer {
233    /// Stochastic Gradient Descent
234    Sgd,
235    /// Adam optimizer
236    Adam,
237    /// RMSprop
238    RmsProp,
239}
240
241#[cfg(test)]
242mod tests {
243    use super::*;
244
245    #[test]
246    fn test_default_config() {
247        let config = LearnedIndexConfig::default_config();
248        assert!(config.validate().is_ok());
249        assert_eq!(config.architecture, ModelArchitecture::TwoLayer);
250    }
251
252    #[test]
253    fn test_speed_optimized() {
254        let config = LearnedIndexConfig::speed_optimized();
255        assert!(config.validate().is_ok());
256        assert_eq!(config.architecture, ModelArchitecture::Linear);
257    }
258
259    #[test]
260    fn test_accuracy_optimized() {
261        let config = LearnedIndexConfig::accuracy_optimized();
262        assert!(config.validate().is_ok());
263        assert_eq!(config.architecture, ModelArchitecture::ThreeLayer);
264    }
265
266    #[test]
267    fn test_training_config_validation() {
268        let mut config = TrainingConfig::default_config();
269        assert!(config.validate().is_ok());
270
271        config.learning_rate = 0.0;
272        assert!(config.validate().is_err());
273
274        config.learning_rate = 0.001;
275        config.validation_split = 1.0;
276        assert!(config.validate().is_err());
277    }
278}