oxirs_vec/learned_index/
training.rs1use super::config::TrainingConfig;
4use super::types::{LearnedIndexResult, TrainingExample};
5use scirs2_core::random::Random;
6use serde::{Deserialize, Serialize};
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct TrainingStats {
11 pub epochs_completed: usize,
13
14 pub final_loss: f64,
16
17 pub validation_loss: f64,
19
20 pub final_accuracy: f64,
22
23 pub training_time_secs: f64,
25
26 pub early_stopped: bool,
28}
29
30pub struct IndexTrainer {
32 config: TrainingConfig,
33}
34
35impl IndexTrainer {
36 pub fn new(config: TrainingConfig) -> Self {
37 Self { config }
38 }
39
40 #[allow(clippy::ptr_arg)]
42 pub fn train(
43 &self,
44 weights: &mut Vec<Vec<f32>>,
45 biases: &mut Vec<f32>,
46 examples: &[TrainingExample],
47 ) -> LearnedIndexResult<TrainingStats> {
48 let start = std::time::Instant::now();
49
50 let split_idx = (examples.len() as f32 * (1.0 - self.config.validation_split)) as usize;
52 let (train_examples, val_examples) = examples.split_at(split_idx);
53
54 tracing::info!(
55 "Training on {} examples, validating on {}",
56 train_examples.len(),
57 val_examples.len()
58 );
59
60 self.initialize_weights(weights, biases, examples[0].features.len());
62
63 let mut best_val_loss = f64::INFINITY;
64 let mut patience_counter = 0;
65 let mut final_loss = 0.0;
66 let mut validation_loss = 0.0;
67 let mut early_stopped = false;
68
69 for epoch in 0..self.config.num_epochs {
71 let train_loss = self.train_epoch(weights, biases, train_examples)?;
72 let val_loss = self.validate(weights, biases, val_examples)?;
73
74 final_loss = train_loss;
75 validation_loss = val_loss;
76
77 if val_loss < best_val_loss {
79 best_val_loss = val_loss;
80 patience_counter = 0;
81 } else {
82 patience_counter += 1;
83 }
84
85 if patience_counter >= self.config.early_stopping_patience {
86 tracing::info!("Early stopping at epoch {}", epoch);
87 early_stopped = true;
88 break;
89 }
90
91 if epoch % 10 == 0 {
92 tracing::debug!(
93 "Epoch {}: train_loss={:.4}, val_loss={:.4}",
94 epoch,
95 train_loss,
96 val_loss
97 );
98 }
99 }
100
101 let accuracy = self.compute_accuracy(weights, biases, val_examples);
103
104 let elapsed = start.elapsed().as_secs_f64();
105
106 Ok(TrainingStats {
107 epochs_completed: if early_stopped {
108 self.config.num_epochs - patience_counter
109 } else {
110 self.config.num_epochs
111 },
112 final_loss,
113 validation_loss,
114 final_accuracy: accuracy,
115 training_time_secs: elapsed,
116 early_stopped,
117 })
118 }
119
120 fn initialize_weights(
121 &self,
122 weights: &mut Vec<Vec<f32>>,
123 biases: &mut Vec<f32>,
124 input_size: usize,
125 ) {
126 let mut rng = Random::seed(42);
127
128 let hidden_size = 32;
130 let output_size = 1;
131
132 let mut layer1 = Vec::new();
134 for _ in 0..(input_size * hidden_size) {
135 layer1.push(rng.gen_range(-0.1..0.1));
136 }
137 weights.push(layer1);
138
139 let mut layer2 = Vec::new();
141 for _ in 0..(hidden_size * output_size) {
142 layer2.push(rng.gen_range(-0.1..0.1));
143 }
144 weights.push(layer2);
145
146 biases.push(rng.gen_range(-0.1..0.1));
148 biases.push(rng.gen_range(-0.1..0.1));
149 }
150
151 #[allow(clippy::ptr_arg)]
152 fn train_epoch(
153 &self,
154 weights: &mut Vec<Vec<f32>>,
155 biases: &mut Vec<f32>,
156 examples: &[TrainingExample],
157 ) -> LearnedIndexResult<f64> {
158 let mut total_loss = 0.0;
159 let mut rng = Random::seed(42);
160
161 let mut indices: Vec<usize> = (0..examples.len()).collect();
163 for i in (1..indices.len()).rev() {
164 let j = rng.gen_range(0..=i);
165 indices.swap(i, j);
166 }
167
168 for batch_start in (0..examples.len()).step_by(self.config.batch_size) {
170 let batch_end = (batch_start + self.config.batch_size).min(examples.len());
171 let batch_indices = &indices[batch_start..batch_end];
172
173 let mut batch_loss = 0.0;
174
175 for &idx in batch_indices {
176 let example = &examples[idx];
177
178 let prediction = self.forward_simple(&example.features, weights, biases);
180 let target = example.target_position as f32 / examples.len() as f32;
181
182 let loss = (prediction - target).powi(2);
184 batch_loss += loss as f64;
185
186 let gradient = 2.0 * (prediction - target);
188 self.update_weights(weights, biases, gradient, &example.features);
189 }
190
191 total_loss += batch_loss;
192 }
193
194 Ok(total_loss / examples.len() as f64)
195 }
196
197 fn validate(
198 &self,
199 weights: &[Vec<f32>],
200 biases: &[f32],
201 examples: &[TrainingExample],
202 ) -> LearnedIndexResult<f64> {
203 let mut total_loss = 0.0;
204
205 for example in examples {
206 let prediction = self.forward_simple(&example.features, weights, biases);
207 let target = example.target_position as f32 / examples.len() as f32;
208 let loss = (prediction - target).powi(2);
209 total_loss += loss as f64;
210 }
211
212 Ok(total_loss / examples.len() as f64)
213 }
214
215 fn compute_accuracy(
216 &self,
217 weights: &[Vec<f32>],
218 biases: &[f32],
219 examples: &[TrainingExample],
220 ) -> f64 {
221 let mut correct = 0;
222 let tolerance = 0.1; for example in examples {
225 let prediction = self.forward_simple(&example.features, weights, biases);
226 let predicted_pos = (prediction * examples.len() as f32) as usize;
227 let error = predicted_pos.abs_diff(example.target_position) as f32;
228
229 if error / (examples.len() as f32) < tolerance {
230 correct += 1;
231 }
232 }
233
234 correct as f64 / examples.len() as f64
235 }
236
237 fn forward_simple(&self, input: &[f32], _weights: &[Vec<f32>], _biases: &[f32]) -> f32 {
238 let sum: f32 = input.iter().sum();
240 let normalized = sum / input.len() as f32;
241
242 1.0 / (1.0 + (-normalized).exp())
244 }
245
246 fn update_weights(
247 &self,
248 weights: &mut [Vec<f32>],
249 biases: &mut [f32],
250 gradient: f32,
251 _features: &[f32],
252 ) {
253 let lr = self.config.learning_rate * 0.1; for weight_layer in weights.iter_mut() {
257 for w in weight_layer.iter_mut() {
258 *w -= lr * gradient;
259 }
260 }
261
262 for bias in biases.iter_mut() {
263 *bias -= lr * gradient;
264 }
265 }
266}
267
268#[cfg(test)]
269mod tests {
270 use super::*;
271
272 fn create_test_examples(n: usize) -> Vec<TrainingExample> {
273 (0..n)
274 .map(|i| TrainingExample::new(vec![i as f32 / n as f32], i))
275 .collect()
276 }
277
278 #[test]
279 fn test_trainer_creation() {
280 let config = TrainingConfig::default_config();
281 let trainer = IndexTrainer::new(config);
282 assert!(trainer.config.num_epochs > 0);
283 }
284
285 #[test]
286 fn test_training() {
287 let config = TrainingConfig::speed_optimized();
288 let trainer = IndexTrainer::new(config);
289
290 let examples = create_test_examples(100);
291 let mut weights = Vec::new();
292 let mut biases = Vec::new();
293
294 let stats = trainer.train(&mut weights, &mut biases, &examples);
295 assert!(stats.is_ok());
296
297 let stats = stats.unwrap();
298 assert!(stats.epochs_completed > 0);
299 assert!(stats.final_loss >= 0.0);
300 assert!(stats.final_accuracy >= 0.0 && stats.final_accuracy <= 1.0);
301 }
302}