oxirs_vec/learned_index/
training.rs

1//! Training logic for learned indexes
2
3use super::config::TrainingConfig;
4use super::types::{LearnedIndexResult, TrainingExample};
5use scirs2_core::random::Random;
6use serde::{Deserialize, Serialize};
7
8/// Statistics from training
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct TrainingStats {
11    /// Number of epochs completed
12    pub epochs_completed: usize,
13
14    /// Final training loss
15    pub final_loss: f64,
16
17    /// Final validation loss
18    pub validation_loss: f64,
19
20    /// Final accuracy (predictions within error bounds)
21    pub final_accuracy: f64,
22
23    /// Training time (seconds)
24    pub training_time_secs: f64,
25
26    /// Early stopped
27    pub early_stopped: bool,
28}
29
30/// Index trainer
31pub struct IndexTrainer {
32    config: TrainingConfig,
33}
34
35impl IndexTrainer {
36    pub fn new(config: TrainingConfig) -> Self {
37        Self { config }
38    }
39
40    /// Train model weights
41    #[allow(clippy::ptr_arg)]
42    pub fn train(
43        &self,
44        weights: &mut Vec<Vec<f32>>,
45        biases: &mut Vec<f32>,
46        examples: &[TrainingExample],
47    ) -> LearnedIndexResult<TrainingStats> {
48        let start = std::time::Instant::now();
49
50        // Split into train/validation
51        let split_idx = (examples.len() as f32 * (1.0 - self.config.validation_split)) as usize;
52        let (train_examples, val_examples) = examples.split_at(split_idx);
53
54        tracing::info!(
55            "Training on {} examples, validating on {}",
56            train_examples.len(),
57            val_examples.len()
58        );
59
60        // Initialize weights randomly
61        self.initialize_weights(weights, biases, examples[0].features.len());
62
63        let mut best_val_loss = f64::INFINITY;
64        let mut patience_counter = 0;
65        let mut final_loss = 0.0;
66        let mut validation_loss = 0.0;
67        let mut early_stopped = false;
68
69        // Training loop
70        for epoch in 0..self.config.num_epochs {
71            let train_loss = self.train_epoch(weights, biases, train_examples)?;
72            let val_loss = self.validate(weights, biases, val_examples)?;
73
74            final_loss = train_loss;
75            validation_loss = val_loss;
76
77            // Early stopping check
78            if val_loss < best_val_loss {
79                best_val_loss = val_loss;
80                patience_counter = 0;
81            } else {
82                patience_counter += 1;
83            }
84
85            if patience_counter >= self.config.early_stopping_patience {
86                tracing::info!("Early stopping at epoch {}", epoch);
87                early_stopped = true;
88                break;
89            }
90
91            if epoch % 10 == 0 {
92                tracing::debug!(
93                    "Epoch {}: train_loss={:.4}, val_loss={:.4}",
94                    epoch,
95                    train_loss,
96                    val_loss
97                );
98            }
99        }
100
101        // Compute final accuracy
102        let accuracy = self.compute_accuracy(weights, biases, val_examples);
103
104        let elapsed = start.elapsed().as_secs_f64();
105
106        Ok(TrainingStats {
107            epochs_completed: if early_stopped {
108                self.config.num_epochs - patience_counter
109            } else {
110                self.config.num_epochs
111            },
112            final_loss,
113            validation_loss,
114            final_accuracy: accuracy,
115            training_time_secs: elapsed,
116            early_stopped,
117        })
118    }
119
120    fn initialize_weights(
121        &self,
122        weights: &mut Vec<Vec<f32>>,
123        biases: &mut Vec<f32>,
124        input_size: usize,
125    ) {
126        let mut rng = Random::seed(42);
127
128        // Simple architecture: input -> hidden -> output
129        let hidden_size = 32;
130        let output_size = 1;
131
132        // Input to hidden weights
133        let mut layer1 = Vec::new();
134        for _ in 0..(input_size * hidden_size) {
135            layer1.push(rng.gen_range(-0.1..0.1));
136        }
137        weights.push(layer1);
138
139        // Hidden to output weights
140        let mut layer2 = Vec::new();
141        for _ in 0..(hidden_size * output_size) {
142            layer2.push(rng.gen_range(-0.1..0.1));
143        }
144        weights.push(layer2);
145
146        // Biases
147        biases.push(rng.gen_range(-0.1..0.1));
148        biases.push(rng.gen_range(-0.1..0.1));
149    }
150
151    #[allow(clippy::ptr_arg)]
152    fn train_epoch(
153        &self,
154        weights: &mut Vec<Vec<f32>>,
155        biases: &mut Vec<f32>,
156        examples: &[TrainingExample],
157    ) -> LearnedIndexResult<f64> {
158        let mut total_loss = 0.0;
159        let mut rng = Random::seed(42);
160
161        // Shuffle examples
162        let mut indices: Vec<usize> = (0..examples.len()).collect();
163        for i in (1..indices.len()).rev() {
164            let j = rng.gen_range(0..=i);
165            indices.swap(i, j);
166        }
167
168        // Mini-batch training
169        for batch_start in (0..examples.len()).step_by(self.config.batch_size) {
170            let batch_end = (batch_start + self.config.batch_size).min(examples.len());
171            let batch_indices = &indices[batch_start..batch_end];
172
173            let mut batch_loss = 0.0;
174
175            for &idx in batch_indices {
176                let example = &examples[idx];
177
178                // Forward pass (simplified)
179                let prediction = self.forward_simple(&example.features, weights, biases);
180                let target = example.target_position as f32 / examples.len() as f32;
181
182                // Compute loss
183                let loss = (prediction - target).powi(2);
184                batch_loss += loss as f64;
185
186                // Backward pass (simplified gradient descent)
187                let gradient = 2.0 * (prediction - target);
188                self.update_weights(weights, biases, gradient, &example.features);
189            }
190
191            total_loss += batch_loss;
192        }
193
194        Ok(total_loss / examples.len() as f64)
195    }
196
197    fn validate(
198        &self,
199        weights: &[Vec<f32>],
200        biases: &[f32],
201        examples: &[TrainingExample],
202    ) -> LearnedIndexResult<f64> {
203        let mut total_loss = 0.0;
204
205        for example in examples {
206            let prediction = self.forward_simple(&example.features, weights, biases);
207            let target = example.target_position as f32 / examples.len() as f32;
208            let loss = (prediction - target).powi(2);
209            total_loss += loss as f64;
210        }
211
212        Ok(total_loss / examples.len() as f64)
213    }
214
215    fn compute_accuracy(
216        &self,
217        weights: &[Vec<f32>],
218        biases: &[f32],
219        examples: &[TrainingExample],
220    ) -> f64 {
221        let mut correct = 0;
222        let tolerance = 0.1; // 10% error tolerance
223
224        for example in examples {
225            let prediction = self.forward_simple(&example.features, weights, biases);
226            let predicted_pos = (prediction * examples.len() as f32) as usize;
227            let error = predicted_pos.abs_diff(example.target_position) as f32;
228
229            if error / (examples.len() as f32) < tolerance {
230                correct += 1;
231            }
232        }
233
234        correct as f64 / examples.len() as f64
235    }
236
237    fn forward_simple(&self, input: &[f32], _weights: &[Vec<f32>], _biases: &[f32]) -> f32 {
238        // Simplified forward pass
239        let sum: f32 = input.iter().sum();
240        let normalized = sum / input.len() as f32;
241
242        // Apply simple transformation
243        1.0 / (1.0 + (-normalized).exp())
244    }
245
246    fn update_weights(
247        &self,
248        weights: &mut [Vec<f32>],
249        biases: &mut [f32],
250        gradient: f32,
251        _features: &[f32],
252    ) {
253        // Simplified weight update
254        let lr = self.config.learning_rate * 0.1; // Scale down for stability
255
256        for weight_layer in weights.iter_mut() {
257            for w in weight_layer.iter_mut() {
258                *w -= lr * gradient;
259            }
260        }
261
262        for bias in biases.iter_mut() {
263            *bias -= lr * gradient;
264        }
265    }
266}
267
268#[cfg(test)]
269mod tests {
270    use super::*;
271
272    fn create_test_examples(n: usize) -> Vec<TrainingExample> {
273        (0..n)
274            .map(|i| TrainingExample::new(vec![i as f32 / n as f32], i))
275            .collect()
276    }
277
278    #[test]
279    fn test_trainer_creation() {
280        let config = TrainingConfig::default_config();
281        let trainer = IndexTrainer::new(config);
282        assert!(trainer.config.num_epochs > 0);
283    }
284
285    #[test]
286    fn test_training() {
287        let config = TrainingConfig::speed_optimized();
288        let trainer = IndexTrainer::new(config);
289
290        let examples = create_test_examples(100);
291        let mut weights = Vec::new();
292        let mut biases = Vec::new();
293
294        let stats = trainer.train(&mut weights, &mut biases, &examples);
295        assert!(stats.is_ok());
296
297        let stats = stats.unwrap();
298        assert!(stats.epochs_completed > 0);
299        assert!(stats.final_loss >= 0.0);
300        assert!(stats.final_accuracy >= 0.0 && stats.final_accuracy <= 1.0);
301    }
302}