1#![allow(dead_code)]
6use crate::model::DeepSeekR1Model;
7use crate::training::data::{TrainingBatch, TrainingExample};
8use crate::training::loss::{CrossEntropyLoss, LossFunction, TrainingMetrics};
9use crate::training::optimizer::{Optimizer, OptimizerConfig};
10use crate::utils::error::{ModelError, Result};
11
12pub struct BasicTrainer {
14 model: DeepSeekR1Model,
15 optimizer: Optimizer,
16 loss_fn: CrossEntropyLoss,
17 step_count: usize,
18 vocab_size: usize,
19}
20
21impl BasicTrainer {
22 pub fn new(model: DeepSeekR1Model) -> Result<Self> {
24 let optimizer_config = OptimizerConfig::default();
25 let optimizer = Optimizer::new(optimizer_config)?;
26 let loss_fn = CrossEntropyLoss::new();
27 let vocab_size = model.config().vocab_size;
28
29 Ok(Self {
30 model,
31 optimizer,
32 loss_fn,
33 step_count: 0,
34 vocab_size,
35 })
36 }
37
38 pub fn with_optimizer_config(
40 model: DeepSeekR1Model,
41 optimizer_config: OptimizerConfig,
42 ) -> Result<Self> {
43 let optimizer = Optimizer::new(optimizer_config)?;
44 let loss_fn = CrossEntropyLoss::new();
45 let vocab_size = model.config().vocab_size;
46
47 Ok(Self {
48 model,
49 optimizer,
50 loss_fn,
51 step_count: 0,
52 vocab_size,
53 })
54 }
55
56 fn tokenize(&self, text: &str) -> Vec<u32> {
58 let binding = text.to_lowercase();
60 let words: Vec<&str> = binding.split_whitespace().collect();
61
62 let mut token_ids = Vec::new();
63
64 for word in words {
65 let mut hash = 0u32;
67 for byte in word.bytes() {
68 hash = hash.wrapping_mul(31).wrapping_add(byte as u32);
69 }
70 token_ids.push(hash % self.vocab_size as u32);
72 }
73
74 if token_ids.is_empty() {
76 token_ids.push(0); }
78
79 token_ids
80 }
81
82 fn prepare_training_data(&self, examples: &[TrainingExample]) -> Result<(Vec<u32>, Vec<u32>)> {
84 let mut input_ids = Vec::new();
85 let mut target_ids = Vec::new();
86
87 for example in examples {
88 let input_tokens = self.tokenize(&example.input);
90 let target_tokens = self.tokenize(&example.target);
91
92 input_ids.extend(input_tokens);
94 target_ids.extend(target_tokens);
95 }
96
97 Ok((input_ids, target_ids))
98 }
99
100 fn prepare_last_step_data(
102 &self,
103 examples: &[TrainingExample],
104 ) -> Result<(Vec<Vec<u32>>, Vec<u32>)> {
105 let mut inputs: Vec<Vec<u32>> = Vec::with_capacity(examples.len());
106 let mut targets: Vec<u32> = Vec::with_capacity(examples.len());
107
108 for example in examples {
109 let input_tokens = self.tokenize(&example.input);
110 let target_tokens = self.tokenize(&example.target);
112 let target_id = target_tokens.get(0).copied().unwrap_or(0);
113
114 inputs.push(input_tokens);
115 targets.push(target_id);
116 }
117
118 Ok((inputs, targets))
119 }
120
121 fn compute_gradients(&self, predictions: &[f32], targets: &[u32]) -> Result<Vec<f32>> {
123 let mut gradients = vec![0.0; predictions.len()];
124 let vocab_size = self.vocab_size;
125 let num_samples = targets.len();
126
127 if predictions.len() != num_samples * vocab_size {
128 return Err(ModelError::Training(format!(
129 "Prediction size mismatch: expected {}, got {}",
130 num_samples * vocab_size,
131 predictions.len()
132 )));
133 }
134
135 for (i, &target) in targets.iter().enumerate() {
136 let start_idx = i * vocab_size;
137 let end_idx = start_idx + vocab_size;
138
139 if end_idx <= predictions.len() && (target as usize) < vocab_size {
140 let logits = &predictions[start_idx..end_idx];
141
142 let max_logit = logits.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
144 let exp_logits: Vec<f32> = logits.iter().map(|&x| (x - max_logit).exp()).collect();
145 let sum_exp: f32 = exp_logits.iter().sum();
146
147 for j in 0..vocab_size {
149 let grad_idx = start_idx + j;
150 let prob = exp_logits[j] / sum_exp;
151
152 if j == target as usize {
153 gradients[grad_idx] = prob - 1.0;
155 } else {
156 gradients[grad_idx] = prob;
158 }
159 }
160 }
161 }
162
163 let batch_size = num_samples as f32;
165 for grad in &mut gradients {
166 *grad /= batch_size;
167 }
168
169 Ok(gradients)
170 }
171
172 pub fn train_step(&mut self, batch: &TrainingBatch) -> Result<TrainingMetrics> {
174 if batch.examples.is_empty() {
175 return Err(ModelError::Training("Empty batch".to_string()));
176 }
177
178 let (inputs_per_example, targets) = self.prepare_last_step_data(&batch.examples)?;
180
181 if inputs_per_example.is_empty() || targets.is_empty() {
182 return Err(ModelError::Training("No valid training data".to_string()));
183 }
184
185 let (predictions, last_hiddens, last_input_ids) =
187 self.forward_last_step(&inputs_per_example)?;
188
189 let target_floats: Vec<f32> = targets.iter().map(|&x| x as f32).collect();
191 let loss = self.loss_fn.compute_loss(&predictions, &target_floats)?;
192
193 let accuracy = self.loss_fn.compute_accuracy(&predictions, &targets);
195
196 let gradients = self.compute_gradients(&predictions, &targets)?;
198
199 self.update_model_parameters(&gradients, &last_hiddens, &last_input_ids)?;
201
202 self.step_count += 1;
203
204 Ok(TrainingMetrics::new(loss, accuracy, self.step_count))
205 }
206
207 fn forward_pass(&mut self, input_ids: &[u32]) -> Result<Vec<f32>> {
209 let logits = self.model.forward(input_ids)?;
211 Ok(logits)
212 }
213
214 fn forward_last_step(
216 &mut self,
217 inputs: &[Vec<u32>],
218 ) -> Result<(Vec<f32>, Vec<Vec<f32>>, Vec<u32>)> {
219 let mut predictions: Vec<f32> = Vec::new();
220 let mut last_hiddens: Vec<Vec<f32>> = Vec::with_capacity(inputs.len());
221 let mut last_input_ids: Vec<u32> = Vec::with_capacity(inputs.len());
222
223 for input in inputs {
224 if input.is_empty() {
225 return Err(ModelError::Training("Empty input sequence".to_string()));
226 }
227
228 let logits = self.model.forward(input)?;
230 let vocab_size = self.vocab_size;
231 if logits.len() < vocab_size {
232 return Err(ModelError::Training(
233 "Model output size doesn't match vocabulary size".to_string(),
234 ));
235 }
236 let last_logits = &logits[logits.len() - vocab_size..];
238 predictions.extend_from_slice(last_logits);
239
240 let hidden = self.model.forward_hidden(input)?;
242 let last_h = hidden
243 .last()
244 .ok_or_else(|| ModelError::Training("No hidden states".to_string()))?
245 .clone();
246 last_hiddens.push(last_h);
247
248 last_input_ids.push(*input.last().unwrap());
250 }
251
252 Ok((predictions, last_hiddens, last_input_ids))
253 }
254
255 fn update_model_parameters(
261 &mut self,
262 gradients: &[f32],
263 hidden: &[Vec<f32>],
264 input_ids: &[u32],
265 ) -> Result<()> {
266 let vocab_size = self.vocab_size;
267 if vocab_size == 0 {
268 return Err(ModelError::Training("Vocab size is zero".to_string()));
269 }
270 if gradients.len() % vocab_size != 0 {
271 return Err(ModelError::Training(format!(
272 "Gradients length {} is not divisible by vocab_size {}",
273 gradients.len(),
274 vocab_size
275 )));
276 }
277 let num_samples = gradients.len() / vocab_size;
278 if hidden.len() != num_samples || input_ids.len() != num_samples {
279 return Err(ModelError::Training(format!(
280 "Mismatch: hidden len {} / input_ids len {} vs samples {}",
281 hidden.len(),
282 input_ids.len(),
283 num_samples
284 )));
285 }
286
287 let mut bias_grads = vec![0.0f32; vocab_size];
289 for (i, chunk) in gradients.chunks(vocab_size).enumerate() {
290 let _ = i; for v in 0..vocab_size {
292 bias_grads[v] += chunk[v];
293 }
294 }
295
296 {
298 let name = "lm_head.bias";
299 let bias_slice = self.model.lm_head_bias_mut();
300 self.optimizer
301 .step_parameter(name, bias_slice, &bias_grads)?;
302 }
303
304 let lm_w_snapshot = self.model.lm_head_weights().clone();
308 if lm_w_snapshot.is_empty() {
309 return Err(ModelError::Training(
310 "LM head weights are empty".to_string(),
311 ));
312 }
313 let hidden_size = lm_w_snapshot[0].len();
314
315 for v in 0..vocab_size {
316 let mut row_grad = vec![0.0f32; hidden_size];
317 for t in 0..num_samples {
318 let g_vt = gradients[t * vocab_size + v];
319 if g_vt != 0.0 {
320 let h_t = &hidden[t];
321 if h_t.len() != hidden_size {
323 return Err(ModelError::Training(format!(
324 "Hidden size {} mismatch at t={} (expected {})",
325 h_t.len(),
326 t,
327 hidden_size
328 )));
329 }
330 for k in 0..hidden_size {
331 row_grad[k] += g_vt * h_t[k];
332 }
333 }
334 }
335 let name = format!("lm_head.weight[{}]", v);
337 let row_slice = self.model.lm_head_row_mut(v)?;
338 self.optimizer.step_parameter(&name, row_slice, &row_grad)?;
339 }
340
341 for t in 0..num_samples {
344 let token_id = input_ids[t] as usize;
345 if token_id >= self.vocab_size {
346 continue; }
348 let dlogits_t = &gradients[t * vocab_size..(t + 1) * vocab_size];
349
350 let mut grad_hidden = vec![0.0f32; hidden_size];
352 for v in 0..vocab_size {
353 let g = dlogits_t[v];
354 if g != 0.0 {
355 let w_row = &lm_w_snapshot[v];
356 for k in 0..hidden_size {
357 grad_hidden[k] += w_row[k] * g;
358 }
359 }
360 }
361
362 let name = format!("embeddings.weight[{}]", token_id);
364 let row_slice = self.model.embedding_row_mut(token_id)?;
365 self.optimizer
366 .step_parameter(&name, row_slice, &grad_hidden)?;
367 }
368
369 Ok(())
370 }
371
372 pub fn step_count(&self) -> usize {
374 self.step_count
375 }
376
377 pub fn evaluate(&mut self, examples: &[TrainingExample]) -> Result<TrainingMetrics> {
379 if examples.is_empty() {
380 return Err(ModelError::Training("Empty evaluation set".to_string()));
381 }
382
383 let (inputs_per_example, targets) = self.prepare_last_step_data(examples)?;
384 let (predictions, _last_hiddens, _last_input_ids) =
385 self.forward_last_step(&inputs_per_example)?;
386
387 let target_floats: Vec<f32> = targets.iter().map(|&x| x as f32).collect();
388 let loss = self.loss_fn.compute_loss(&predictions, &target_floats)?;
389 let accuracy = self.loss_fn.compute_accuracy(&predictions, &targets);
390
391 Ok(TrainingMetrics::new(loss, accuracy, self.step_count))
392 }
393}
394
395pub trait RewardFunction {
397 fn compute_reward(&self, reasoning_chain: &[String], target: &str, predicted: &str) -> f32;
398}
399
400pub struct SimpleRewardFunction;
402
403impl RewardFunction for SimpleRewardFunction {
404 fn compute_reward(&self, reasoning_chain: &[String], target: &str, predicted: &str) -> f32 {
405 let mut reward = 0.0;
406
407 if predicted.trim().to_lowercase() == target.trim().to_lowercase() {
409 reward += 1.0;
410 }
411
412 let reasoning_bonus = self.evaluate_reasoning_quality(reasoning_chain);
414 reward += reasoning_bonus;
415
416 let length_penalty = self.evaluate_reasoning_length(reasoning_chain);
418 reward += length_penalty;
419
420 reward.max(0.0) }
422}
423
424impl SimpleRewardFunction {
425 fn evaluate_reasoning_quality(&self, reasoning_chain: &[String]) -> f32 {
427 if reasoning_chain.is_empty() {
428 return -0.5; }
430
431 let mut quality_score: f32 = 0.0;
432
433 if reasoning_chain.len() >= 2 {
435 quality_score += 0.2;
436 }
437
438 let math_keywords = ["add", "subtract", "multiply", "solve", "equation", "="];
440 let reasoning_text = reasoning_chain.join(" ").to_lowercase();
441
442 for keyword in &math_keywords {
443 if reasoning_text.contains(keyword) {
444 quality_score += 0.1;
445 }
446 }
447
448 let logical_connectors = ["therefore", "since", "because", "so", "thus"];
450 for connector in &logical_connectors {
451 if reasoning_text.contains(connector) {
452 quality_score += 0.1;
453 }
454 }
455
456 quality_score.min(0.5) }
458
459 fn evaluate_reasoning_length(&self, reasoning_chain: &[String]) -> f32 {
461 let length = reasoning_chain.len();
462
463 match length {
464 0 => -0.3, 1 => -0.1, 2..=5 => 0.0, 6..=8 => -0.1, _ => -0.2, }
470 }
471}
472
473#[derive(Debug, Clone)]
475pub struct PolicyGradient {
476 pub action_probs: Vec<f32>,
477 pub rewards: Vec<f32>,
478 pub baseline: f32,
479}
480
481impl PolicyGradient {
482 pub fn new(action_probs: Vec<f32>, rewards: Vec<f32>) -> Self {
484 let baseline = if rewards.is_empty() {
485 0.0
486 } else {
487 rewards.iter().sum::<f32>() / rewards.len() as f32
488 };
489
490 Self {
491 action_probs,
492 rewards,
493 baseline,
494 }
495 }
496
497 pub fn compute_gradients(&self) -> Vec<f32> {
499 let mut gradients = vec![0.0; self.action_probs.len()];
500
501 for (i, (&prob, &reward)) in self
502 .action_probs
503 .iter()
504 .zip(self.rewards.iter())
505 .enumerate()
506 {
507 if prob > 0.0 {
508 let advantage = reward - self.baseline;
510 gradients[i] = advantage / prob; }
512 }
513
514 gradients
515 }
516}
517
518pub struct RLTrainer {
520 model: DeepSeekR1Model,
521 optimizer: Optimizer,
522 reward_fn: SimpleRewardFunction,
523 step_count: usize,
524 vocab_size: usize,
525 baseline_history: Vec<f32>,
526 max_baseline_history: usize,
527}
528
529impl RLTrainer {
530 pub fn new(model: DeepSeekR1Model) -> Result<Self> {
532 let optimizer_config = OptimizerConfig {
533 learning_rate: 1e-5, ..OptimizerConfig::default()
535 };
536 let optimizer = Optimizer::new(optimizer_config)?;
537 let reward_fn = SimpleRewardFunction;
538 let vocab_size = model.config().vocab_size;
539
540 Ok(Self {
541 model,
542 optimizer,
543 reward_fn,
544 step_count: 0,
545 vocab_size,
546 baseline_history: Vec::new(),
547 max_baseline_history: 100,
548 })
549 }
550
551 pub fn with_optimizer_config(
553 model: DeepSeekR1Model,
554 optimizer_config: OptimizerConfig,
555 ) -> Result<Self> {
556 let optimizer = Optimizer::new(optimizer_config)?;
557 let reward_fn = SimpleRewardFunction;
558 let vocab_size = model.config().vocab_size;
559
560 Ok(Self {
561 model,
562 optimizer,
563 reward_fn,
564 step_count: 0,
565 vocab_size,
566 baseline_history: Vec::new(),
567 max_baseline_history: 100,
568 })
569 }
570
571 fn generate_response_with_reasoning(&mut self, input: &str) -> Result<(String, Vec<String>)> {
573 let input_tokens = self.tokenize(input);
575 let logits = self.model.forward(&input_tokens)?;
577 let response_token = logits
579 .iter()
580 .enumerate()
581 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
582 .map(|(idx, _)| idx)
583 .unwrap_or(0) as u32;
584 let response = self.decode(&[response_token]);
585 let reasoning_chain = Vec::new(); Ok((response, reasoning_chain))
588 }
589
590 fn compute_action_probabilities(&mut self, input: &str) -> Result<Vec<f32>> {
592 let input_tokens = self.tokenize(input);
594 let logits = self.model.forward(&input_tokens)?;
595 let max_logit = logits.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
596 let exp_logits: Vec<f32> = logits.iter().map(|&x| (x - max_logit).exp()).collect();
597 let sum_exp: f32 = exp_logits.iter().sum();
598 let probs: Vec<f32> = exp_logits.iter().map(|&x| x / sum_exp).collect();
599 Ok(probs)
600 }
601
602 fn update_baseline(&mut self, reward: f32) {
604 self.baseline_history.push(reward);
605
606 if self.baseline_history.len() > self.max_baseline_history {
608 self.baseline_history.remove(0);
609 }
610 }
611
612 fn get_baseline(&self) -> f32 {
614 if self.baseline_history.is_empty() {
615 0.0
616 } else {
617 self.baseline_history.iter().sum::<f32>() / self.baseline_history.len() as f32
618 }
619 }
620
621 pub fn train_step(&mut self, batch: &TrainingBatch) -> Result<TrainingMetrics> {
623 if batch.examples.is_empty() {
624 return Err(ModelError::Training("Empty batch".to_string()));
625 }
626
627 let mut total_reward = 0.0;
628 let mut total_loss = 0.0;
629 let mut correct_predictions = 0;
630
631 for example in &batch.examples {
632 let (predicted_response, reasoning_chain) =
634 self.generate_response_with_reasoning(&example.input)?;
635
636 let reward = self.reward_fn.compute_reward(
638 &reasoning_chain,
639 &example.target,
640 &predicted_response,
641 );
642
643 total_reward += reward;
644 self.update_baseline(reward);
645
646 if predicted_response.trim().to_lowercase() == example.target.trim().to_lowercase() {
648 correct_predictions += 1;
649 }
650
651 let action_probs = self.compute_action_probabilities(&example.input)?;
653
654 let rewards = vec![reward; action_probs.len()];
656 let policy_grad = PolicyGradient::new(action_probs, rewards);
657 let gradients = policy_grad.compute_gradients();
658
659 let loss = -(reward - self.get_baseline());
661 total_loss += loss;
662
663 let mut dummy_params = vec![0.1; gradients.len()];
666 self.optimizer.step_parameter(
667 &format!("rl_params_{}", example.input.len()),
668 &mut dummy_params,
669 &gradients,
670 )?;
671 }
672
673 self.step_count += 1;
674
675 let _avg_reward = total_reward / batch.examples.len() as f32;
676 let avg_loss = total_loss / batch.examples.len() as f32;
677 let accuracy = correct_predictions as f32 / batch.examples.len() as f32;
678
679 Ok(TrainingMetrics::new(avg_loss, accuracy, self.step_count))
680 }
681
682 pub fn evaluate(&mut self, examples: &[TrainingExample]) -> Result<RLEvaluationMetrics> {
684 if examples.is_empty() {
685 return Err(ModelError::Training("Empty evaluation set".to_string()));
686 }
687
688 let mut total_reward = 0.0;
689 let mut correct_predictions = 0;
690 let mut reasoning_quality_scores = Vec::new();
691
692 for example in examples {
693 let (predicted_response, reasoning_chain) =
694 self.generate_response_with_reasoning(&example.input)?;
695
696 let reward = self.reward_fn.compute_reward(
697 &reasoning_chain,
698 &example.target,
699 &predicted_response,
700 );
701
702 total_reward += reward;
703 reasoning_quality_scores.push(reward);
704
705 if predicted_response.trim().to_lowercase() == example.target.trim().to_lowercase() {
706 correct_predictions += 1;
707 }
708 }
709
710 let avg_reward = total_reward / examples.len() as f32;
711 let accuracy = correct_predictions as f32 / examples.len() as f32;
712 let avg_reasoning_quality =
713 reasoning_quality_scores.iter().sum::<f32>() / reasoning_quality_scores.len() as f32;
714
715 Ok(RLEvaluationMetrics {
716 average_reward: avg_reward,
717 accuracy,
718 reasoning_quality: avg_reasoning_quality,
719 baseline: self.get_baseline(),
720 total_examples: examples.len(),
721 })
722 }
723
724 pub fn step_count(&self) -> usize {
726 self.step_count
727 }
728
729 pub fn baseline(&self) -> f32 {
731 self.get_baseline()
732 }
733}
734
735impl RLTrainer {
736 fn tokenize(&self, text: &str) -> Vec<u32> {
738 let binding = text.to_lowercase();
739 let words: Vec<&str> = binding.split_whitespace().collect();
740
741 let mut token_ids = Vec::new();
742
743 for word in words {
744 let mut hash = 0u32;
745 for byte in word.bytes() {
746 hash = hash.wrapping_mul(31).wrapping_add(byte as u32);
747 }
748 token_ids.push(hash % self.vocab_size as u32);
749 }
750
751 if token_ids.is_empty() {
752 token_ids.push(0); }
754
755 token_ids
756 }
757
758 fn decode(&self, token_ids: &[u32]) -> String {
760 token_ids
761 .iter()
762 .map(|id| format!("<{}>", id))
763 .collect::<Vec<_>>()
764 .join(" ")
765 }
766}
767
768#[derive(Debug, Clone)]
770pub struct RLEvaluationMetrics {
771 pub average_reward: f32,
772 pub accuracy: f32,
773 pub reasoning_quality: f32,
774 pub baseline: f32,
775 pub total_examples: usize,
776}
777
778impl RLEvaluationMetrics {
779 pub fn display(&self) {
781 println!("RL Evaluation Metrics:");
782 println!(" Average Reward: {:.4}", self.average_reward);
783 println!(" Accuracy: {:.2}%", self.accuracy * 100.0);
784 println!(" Reasoning Quality: {:.4}", self.reasoning_quality);
785 println!(" Baseline: {:.4}", self.baseline);
786 println!(" Total Examples: {}", self.total_examples);
787 fn tokenize(text: &str, vocab_size: usize) -> Vec<u32> {
789 let binding = text.to_lowercase();
790 let words: Vec<&str> = binding.split_whitespace().collect();
791
792 let mut token_ids = Vec::new();
793
794 for word in words {
795 let mut hash = 0u32;
796 for byte in word.bytes() {
797 hash = hash.wrapping_mul(31).wrapping_add(byte as u32);
798 }
799 token_ids.push(hash % vocab_size as u32);
800 }
801
802 if token_ids.is_empty() {
803 token_ids.push(0); }
805
806 token_ids
807 }
808
809 fn decode(token_ids: &[u32]) -> String {
811 token_ids
812 .iter()
813 .map(|id| format!("<{}>", id))
814 .collect::<Vec<_>>()
815 .join(" ")
816 }
817 }
818}
819
820#[cfg(test)]
821mod tests {
822 use super::*;
823 use crate::model::{DeepSeekR1Model, ModelConfig};
824 use crate::training::data::{ProblemType, TrainingExample};
825
826 #[test]
827 fn test_basic_trainer_creation() {
828 let config = ModelConfig::default();
829 let model = DeepSeekR1Model::new(config).unwrap();
830 let trainer = BasicTrainer::new(model);
831 assert!(trainer.is_ok());
832 }
833
834 #[test]
835 fn test_basic_trainer_with_custom_config() {
836 let config = ModelConfig::default();
837 let model = DeepSeekR1Model::new(config).unwrap();
838
839 let optimizer_config = OptimizerConfig {
840 learning_rate: 0.001,
841 ..OptimizerConfig::default()
842 };
843
844 let trainer = BasicTrainer::with_optimizer_config(model, optimizer_config);
845 assert!(trainer.is_ok());
846 }
847
848 #[test]
849 fn test_training_step() {
850 let config = ModelConfig::default();
851 let model = DeepSeekR1Model::new(config).unwrap();
852 let mut trainer = BasicTrainer::new(model).unwrap();
853
854 let examples = vec![
855 TrainingExample::new("2 + 2".to_string(), "4".to_string(), ProblemType::Math),
856 TrainingExample::new("3 * 3".to_string(), "9".to_string(), ProblemType::Math),
857 ];
858
859 let batch = TrainingBatch::new(examples);
860 let result = trainer.train_step(&batch);
861 assert!(result.is_ok());
862
863 let metrics = result.unwrap();
864 assert!(metrics.loss >= 0.0);
865 assert!(metrics.accuracy >= 0.0 && metrics.accuracy <= 1.0);
866 assert_eq!(metrics.step, 1);
867 }
868
869 #[test]
870 fn test_evaluation() {
871 let config = ModelConfig::default();
872 let model = DeepSeekR1Model::new(config).unwrap();
873 let mut trainer = BasicTrainer::new(model).unwrap();
874
875 let examples = vec![TrainingExample::new(
876 "test".to_string(),
877 "result".to_string(),
878 ProblemType::General,
879 )];
880
881 let result = trainer.evaluate(&examples);
882 assert!(result.is_ok());
883 }
884
885 #[test]
886 fn test_empty_batch_error() {
887 let config = ModelConfig::default();
888 let model = DeepSeekR1Model::new(config).unwrap();
889 let mut trainer = BasicTrainer::new(model).unwrap();
890
891 let batch = TrainingBatch::new(vec![]);
892 let result = trainer.train_step(&batch);
893 assert!(result.is_err());
894 }
895
896 #[test]
897 fn test_rl_trainer_creation() {
898 let config = ModelConfig::default();
899 let model = DeepSeekR1Model::new(config).unwrap();
900 let trainer = RLTrainer::new(model);
901 assert!(trainer.is_ok());
902 }
903
904 #[test]
905 fn test_rl_trainer_with_custom_config() {
906 let config = ModelConfig::default();
907 let model = DeepSeekR1Model::new(config).unwrap();
908
909 let optimizer_config = OptimizerConfig {
910 learning_rate: 1e-6,
911 ..OptimizerConfig::default()
912 };
913
914 let trainer = RLTrainer::with_optimizer_config(model, optimizer_config);
915 assert!(trainer.is_ok());
916 }
917
918 #[test]
919 fn test_reward_function() {
920 let reward_fn = SimpleRewardFunction;
921
922 let reasoning = vec![
924 "I need to add 2 and 2".to_string(),
925 "2 + 2 = 4".to_string(),
926 "Therefore, the answer is 4".to_string(),
927 ];
928 let reward = reward_fn.compute_reward(&reasoning, "4", "4");
929 assert!(reward > 1.0); let reward_wrong = reward_fn.compute_reward(&reasoning, "4", "5");
933 assert!(reward_wrong < reward); let reward_no_reasoning = reward_fn.compute_reward(&[], "4", "4");
937 assert!(reward_no_reasoning < reward); }
939
940 #[test]
941 fn test_policy_gradient() {
942 let action_probs = vec![0.3, 0.5, 0.2];
943 let rewards = vec![1.0, 0.5, 0.8];
944
945 let policy_grad = PolicyGradient::new(action_probs, rewards);
946 assert!((policy_grad.baseline - 0.767).abs() < 0.01); let gradients = policy_grad.compute_gradients();
949 assert_eq!(gradients.len(), 3);
950 }
951
952 #[test]
953 fn test_rl_training_step() {
954 let config = ModelConfig::default();
955 let model = DeepSeekR1Model::new(config).unwrap();
956 let mut trainer = RLTrainer::new(model).unwrap();
957
958 let examples = vec![TrainingExample::new(
959 "2 + 2".to_string(),
960 "4".to_string(),
961 ProblemType::Math,
962 )];
963
964 let batch = TrainingBatch::new(examples);
965 let result = trainer.train_step(&batch);
966 assert!(result.is_ok());
967
968 let metrics = result.unwrap();
969 assert_eq!(metrics.step, 1);
970 }
971
972 #[test]
973 fn test_rl_evaluation() {
974 let config = ModelConfig::default();
975 let model = DeepSeekR1Model::new(config).unwrap();
976 let mut trainer = RLTrainer::new(model).unwrap();
977
978 let examples = vec![TrainingExample::new(
979 "test".to_string(),
980 "result".to_string(),
981 ProblemType::General,
982 )];
983
984 let result = trainer.evaluate(&examples);
985 assert!(result.is_ok());
986
987 let metrics = result.unwrap();
988 assert_eq!(metrics.total_examples, 1);
989 assert!(metrics.average_reward >= 0.0);
990 }
991
992 #[test]
993 fn test_trainer_basic_functionality() {
994 let config = ModelConfig::default();
995 let model = DeepSeekR1Model::new(config).unwrap();
996 let trainer = BasicTrainer::new(model).unwrap();
997
998 assert!(trainer.model.config().vocab_size > 0);
1000
1001 let examples = vec![TrainingExample::new(
1003 "What is 2 + 2?".to_string(),
1004 "4".to_string(),
1005 ProblemType::Math,
1006 )];
1007 let _batch = TrainingBatch::new(examples);
1008
1009 }
1011}