1use crate::codebook::Codebook;
40
41#[derive(Clone, Debug)]
43pub struct PhaseTrainingConfig {
44 pub warmup_steps: usize,
46 pub full_steps: usize,
48 pub predict_steps: usize,
50 pub correct_every: usize,
52 pub learning_rate: f64,
54 pub epochs: usize,
56 pub batch_size: usize,
58}
59
60impl Default for PhaseTrainingConfig {
61 fn default() -> Self {
62 Self {
63 warmup_steps: 50,
64 full_steps: 10,
65 predict_steps: 40,
66 correct_every: 8,
67 learning_rate: 0.001,
68 epochs: 100,
69 batch_size: 32,
70 }
71 }
72}
73
74#[derive(Clone, Debug, Default)]
76pub struct PhaseTrainingStats {
77 pub total_steps: usize,
79 pub full_backprop_steps: usize,
81 pub predicted_steps: usize,
83 pub final_loss: f64,
85 pub speedup: f64,
87}
88
89#[derive(Clone, Copy, Debug, PartialEq, Eq)]
91pub enum TrainingPhase {
92 Warmup,
94 Full,
96 Predict,
98 Correct,
100}
101
102#[derive(Debug)]
104pub struct PhaseTrainer {
105 config: PhaseTrainingConfig,
106 current_step: usize,
107 phase: TrainingPhase,
108 gradient_history: Vec<Vec<f64>>,
109 stats: PhaseTrainingStats,
110}
111
112impl PhaseTrainer {
113 pub fn new(config: PhaseTrainingConfig) -> Result<Self, String> {
121 if config.correct_every == 0 {
122 return Err("correct_every must be > 0".to_string());
123 }
124
125 let cycle_length = config.warmup_steps + config.full_steps + config.predict_steps;
126 if cycle_length == 0 {
127 return Err(
128 "At least one of warmup_steps, full_steps, or predict_steps must be > 0"
129 .to_string(),
130 );
131 }
132
133 Ok(Self {
134 config,
135 current_step: 0,
136 phase: TrainingPhase::Warmup,
137 gradient_history: Vec::new(),
138 stats: PhaseTrainingStats::default(),
139 })
140 }
141
142 pub fn current_phase(&self) -> TrainingPhase {
144 self.phase
145 }
146
147 pub fn begin_step(&mut self) -> TrainingPhase {
149 let cycle_length =
151 self.config.warmup_steps + self.config.full_steps + self.config.predict_steps;
152 let step_in_cycle = self.current_step % cycle_length;
153
154 self.phase = if step_in_cycle < self.config.warmup_steps {
155 TrainingPhase::Warmup
156 } else if step_in_cycle < self.config.warmup_steps + self.config.full_steps {
157 TrainingPhase::Full
158 } else {
159 let predict_step = step_in_cycle - self.config.warmup_steps - self.config.full_steps;
160 if predict_step.is_multiple_of(self.config.correct_every) && predict_step > 0 {
161 TrainingPhase::Correct
162 } else {
163 TrainingPhase::Predict
164 }
165 };
166
167 self.phase
168 }
169
170 pub fn should_compute_full(&self) -> bool {
172 matches!(
173 self.phase,
174 TrainingPhase::Warmup | TrainingPhase::Full | TrainingPhase::Correct
175 )
176 }
177
178 pub fn record_predicted_step(&mut self) {
180 self.stats.predicted_steps += 1;
181 }
182
183 pub fn record_gradients(&mut self, gradients: Vec<f64>) {
185 self.gradient_history.push(gradients);
186 self.stats.full_backprop_steps += 1;
187
188 const MAX_HISTORY: usize = 100;
190 if self.gradient_history.len() > MAX_HISTORY {
191 self.gradient_history.remove(0);
192 }
193 }
194
195 pub fn get_predicted_gradients(&self, param_count: usize) -> Vec<f64> {
201 if self.gradient_history.is_empty() {
202 return vec![0.0; param_count];
203 }
204
205 let history_len = self.gradient_history.len();
209 let mut result = vec![0.0; param_count];
210 let mut total_weight = 0.0;
211
212 for (i, grads) in self.gradient_history.iter().enumerate() {
213 let weight = ((i + 1) as f64 / history_len as f64).powi(2);
215 total_weight += weight;
216
217 for (j, &g) in grads.iter().enumerate() {
218 if j < param_count {
219 result[j] += g * weight;
220 }
221 }
222 }
223
224 if total_weight > 0.0 {
225 for g in &mut result {
226 *g /= total_weight;
227 }
228 }
229
230 result
231 }
232
233 pub fn end_step(&mut self, loss: f64) {
235 self.current_step += 1;
236 self.stats.total_steps += 1;
237 self.stats.final_loss = loss;
238 }
239
240 pub fn stats(&self) -> &PhaseTrainingStats {
242 &self.stats
243 }
244
245 pub fn finalize(&mut self) -> PhaseTrainingStats {
247 let full = self.stats.full_backprop_steps as f64;
249 let predicted = self.stats.predicted_steps as f64;
250 let total = full + predicted;
251
252 if total > 0.0 && predicted > 0.0 {
253 let full_time = full;
255 let predicted_time = predicted * 0.25;
256 let actual_time = full_time + predicted_time;
257 self.stats.speedup = total / actual_time;
258 } else {
259 self.stats.speedup = 1.0;
260 }
261
262 self.stats.clone()
263 }
264}
265
266pub fn compute_reconstruction_loss(codebook: &Codebook, data: &[u8]) -> f64 {
270 if data.is_empty() || codebook.basis_vectors.is_empty() {
271 return 1.0; }
273
274 let projection = codebook.project(data);
276
277 1.0 - projection.quality_score
279}
280
281pub fn compute_basis_gradients(codebook: &Codebook, data: &[u8], _epsilon: f64) -> Vec<f64> {
290 let base_loss = compute_reconstruction_loss(codebook, data);
291 let mut gradients = Vec::new();
292
293 for bv in &codebook.basis_vectors {
295 let vector_norm = (bv.vector.pos.len() + bv.vector.neg.len()) as f64;
301 if vector_norm > 0.0 {
302 gradients.push(base_loss * bv.weight / vector_norm);
304 } else {
305 gradients.push(0.0);
306 }
307 }
308
309 gradients
310}
311
312pub fn apply_gradients(codebook: &mut Codebook, gradients: &[f64], learning_rate: f64) {
314 for (i, bv) in codebook.basis_vectors.iter_mut().enumerate() {
315 if i < gradients.len() {
316 bv.weight -= learning_rate * gradients[i];
318 bv.weight = bv.weight.clamp(0.01, 10.0);
320 }
321 }
322}
323
324pub fn train_codebook_with_phases(
337 codebook: &mut Codebook,
338 training_data: &[&[u8]],
339 config: &PhaseTrainingConfig,
340) -> Result<PhaseTrainingStats, String> {
341 if training_data.is_empty() {
342 return Err("No training data provided".to_string());
343 }
344
345 if codebook.basis_vectors.is_empty() {
346 return Err("Codebook has no basis vectors to train".to_string());
347 }
348
349 if config.batch_size == 0 {
350 return Err("batch_size must be > 0".to_string());
351 }
352
353 if config.correct_every == 0 {
354 return Err("correct_every must be > 0".to_string());
355 }
356
357 let cycle_length = config.warmup_steps + config.full_steps + config.predict_steps;
358 if cycle_length == 0 {
359 return Err(
360 "At least one of warmup_steps, full_steps, or predict_steps must be > 0".to_string(),
361 );
362 }
363
364 let mut trainer = PhaseTrainer::new(config.clone())?;
365 let param_count = codebook.basis_vectors.len();
366
367 for _epoch in 0..config.epochs {
368 for batch_start in (0..training_data.len()).step_by(config.batch_size) {
369 let batch_end = (batch_start + config.batch_size).min(training_data.len());
370 let batch = &training_data[batch_start..batch_end];
371
372 let _phase = trainer.begin_step();
373
374 let gradients = if trainer.should_compute_full() {
376 let mut batch_gradients = vec![0.0; param_count];
378 for sample in batch {
379 let sample_grads = compute_basis_gradients(codebook, sample, 1e-5);
380 for (i, g) in sample_grads.iter().enumerate() {
381 if i < batch_gradients.len() {
382 batch_gradients[i] += g / batch.len() as f64;
383 }
384 }
385 }
386 trainer.record_gradients(batch_gradients.clone());
387 batch_gradients
388 } else {
389 trainer.record_predicted_step();
391 trainer.get_predicted_gradients(param_count)
392 };
393
394 apply_gradients(codebook, &gradients, config.learning_rate);
396
397 let batch_loss: f64 = batch
399 .iter()
400 .map(|s| compute_reconstruction_loss(codebook, s))
401 .sum::<f64>()
402 / batch.len() as f64;
403
404 trainer.end_step(batch_loss);
405 }
406 }
407
408 Ok(trainer.finalize())
409}
410
411#[cfg(test)]
412mod tests {
413 use super::*;
414
415 #[test]
416 fn test_phase_trainer_cycles() {
417 let config = PhaseTrainingConfig {
418 warmup_steps: 2,
419 full_steps: 1,
420 predict_steps: 3,
421 correct_every: 2,
422 ..Default::default()
423 };
424
425 let mut trainer = PhaseTrainer::new(config).expect("valid config");
426
427 assert_eq!(trainer.begin_step(), TrainingPhase::Warmup);
429 trainer.end_step(1.0);
430 assert_eq!(trainer.begin_step(), TrainingPhase::Warmup);
431 trainer.end_step(1.0);
432
433 assert_eq!(trainer.begin_step(), TrainingPhase::Full);
435 trainer.end_step(1.0);
436
437 assert_eq!(trainer.begin_step(), TrainingPhase::Predict);
439 trainer.end_step(1.0);
440
441 assert_eq!(trainer.begin_step(), TrainingPhase::Predict);
443 trainer.end_step(1.0);
444
445 assert_eq!(trainer.begin_step(), TrainingPhase::Correct);
447 }
448
449 #[test]
450 fn test_gradient_prediction() {
451 let config = PhaseTrainingConfig::default();
452 let mut trainer = PhaseTrainer::new(config).expect("valid config");
453
454 trainer.record_gradients(vec![0.1, 0.2, 0.3]);
456 trainer.record_gradients(vec![0.2, 0.3, 0.4]);
457
458 let predicted = trainer.get_predicted_gradients(3);
459 assert_eq!(predicted.len(), 3);
460
461 for g in &predicted {
463 assert!(*g > 0.0);
464 assert!(*g < 1.0);
465 }
466 }
467
468 #[test]
469 fn test_training_config_default() {
470 let config = PhaseTrainingConfig::default();
471 assert_eq!(config.warmup_steps, 50);
472 assert_eq!(config.full_steps, 10);
473 assert_eq!(config.predict_steps, 40);
474 assert_eq!(config.correct_every, 8);
475 }
476
477 #[test]
478 fn test_phase_trainer_rejects_zero_correct_every() {
479 let config = PhaseTrainingConfig {
480 correct_every: 0,
481 ..Default::default()
482 };
483 let result = PhaseTrainer::new(config);
484 assert!(result.is_err());
485 assert!(result.unwrap_err().contains("correct_every"));
486 }
487
488 #[test]
489 fn test_phase_trainer_rejects_zero_cycle_length() {
490 let config = PhaseTrainingConfig {
491 warmup_steps: 0,
492 full_steps: 0,
493 predict_steps: 0,
494 correct_every: 8,
495 ..Default::default()
496 };
497 let result = PhaseTrainer::new(config);
498 assert!(result.is_err());
499 assert!(result.unwrap_err().contains("warmup_steps"));
500 }
501
502 #[test]
503 fn test_train_codebook_rejects_zero_batch_size() {
504 let mut codebook = Codebook::new(1000);
505 codebook.initialize_byte_basis();
506 let data: Vec<&[u8]> = vec![b"test"];
507 let config = PhaseTrainingConfig {
508 batch_size: 0,
509 ..Default::default()
510 };
511 let result = train_codebook_with_phases(&mut codebook, &data, &config);
512 assert!(result.is_err());
513 assert!(result.unwrap_err().contains("batch_size"));
514 }
515
516 #[test]
517 fn test_train_codebook_rejects_empty_data() {
518 let mut codebook = Codebook::new(1000);
519 codebook.initialize_byte_basis();
520 let data: Vec<&[u8]> = vec![];
521 let config = PhaseTrainingConfig::default();
522 let result = train_codebook_with_phases(&mut codebook, &data, &config);
523 assert!(result.is_err());
524 assert!(result.unwrap_err().contains("training data"));
525 }
526
527 #[test]
528 fn test_train_codebook_rejects_empty_codebook() {
529 let mut codebook = Codebook::new(1000);
530 let data: Vec<&[u8]> = vec![b"test"];
532 let config = PhaseTrainingConfig::default();
533 let result = train_codebook_with_phases(&mut codebook, &data, &config);
534 assert!(result.is_err());
535 assert!(result.unwrap_err().contains("basis vectors"));
536 }
537}