Skip to main content

ds_r1_rs/training/
trainer.rs

1//! # Training Infrastructure
2//!
3//! Basic trainer implementations for supervised and reinforcement learning.
4
5#![allow(dead_code)]
6use crate::model::DeepSeekR1Model;
7use crate::training::data::{TrainingBatch, TrainingExample};
8use crate::training::loss::{CrossEntropyLoss, LossFunction, TrainingMetrics};
9use crate::training::optimizer::{Optimizer, OptimizerConfig};
10use crate::utils::error::{ModelError, Result};
11
12/// Basic supervised trainer
13pub struct BasicTrainer {
14    model: DeepSeekR1Model,
15    optimizer: Optimizer,
16    loss_fn: CrossEntropyLoss,
17    step_count: usize,
18    vocab_size: usize,
19}
20
21impl BasicTrainer {
22    /// Create a new basic trainer
23    pub fn new(model: DeepSeekR1Model) -> Result<Self> {
24        let optimizer_config = OptimizerConfig::default();
25        let optimizer = Optimizer::new(optimizer_config)?;
26        let loss_fn = CrossEntropyLoss::new();
27        let vocab_size = model.config().vocab_size;
28
29        Ok(Self {
30            model,
31            optimizer,
32            loss_fn,
33            step_count: 0,
34            vocab_size,
35        })
36    }
37
38    /// Create a new basic trainer with custom optimizer config
39    pub fn with_optimizer_config(
40        model: DeepSeekR1Model,
41        optimizer_config: OptimizerConfig,
42    ) -> Result<Self> {
43        let optimizer = Optimizer::new(optimizer_config)?;
44        let loss_fn = CrossEntropyLoss::new();
45        let vocab_size = model.config().vocab_size;
46
47        Ok(Self {
48            model,
49            optimizer,
50            loss_fn,
51            step_count: 0,
52            vocab_size,
53        })
54    }
55
56    /// Convert text to token IDs using a simple word-based tokenizer
57    fn tokenize(&self, text: &str) -> Vec<u32> {
58        // Simple word-based tokenization with basic preprocessing
59        let binding = text.to_lowercase();
60        let words: Vec<&str> = binding.split_whitespace().collect();
61
62        let mut token_ids = Vec::new();
63
64        for word in words {
65            // Create a simple hash-based token ID
66            let mut hash = 0u32;
67            for byte in word.bytes() {
68                hash = hash.wrapping_mul(31).wrapping_add(byte as u32);
69            }
70            // Ensure token ID is within vocabulary range
71            token_ids.push(hash % self.vocab_size as u32);
72        }
73
74        // Add special tokens if empty
75        if token_ids.is_empty() {
76            token_ids.push(0); // UNK token
77        }
78
79        token_ids
80    }
81
82    /// Prepare training data from examples
83    fn prepare_training_data(&self, examples: &[TrainingExample]) -> Result<(Vec<u32>, Vec<u32>)> {
84        let mut input_ids = Vec::new();
85        let mut target_ids = Vec::new();
86
87        for example in examples {
88            // Tokenize input and target
89            let input_tokens = self.tokenize(&example.input);
90            let target_tokens = self.tokenize(&example.target);
91
92            // For next-token prediction, we shift targets by one position
93            input_ids.extend(input_tokens);
94            target_ids.extend(target_tokens);
95        }
96
97        Ok((input_ids, target_ids))
98    }
99
100    /// Prepare last-step training data: per-example input token sequences and single-target class
101    fn prepare_last_step_data(
102        &self,
103        examples: &[TrainingExample],
104    ) -> Result<(Vec<Vec<u32>>, Vec<u32>)> {
105        let mut inputs: Vec<Vec<u32>> = Vec::with_capacity(examples.len());
106        let mut targets: Vec<u32> = Vec::with_capacity(examples.len());
107
108        for example in examples {
109            let input_tokens = self.tokenize(&example.input);
110            // For a classification-like next-token objective, use the first token of the target string
111            let target_tokens = self.tokenize(&example.target);
112            let target_id = target_tokens.get(0).copied().unwrap_or(0);
113
114            inputs.push(input_tokens);
115            targets.push(target_id);
116        }
117
118        Ok((inputs, targets))
119    }
120
121    /// Compute gradients using backpropagation through cross-entropy loss
122    fn compute_gradients(&self, predictions: &[f32], targets: &[u32]) -> Result<Vec<f32>> {
123        let mut gradients = vec![0.0; predictions.len()];
124        let vocab_size = self.vocab_size;
125        let num_samples = targets.len();
126
127        if predictions.len() != num_samples * vocab_size {
128            return Err(ModelError::Training(format!(
129                "Prediction size mismatch: expected {}, got {}",
130                num_samples * vocab_size,
131                predictions.len()
132            )));
133        }
134
135        for (i, &target) in targets.iter().enumerate() {
136            let start_idx = i * vocab_size;
137            let end_idx = start_idx + vocab_size;
138
139            if end_idx <= predictions.len() && (target as usize) < vocab_size {
140                let logits = &predictions[start_idx..end_idx];
141
142                // Compute softmax probabilities with numerical stability
143                let max_logit = logits.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
144                let exp_logits: Vec<f32> = logits.iter().map(|&x| (x - max_logit).exp()).collect();
145                let sum_exp: f32 = exp_logits.iter().sum();
146
147                // Compute gradients of cross-entropy loss w.r.t. logits
148                for j in 0..vocab_size {
149                    let grad_idx = start_idx + j;
150                    let prob = exp_logits[j] / sum_exp;
151
152                    if j == target as usize {
153                        // Gradient for correct class: p - 1
154                        gradients[grad_idx] = prob - 1.0;
155                    } else {
156                        // Gradient for incorrect class: p
157                        gradients[grad_idx] = prob;
158                    }
159                }
160            }
161        }
162
163        // Normalize gradients by batch size
164        let batch_size = num_samples as f32;
165        for grad in &mut gradients {
166            *grad /= batch_size;
167        }
168
169        Ok(gradients)
170    }
171
172    /// Perform a training step
173    pub fn train_step(&mut self, batch: &TrainingBatch) -> Result<TrainingMetrics> {
174        if batch.examples.is_empty() {
175            return Err(ModelError::Training("Empty batch".to_string()));
176        }
177
178        // Prepare per-example last-step data
179        let (inputs_per_example, targets) = self.prepare_last_step_data(&batch.examples)?;
180
181        if inputs_per_example.is_empty() || targets.is_empty() {
182            return Err(ModelError::Training("No valid training data".to_string()));
183        }
184
185        // Forward per-example (last-step) and collect last hiddens and last input ids
186        let (predictions, last_hiddens, last_input_ids) =
187            self.forward_last_step(&inputs_per_example)?;
188
189        // Compute loss
190        let target_floats: Vec<f32> = targets.iter().map(|&x| x as f32).collect();
191        let loss = self.loss_fn.compute_loss(&predictions, &target_floats)?;
192
193        // Compute accuracy
194        let accuracy = self.loss_fn.compute_accuracy(&predictions, &targets);
195
196        // Compute gradients
197        let gradients = self.compute_gradients(&predictions, &targets)?;
198
199        // Update model parameters using computed gradients (LM head weights + bias + embeddings)
200        self.update_model_parameters(&gradients, &last_hiddens, &last_input_ids)?;
201
202        self.step_count += 1;
203
204        Ok(TrainingMetrics::new(loss, accuracy, self.step_count))
205    }
206
207    /// Forward pass through the model
208    fn forward_pass(&mut self, input_ids: &[u32]) -> Result<Vec<f32>> {
209        // Use the actual model forward pass
210        let logits = self.model.forward(input_ids)?;
211        Ok(logits)
212    }
213
214    /// Forward per-example and return concatenated last-step logits, last hidden states, and last input ids
215    fn forward_last_step(
216        &mut self,
217        inputs: &[Vec<u32>],
218    ) -> Result<(Vec<f32>, Vec<Vec<f32>>, Vec<u32>)> {
219        let mut predictions: Vec<f32> = Vec::new();
220        let mut last_hiddens: Vec<Vec<f32>> = Vec::with_capacity(inputs.len());
221        let mut last_input_ids: Vec<u32> = Vec::with_capacity(inputs.len());
222
223        for input in inputs {
224            if input.is_empty() {
225                return Err(ModelError::Training("Empty input sequence".to_string()));
226            }
227
228            // Full forward for this example
229            let logits = self.model.forward(input)?;
230            let vocab_size = self.vocab_size;
231            if logits.len() < vocab_size {
232                return Err(ModelError::Training(
233                    "Model output size doesn't match vocabulary size".to_string(),
234                ));
235            }
236            // Take last position logits only
237            let last_logits = &logits[logits.len() - vocab_size..];
238            predictions.extend_from_slice(last_logits);
239
240            // Final hidden states and take only last one
241            let hidden = self.model.forward_hidden(input)?;
242            let last_h = hidden
243                .last()
244                .ok_or_else(|| ModelError::Training("No hidden states".to_string()))?
245                .clone();
246            last_hiddens.push(last_h);
247
248            // Track last input id for embedding update
249            last_input_ids.push(*input.last().unwrap());
250        }
251
252        Ok((predictions, last_hiddens, last_input_ids))
253    }
254
255    /// Update model parameters using gradients from CE on logits.
256    /// Implements a minimal backward path for the LM head (weights + bias) and token embeddings.
257    /// - dW_lm_head = sum_t outer(dlogits_t, h_t)
258    /// - db_lm_head = sum_t dlogits_t
259    /// - dembed[token_t] += W^T * dlogits_t   (minimal proxy ignoring intermediate layers)
260    fn update_model_parameters(
261        &mut self,
262        gradients: &[f32],
263        hidden: &[Vec<f32>],
264        input_ids: &[u32],
265    ) -> Result<()> {
266        let vocab_size = self.vocab_size;
267        if vocab_size == 0 {
268            return Err(ModelError::Training("Vocab size is zero".to_string()));
269        }
270        if gradients.len() % vocab_size != 0 {
271            return Err(ModelError::Training(format!(
272                "Gradients length {} is not divisible by vocab_size {}",
273                gradients.len(),
274                vocab_size
275            )));
276        }
277        let num_samples = gradients.len() / vocab_size;
278        if hidden.len() != num_samples || input_ids.len() != num_samples {
279            return Err(ModelError::Training(format!(
280                "Mismatch: hidden len {} / input_ids len {} vs samples {}",
281                hidden.len(),
282                input_ids.len(),
283                num_samples
284            )));
285        }
286
287        // 1) Bias gradients: db = sum_t dlogits_t
288        let mut bias_grads = vec![0.0f32; vocab_size];
289        for (i, chunk) in gradients.chunks(vocab_size).enumerate() {
290            let _ = i; // unused
291            for v in 0..vocab_size {
292                bias_grads[v] += chunk[v];
293            }
294        }
295
296        // Apply bias update
297        {
298            let name = "lm_head.bias";
299            let bias_slice = self.model.lm_head_bias_mut();
300            self.optimizer
301                .step_parameter(name, bias_slice, &bias_grads)?;
302        }
303
304        // 2) Weight gradients: dW[v] = sum_t dlogits_t[v] * h_t
305        // Accumulate per row to avoid storing full matrix if not needed.
306        // Use hidden_size inferred from a row of lm_head.
307        let lm_w_snapshot = self.model.lm_head_weights().clone();
308        if lm_w_snapshot.is_empty() {
309            return Err(ModelError::Training(
310                "LM head weights are empty".to_string(),
311            ));
312        }
313        let hidden_size = lm_w_snapshot[0].len();
314
315        for v in 0..vocab_size {
316            let mut row_grad = vec![0.0f32; hidden_size];
317            for t in 0..num_samples {
318                let g_vt = gradients[t * vocab_size + v];
319                if g_vt != 0.0 {
320                    let h_t = &hidden[t];
321                    // Safety: hidden[t] must match hidden_size
322                    if h_t.len() != hidden_size {
323                        return Err(ModelError::Training(format!(
324                            "Hidden size {} mismatch at t={} (expected {})",
325                            h_t.len(),
326                            t,
327                            hidden_size
328                        )));
329                    }
330                    for k in 0..hidden_size {
331                        row_grad[k] += g_vt * h_t[k];
332                    }
333                }
334            }
335            // Apply update to lm_head.weight[v]
336            let name = format!("lm_head.weight[{}]", v);
337            let row_slice = self.model.lm_head_row_mut(v)?;
338            self.optimizer.step_parameter(&name, row_slice, &row_grad)?;
339        }
340
341        // 3) Embedding gradients: for each position t, dembed[token_t] += W^T * dlogits_t
342        // We compute per-t grad_hidden = W^T * dlogits_t, then update the corresponding embedding row.
343        for t in 0..num_samples {
344            let token_id = input_ids[t] as usize;
345            if token_id >= self.vocab_size {
346                continue; // skip OOB
347            }
348            let dlogits_t = &gradients[t * vocab_size..(t + 1) * vocab_size];
349
350            // grad_hidden = W^T * dlogits_t
351            let mut grad_hidden = vec![0.0f32; hidden_size];
352            for v in 0..vocab_size {
353                let g = dlogits_t[v];
354                if g != 0.0 {
355                    let w_row = &lm_w_snapshot[v];
356                    for k in 0..hidden_size {
357                        grad_hidden[k] += w_row[k] * g;
358                    }
359                }
360            }
361
362            // Apply update to embedding row for this token
363            let name = format!("embeddings.weight[{}]", token_id);
364            let row_slice = self.model.embedding_row_mut(token_id)?;
365            self.optimizer
366                .step_parameter(&name, row_slice, &grad_hidden)?;
367        }
368
369        Ok(())
370    }
371
372    /// Get current training step
373    pub fn step_count(&self) -> usize {
374        self.step_count
375    }
376
377    /// Evaluate on a batch of examples
378    pub fn evaluate(&mut self, examples: &[TrainingExample]) -> Result<TrainingMetrics> {
379        if examples.is_empty() {
380            return Err(ModelError::Training("Empty evaluation set".to_string()));
381        }
382
383        let (inputs_per_example, targets) = self.prepare_last_step_data(examples)?;
384        let (predictions, _last_hiddens, _last_input_ids) =
385            self.forward_last_step(&inputs_per_example)?;
386
387        let target_floats: Vec<f32> = targets.iter().map(|&x| x as f32).collect();
388        let loss = self.loss_fn.compute_loss(&predictions, &target_floats)?;
389        let accuracy = self.loss_fn.compute_accuracy(&predictions, &targets);
390
391        Ok(TrainingMetrics::new(loss, accuracy, self.step_count))
392    }
393}
394
395/// Reward function for evaluating reasoning quality
396pub trait RewardFunction {
397    fn compute_reward(&self, reasoning_chain: &[String], target: &str, predicted: &str) -> f32;
398}
399
400/// Simple reward function based on correctness and reasoning quality
401pub struct SimpleRewardFunction;
402
403impl RewardFunction for SimpleRewardFunction {
404    fn compute_reward(&self, reasoning_chain: &[String], target: &str, predicted: &str) -> f32 {
405        let mut reward = 0.0;
406
407        // Base reward for correct answer
408        if predicted.trim().to_lowercase() == target.trim().to_lowercase() {
409            reward += 1.0;
410        }
411
412        // Bonus for reasoning quality
413        let reasoning_bonus = self.evaluate_reasoning_quality(reasoning_chain);
414        reward += reasoning_bonus;
415
416        // Penalty for very short or very long reasoning
417        let length_penalty = self.evaluate_reasoning_length(reasoning_chain);
418        reward += length_penalty;
419
420        reward.max(0.0) // Ensure non-negative reward
421    }
422}
423
424impl SimpleRewardFunction {
425    /// Evaluate the quality of reasoning steps
426    fn evaluate_reasoning_quality(&self, reasoning_chain: &[String]) -> f32 {
427        if reasoning_chain.is_empty() {
428            return -0.5; // Penalty for no reasoning
429        }
430
431        let mut quality_score: f32 = 0.0;
432
433        // Reward for step-by-step structure
434        if reasoning_chain.len() >= 2 {
435            quality_score += 0.2;
436        }
437
438        // Reward for mathematical keywords in math problems
439        let math_keywords = ["add", "subtract", "multiply", "solve", "equation", "="];
440        let reasoning_text = reasoning_chain.join(" ").to_lowercase();
441
442        for keyword in &math_keywords {
443            if reasoning_text.contains(keyword) {
444                quality_score += 0.1;
445            }
446        }
447
448        // Reward for logical connectors
449        let logical_connectors = ["therefore", "since", "because", "so", "thus"];
450        for connector in &logical_connectors {
451            if reasoning_text.contains(connector) {
452                quality_score += 0.1;
453            }
454        }
455
456        quality_score.min(0.5) // Cap the bonus
457    }
458
459    /// Evaluate reasoning length appropriateness
460    fn evaluate_reasoning_length(&self, reasoning_chain: &[String]) -> f32 {
461        let length = reasoning_chain.len();
462
463        match length {
464            0 => -0.3,     // Too short
465            1 => -0.1,     // Still too short
466            2..=5 => 0.0,  // Good length
467            6..=8 => -0.1, // Getting long
468            _ => -0.2,     // Too long
469        }
470    }
471}
472
473/// Policy gradient computation for REINFORCE algorithm
474#[derive(Debug, Clone)]
475pub struct PolicyGradient {
476    pub action_probs: Vec<f32>,
477    pub rewards: Vec<f32>,
478    pub baseline: f32,
479}
480
481impl PolicyGradient {
482    /// Create new policy gradient
483    pub fn new(action_probs: Vec<f32>, rewards: Vec<f32>) -> Self {
484        let baseline = if rewards.is_empty() {
485            0.0
486        } else {
487            rewards.iter().sum::<f32>() / rewards.len() as f32
488        };
489
490        Self {
491            action_probs,
492            rewards,
493            baseline,
494        }
495    }
496
497    /// Compute policy gradient using REINFORCE
498    pub fn compute_gradients(&self) -> Vec<f32> {
499        let mut gradients = vec![0.0; self.action_probs.len()];
500
501        for (i, (&prob, &reward)) in self
502            .action_probs
503            .iter()
504            .zip(self.rewards.iter())
505            .enumerate()
506        {
507            if prob > 0.0 {
508                // REINFORCE gradient: (reward - baseline) * grad_log_prob
509                let advantage = reward - self.baseline;
510                gradients[i] = advantage / prob; // Simplified gradient of log probability
511            }
512        }
513
514        gradients
515    }
516}
517
518/// Reinforcement learning trainer using REINFORCE algorithm
519pub struct RLTrainer {
520    model: DeepSeekR1Model,
521    optimizer: Optimizer,
522    reward_fn: SimpleRewardFunction,
523    step_count: usize,
524    vocab_size: usize,
525    baseline_history: Vec<f32>,
526    max_baseline_history: usize,
527}
528
529impl RLTrainer {
530    /// Create a new RL trainer
531    pub fn new(model: DeepSeekR1Model) -> Result<Self> {
532        let optimizer_config = OptimizerConfig {
533            learning_rate: 1e-5, // Lower learning rate for RL
534            ..OptimizerConfig::default()
535        };
536        let optimizer = Optimizer::new(optimizer_config)?;
537        let reward_fn = SimpleRewardFunction;
538        let vocab_size = model.config().vocab_size;
539
540        Ok(Self {
541            model,
542            optimizer,
543            reward_fn,
544            step_count: 0,
545            vocab_size,
546            baseline_history: Vec::new(),
547            max_baseline_history: 100,
548        })
549    }
550
551    /// Create RL trainer with custom optimizer config
552    pub fn with_optimizer_config(
553        model: DeepSeekR1Model,
554        optimizer_config: OptimizerConfig,
555    ) -> Result<Self> {
556        let optimizer = Optimizer::new(optimizer_config)?;
557        let reward_fn = SimpleRewardFunction;
558        let vocab_size = model.config().vocab_size;
559
560        Ok(Self {
561            model,
562            optimizer,
563            reward_fn,
564            step_count: 0,
565            vocab_size,
566            baseline_history: Vec::new(),
567            max_baseline_history: 100,
568        })
569    }
570
571    /// Generate response with reasoning for RL training
572    fn generate_response_with_reasoning(&mut self, input: &str) -> Result<(String, Vec<String>)> {
573        // Tokenize input using RLTrainer's tokenizer
574        let input_tokens = self.tokenize(input);
575        // Forward pass through the model
576        let logits = self.model.forward(&input_tokens)?;
577        // Decode response using argmax for simplicity
578        let response_token = logits
579            .iter()
580            .enumerate()
581            .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
582            .map(|(idx, _)| idx)
583            .unwrap_or(0) as u32;
584        let response = self.decode(&[response_token]);
585        // Reasoning chain extraction (if available, otherwise empty)
586        let reasoning_chain = Vec::new(); // TODO: Extract reasoning chain from model output if supported
587        Ok((response, reasoning_chain))
588    }
589
590    /// Compute action probabilities (simplified)
591    fn compute_action_probabilities(&mut self, input: &str) -> Result<Vec<f32>> {
592        // Use model logits for probability computation
593        let input_tokens = self.tokenize(input);
594        let logits = self.model.forward(&input_tokens)?;
595        let max_logit = logits.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
596        let exp_logits: Vec<f32> = logits.iter().map(|&x| (x - max_logit).exp()).collect();
597        let sum_exp: f32 = exp_logits.iter().sum();
598        let probs: Vec<f32> = exp_logits.iter().map(|&x| x / sum_exp).collect();
599        Ok(probs)
600    }
601
602    /// Update baseline estimate
603    fn update_baseline(&mut self, reward: f32) {
604        self.baseline_history.push(reward);
605
606        // Keep only recent history
607        if self.baseline_history.len() > self.max_baseline_history {
608            self.baseline_history.remove(0);
609        }
610    }
611
612    /// Get current baseline estimate
613    fn get_baseline(&self) -> f32 {
614        if self.baseline_history.is_empty() {
615            0.0
616        } else {
617            self.baseline_history.iter().sum::<f32>() / self.baseline_history.len() as f32
618        }
619    }
620
621    /// Perform an RL training step using REINFORCE
622    pub fn train_step(&mut self, batch: &TrainingBatch) -> Result<TrainingMetrics> {
623        if batch.examples.is_empty() {
624            return Err(ModelError::Training("Empty batch".to_string()));
625        }
626
627        let mut total_reward = 0.0;
628        let mut total_loss = 0.0;
629        let mut correct_predictions = 0;
630
631        for example in &batch.examples {
632            // Generate response with reasoning
633            let (predicted_response, reasoning_chain) =
634                self.generate_response_with_reasoning(&example.input)?;
635
636            // Compute reward
637            let reward = self.reward_fn.compute_reward(
638                &reasoning_chain,
639                &example.target,
640                &predicted_response,
641            );
642
643            total_reward += reward;
644            self.update_baseline(reward);
645
646            // Check if prediction is correct
647            if predicted_response.trim().to_lowercase() == example.target.trim().to_lowercase() {
648                correct_predictions += 1;
649            }
650
651            // Compute action probabilities
652            let action_probs = self.compute_action_probabilities(&example.input)?;
653
654            // Compute policy gradient
655            let rewards = vec![reward; action_probs.len()];
656            let policy_grad = PolicyGradient::new(action_probs, rewards);
657            let gradients = policy_grad.compute_gradients();
658
659            // Compute loss (negative expected reward)
660            let loss = -(reward - self.get_baseline());
661            total_loss += loss;
662
663            // Update parameters using policy gradients
664            // In a real implementation, this would update actual model parameters
665            let mut dummy_params = vec![0.1; gradients.len()];
666            self.optimizer.step_parameter(
667                &format!("rl_params_{}", example.input.len()),
668                &mut dummy_params,
669                &gradients,
670            )?;
671        }
672
673        self.step_count += 1;
674
675        let _avg_reward = total_reward / batch.examples.len() as f32;
676        let avg_loss = total_loss / batch.examples.len() as f32;
677        let accuracy = correct_predictions as f32 / batch.examples.len() as f32;
678
679        Ok(TrainingMetrics::new(avg_loss, accuracy, self.step_count))
680    }
681
682    /// Evaluate the RL policy on examples
683    pub fn evaluate(&mut self, examples: &[TrainingExample]) -> Result<RLEvaluationMetrics> {
684        if examples.is_empty() {
685            return Err(ModelError::Training("Empty evaluation set".to_string()));
686        }
687
688        let mut total_reward = 0.0;
689        let mut correct_predictions = 0;
690        let mut reasoning_quality_scores = Vec::new();
691
692        for example in examples {
693            let (predicted_response, reasoning_chain) =
694                self.generate_response_with_reasoning(&example.input)?;
695
696            let reward = self.reward_fn.compute_reward(
697                &reasoning_chain,
698                &example.target,
699                &predicted_response,
700            );
701
702            total_reward += reward;
703            reasoning_quality_scores.push(reward);
704
705            if predicted_response.trim().to_lowercase() == example.target.trim().to_lowercase() {
706                correct_predictions += 1;
707            }
708        }
709
710        let avg_reward = total_reward / examples.len() as f32;
711        let accuracy = correct_predictions as f32 / examples.len() as f32;
712        let avg_reasoning_quality =
713            reasoning_quality_scores.iter().sum::<f32>() / reasoning_quality_scores.len() as f32;
714
715        Ok(RLEvaluationMetrics {
716            average_reward: avg_reward,
717            accuracy,
718            reasoning_quality: avg_reasoning_quality,
719            baseline: self.get_baseline(),
720            total_examples: examples.len(),
721        })
722    }
723
724    /// Get current step count
725    pub fn step_count(&self) -> usize {
726        self.step_count
727    }
728
729    /// Get current baseline
730    pub fn baseline(&self) -> f32 {
731        self.get_baseline()
732    }
733}
734
735impl RLTrainer {
736    /// Tokenize text using word-based hashing (same as BasicTrainer)
737    fn tokenize(&self, text: &str) -> Vec<u32> {
738        let binding = text.to_lowercase();
739        let words: Vec<&str> = binding.split_whitespace().collect();
740
741        let mut token_ids = Vec::new();
742
743        for word in words {
744            let mut hash = 0u32;
745            for byte in word.bytes() {
746                hash = hash.wrapping_mul(31).wrapping_add(byte as u32);
747            }
748            token_ids.push(hash % self.vocab_size as u32);
749        }
750
751        if token_ids.is_empty() {
752            token_ids.push(0); // UNK token
753        }
754
755        token_ids
756    }
757
758    /// Decode token IDs to a string (simple implementation)
759    fn decode(&self, token_ids: &[u32]) -> String {
760        token_ids
761            .iter()
762            .map(|id| format!("<{}>", id))
763            .collect::<Vec<_>>()
764            .join(" ")
765    }
766}
767
768/// RL-specific evaluation metrics
769#[derive(Debug, Clone)]
770pub struct RLEvaluationMetrics {
771    pub average_reward: f32,
772    pub accuracy: f32,
773    pub reasoning_quality: f32,
774    pub baseline: f32,
775    pub total_examples: usize,
776}
777
778impl RLEvaluationMetrics {
779    /// Display metrics in a formatted way
780    pub fn display(&self) {
781        println!("RL Evaluation Metrics:");
782        println!("  Average Reward: {:.4}", self.average_reward);
783        println!("  Accuracy: {:.2}%", self.accuracy * 100.0);
784        println!("  Reasoning Quality: {:.4}", self.reasoning_quality);
785        println!("  Baseline: {:.4}", self.baseline);
786        println!("  Total Examples: {}", self.total_examples);
787        /// Tokenize text using word-based hashing (same as BasicTrainer)
788        fn tokenize(text: &str, vocab_size: usize) -> Vec<u32> {
789            let binding = text.to_lowercase();
790            let words: Vec<&str> = binding.split_whitespace().collect();
791
792            let mut token_ids = Vec::new();
793
794            for word in words {
795                let mut hash = 0u32;
796                for byte in word.bytes() {
797                    hash = hash.wrapping_mul(31).wrapping_add(byte as u32);
798                }
799                token_ids.push(hash % vocab_size as u32);
800            }
801
802            if token_ids.is_empty() {
803                token_ids.push(0); // UNK token
804            }
805
806            token_ids
807        }
808
809        /// Decode token IDs to a string (simple implementation)
810        fn decode(token_ids: &[u32]) -> String {
811            token_ids
812                .iter()
813                .map(|id| format!("<{}>", id))
814                .collect::<Vec<_>>()
815                .join(" ")
816        }
817    }
818}
819
820#[cfg(test)]
821mod tests {
822    use super::*;
823    use crate::model::{DeepSeekR1Model, ModelConfig};
824    use crate::training::data::{ProblemType, TrainingExample};
825
826    #[test]
827    fn test_basic_trainer_creation() {
828        let config = ModelConfig::default();
829        let model = DeepSeekR1Model::new(config).unwrap();
830        let trainer = BasicTrainer::new(model);
831        assert!(trainer.is_ok());
832    }
833
834    #[test]
835    fn test_basic_trainer_with_custom_config() {
836        let config = ModelConfig::default();
837        let model = DeepSeekR1Model::new(config).unwrap();
838
839        let optimizer_config = OptimizerConfig {
840            learning_rate: 0.001,
841            ..OptimizerConfig::default()
842        };
843
844        let trainer = BasicTrainer::with_optimizer_config(model, optimizer_config);
845        assert!(trainer.is_ok());
846    }
847
848    #[test]
849    fn test_training_step() {
850        let config = ModelConfig::default();
851        let model = DeepSeekR1Model::new(config).unwrap();
852        let mut trainer = BasicTrainer::new(model).unwrap();
853
854        let examples = vec![
855            TrainingExample::new("2 + 2".to_string(), "4".to_string(), ProblemType::Math),
856            TrainingExample::new("3 * 3".to_string(), "9".to_string(), ProblemType::Math),
857        ];
858
859        let batch = TrainingBatch::new(examples);
860        let result = trainer.train_step(&batch);
861        assert!(result.is_ok());
862
863        let metrics = result.unwrap();
864        assert!(metrics.loss >= 0.0);
865        assert!(metrics.accuracy >= 0.0 && metrics.accuracy <= 1.0);
866        assert_eq!(metrics.step, 1);
867    }
868
869    #[test]
870    fn test_evaluation() {
871        let config = ModelConfig::default();
872        let model = DeepSeekR1Model::new(config).unwrap();
873        let mut trainer = BasicTrainer::new(model).unwrap();
874
875        let examples = vec![TrainingExample::new(
876            "test".to_string(),
877            "result".to_string(),
878            ProblemType::General,
879        )];
880
881        let result = trainer.evaluate(&examples);
882        assert!(result.is_ok());
883    }
884
885    #[test]
886    fn test_empty_batch_error() {
887        let config = ModelConfig::default();
888        let model = DeepSeekR1Model::new(config).unwrap();
889        let mut trainer = BasicTrainer::new(model).unwrap();
890
891        let batch = TrainingBatch::new(vec![]);
892        let result = trainer.train_step(&batch);
893        assert!(result.is_err());
894    }
895
896    #[test]
897    fn test_rl_trainer_creation() {
898        let config = ModelConfig::default();
899        let model = DeepSeekR1Model::new(config).unwrap();
900        let trainer = RLTrainer::new(model);
901        assert!(trainer.is_ok());
902    }
903
904    #[test]
905    fn test_rl_trainer_with_custom_config() {
906        let config = ModelConfig::default();
907        let model = DeepSeekR1Model::new(config).unwrap();
908
909        let optimizer_config = OptimizerConfig {
910            learning_rate: 1e-6,
911            ..OptimizerConfig::default()
912        };
913
914        let trainer = RLTrainer::with_optimizer_config(model, optimizer_config);
915        assert!(trainer.is_ok());
916    }
917
918    #[test]
919    fn test_reward_function() {
920        let reward_fn = SimpleRewardFunction;
921
922        // Test correct answer with good reasoning
923        let reasoning = vec![
924            "I need to add 2 and 2".to_string(),
925            "2 + 2 = 4".to_string(),
926            "Therefore, the answer is 4".to_string(),
927        ];
928        let reward = reward_fn.compute_reward(&reasoning, "4", "4");
929        assert!(reward > 1.0); // Should get base reward + bonuses
930
931        // Test incorrect answer
932        let reward_wrong = reward_fn.compute_reward(&reasoning, "4", "5");
933        assert!(reward_wrong < reward); // Should be lower than correct answer
934
935        // Test no reasoning
936        let reward_no_reasoning = reward_fn.compute_reward(&[], "4", "4");
937        assert!(reward_no_reasoning < reward); // Should be penalized for no reasoning
938    }
939
940    #[test]
941    fn test_policy_gradient() {
942        let action_probs = vec![0.3, 0.5, 0.2];
943        let rewards = vec![1.0, 0.5, 0.8];
944
945        let policy_grad = PolicyGradient::new(action_probs, rewards);
946        assert!((policy_grad.baseline - 0.767).abs() < 0.01); // Average of rewards
947
948        let gradients = policy_grad.compute_gradients();
949        assert_eq!(gradients.len(), 3);
950    }
951
952    #[test]
953    fn test_rl_training_step() {
954        let config = ModelConfig::default();
955        let model = DeepSeekR1Model::new(config).unwrap();
956        let mut trainer = RLTrainer::new(model).unwrap();
957
958        let examples = vec![TrainingExample::new(
959            "2 + 2".to_string(),
960            "4".to_string(),
961            ProblemType::Math,
962        )];
963
964        let batch = TrainingBatch::new(examples);
965        let result = trainer.train_step(&batch);
966        assert!(result.is_ok());
967
968        let metrics = result.unwrap();
969        assert_eq!(metrics.step, 1);
970    }
971
972    #[test]
973    fn test_rl_evaluation() {
974        let config = ModelConfig::default();
975        let model = DeepSeekR1Model::new(config).unwrap();
976        let mut trainer = RLTrainer::new(model).unwrap();
977
978        let examples = vec![TrainingExample::new(
979            "test".to_string(),
980            "result".to_string(),
981            ProblemType::General,
982        )];
983
984        let result = trainer.evaluate(&examples);
985        assert!(result.is_ok());
986
987        let metrics = result.unwrap();
988        assert_eq!(metrics.total_examples, 1);
989        assert!(metrics.average_reward >= 0.0);
990    }
991
992    #[test]
993    fn test_trainer_basic_functionality() {
994        let config = ModelConfig::default();
995        let model = DeepSeekR1Model::new(config).unwrap();
996        let trainer = BasicTrainer::new(model).unwrap();
997
998        // Just verify trainer was created successfully
999        assert!(trainer.model.config().vocab_size > 0);
1000
1001        // Verify we can create training examples and batches
1002        let examples = vec![TrainingExample::new(
1003            "What is 2 + 2?".to_string(),
1004            "4".to_string(),
1005            ProblemType::Math,
1006        )];
1007        let _batch = TrainingBatch::new(examples);
1008
1009        // Test passes if we get here without panicking
1010    }
1011}