use crate::error::{Result, RuvLLMError};
use crate::lora::micro_lora::{EwcState, MicroLoRA, MicroLoraConfig, TargetModule};
use crate::lora::training::{EwcRegularizer, TrainingConfig, TrainingPipeline};
use crate::training::contrastive::{ContrastiveConfig, ContrastiveTrainer};
use crate::training::grpo::{GrpoConfig, GrpoOptimizer};
use ndarray::Array1;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::{Path, PathBuf};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RlmRefinerConfig {
pub lora_rank: usize,
pub learning_rate: f32,
pub training_tokens: usize,
pub batch_size: usize,
pub ewc_lambda: f32,
pub grpo_group_size: usize,
pub router_repair_epochs: usize,
pub use_metal: bool,
pub checkpoint_every_n: usize,
pub hidden_dim: usize,
pub checkpoint_dir: PathBuf,
}
impl Default for RlmRefinerConfig {
fn default() -> Self {
Self {
lora_rank: 2,
learning_rate: 1e-4,
training_tokens: 100_000_000,
batch_size: 2,
ewc_lambda: 2000.0,
grpo_group_size: 8,
router_repair_epochs: 5,
use_metal: false, checkpoint_every_n: 1000,
hidden_dim: 768,
checkpoint_dir: PathBuf::from("checkpoints/rlm_refiner"),
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct RefinementStepMetrics {
pub step: usize,
pub kl_divergence: f32,
pub grpo_reward: f32,
pub ewc_penalty: f32,
pub lora_correction_norm: f32,
pub learning_rate: f32,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct RefinementResult {
pub total_steps: usize,
pub tokens_processed: usize,
pub final_kl_divergence: f32,
pub final_grpo_reward: f32,
pub router_accuracy: f64,
pub checkpoint_paths: Vec<PathBuf>,
pub history: Vec<RefinementStepMetrics>,
}
pub struct RlmRefiner {
config: RlmRefinerConfig,
lora_adapters: HashMap<usize, MicroLoRA>,
ewc: EwcRegularizer,
grpo: GrpoOptimizer,
contrastive: ContrastiveTrainer,
training_pipeline: TrainingPipeline,
global_step: usize,
metrics_history: Vec<RefinementStepMetrics>,
}
impl RlmRefiner {
pub fn new(config: RlmRefinerConfig, num_expert_layers: usize) -> Result<Self> {
let mut lora_adapters = HashMap::with_capacity(num_expert_layers);
let lora_config = MicroLoraConfig {
rank: config.lora_rank.clamp(1, 2),
alpha: (config.lora_rank as f32) * 2.0,
dropout: 0.0,
target_modules: TargetModule::mlp(),
in_features: config.hidden_dim,
out_features: config.hidden_dim,
use_bias: false,
standard_init: true,
gradient_checkpointing: false,
};
for layer_idx in 0..num_expert_layers {
lora_adapters.insert(layer_idx, MicroLoRA::new(lora_config.clone()));
}
let ewc = EwcRegularizer::new(config.ewc_lambda, 0.999);
let grpo_config = GrpoConfig {
group_size: config.grpo_group_size,
learning_rate: config.learning_rate as f32,
normalize_rewards: true,
normalize_advantages: true,
..GrpoConfig::default()
};
let grpo = GrpoOptimizer::new(grpo_config);
let contrastive_config = ContrastiveConfig {
use_metal: config.use_metal,
..ContrastiveConfig::default()
};
let contrastive = ContrastiveTrainer::new(contrastive_config)
.map_err(|e| RuvLLMError::Config(format!("ContrastiveTrainer init: {}", e)))?;
let training_config = TrainingConfig {
learning_rate: config.learning_rate,
ewc_lambda: config.ewc_lambda,
batch_size: config.batch_size,
..TrainingConfig::default()
};
let training_pipeline = TrainingPipeline::new(training_config);
Ok(Self {
config,
lora_adapters,
ewc,
grpo,
contrastive,
training_pipeline,
global_step: 0,
metrics_history: Vec::new(),
})
}
pub fn init_ewc_states(&mut self) {
for lora in self.lora_adapters.values() {
for module in &TargetModule::mlp() {
if let Some(adapter_lock) = lora.get_adapter(module) {
let adapter = adapter_lock.read();
self.ewc.init_module(*module, &adapter);
}
}
}
}
pub fn refine_step(
&mut self,
expert_idx: usize,
input: &[f32],
ternary_output: &[f32],
teacher_output: &[f32],
) -> Result<RefinementStepMetrics> {
let dim = self.config.hidden_dim;
if input.len() != dim || ternary_output.len() != dim || teacher_output.len() != dim {
return Err(RuvLLMError::InvalidOperation(format!(
"Dimension mismatch: expected {}, got input={}, ternary={}, teacher={}",
dim,
input.len(),
ternary_output.len(),
teacher_output.len(),
)));
}
let lora = self.lora_adapters.get(&expert_idx).ok_or_else(|| {
RuvLLMError::InvalidOperation(format!("No LoRA adapter for expert {}", expert_idx))
})?;
let mut lora_correction = vec![0.0f32; dim];
for module in &TargetModule::mlp() {
lora.forward_add(input, module, &mut lora_correction);
}
let combined: Vec<f32> = ternary_output
.iter()
.zip(lora_correction.iter())
.map(|(t, l)| t + l)
.collect();
let kl_divergence = kl_divergence_proxy(&combined, teacher_output);
let cosine_sim = cosine_similarity(&combined, teacher_output);
let grpo_reward = cosine_sim.max(0.0);
let advantages = self.grpo.compute_relative_advantages(&[grpo_reward]);
let _grpo_reward_normalized = advantages.first().copied().unwrap_or(0.0);
let input_arr = Array1::from_vec(input.to_vec());
let grad_output: Vec<f32> = teacher_output
.iter()
.zip(combined.iter())
.map(|(t, c)| t - c)
.collect();
let grad_arr = Array1::from_vec(grad_output);
let reward_signal = grpo_reward.max(0.01);
for module in &TargetModule::mlp() {
if let Some(adapter_lock) = lora.get_adapter(module) {
let mut adapter = adapter_lock.write();
adapter.accumulate_gradient(&input_arr, &grad_arr, reward_signal);
}
}
if (self.global_step + 1) % self.config.batch_size == 0 {
let ewc_states: HashMap<TargetModule, EwcState> = TargetModule::mlp()
.into_iter()
.filter_map(|m| self.ewc.get_state(&m).cloned().map(|s| (m, s)))
.collect();
lora.apply_updates_with_ewc(
self.config.learning_rate,
&ewc_states,
self.config.ewc_lambda,
);
}
let lora_correction_norm = lora_correction.iter().map(|v| v * v).sum::<f32>().sqrt();
let metrics = RefinementStepMetrics {
step: self.global_step,
kl_divergence,
grpo_reward,
ewc_penalty: self.ewc.lambda(),
lora_correction_norm,
learning_rate: self.config.learning_rate,
};
self.metrics_history.push(metrics.clone());
self.global_step += 1;
if self.global_step % self.config.checkpoint_every_n == 0 {
let _ = self.save_checkpoint(self.global_step);
}
Ok(metrics)
}
pub fn repair_router<P: AsRef<Path>>(&mut self, triplet_path: P) -> Result<f64> {
let count = self
.contrastive
.load_triplets(triplet_path)
.map_err(|e| RuvLLMError::Config(format!("Load triplets: {}", e)))?;
if count == 0 {
return Err(RuvLLMError::InvalidOperation(
"No router repair triplets loaded".to_string(),
));
}
let result = self
.contrastive
.train(self.config.router_repair_epochs)
.map_err(|e| RuvLLMError::InvalidOperation(format!("Router repair failed: {}", e)))?;
Ok(result.best_accuracy)
}
pub fn save_checkpoint(&self, step: usize) -> Result<PathBuf> {
let dir = self.config.checkpoint_dir.join(format!("step_{}", step));
std::fs::create_dir_all(&dir)?;
for (&layer_idx, lora) in &self.lora_adapters {
let path = dir.join(format!("expert_{}_lora.bin", layer_idx));
lora.save(path.to_str().unwrap_or("lora.bin"))?;
}
let ewc_export = self.ewc.export_states();
let ewc_bytes = bincode::serde::encode_to_vec(&ewc_export, bincode::config::standard())
.map_err(|e| RuvLLMError::Serialization(e.to_string()))?;
std::fs::write(dir.join("ewc_states.bin"), ewc_bytes)?;
let metrics_json = serde_json::to_string_pretty(&self.metrics_history)
.map_err(|e| RuvLLMError::Serialization(e.to_string()))?;
std::fs::write(dir.join("metrics.json"), metrics_json)?;
Ok(dir)
}
pub fn export_refined_model<P: AsRef<Path>>(&self, output_dir: P) -> Result<PathBuf> {
let dir = output_dir.as_ref();
std::fs::create_dir_all(dir)?;
for (&layer_idx, lora) in &self.lora_adapters {
let state = lora.export_state();
let bytes = bincode::serde::encode_to_vec(&state, bincode::config::standard())
.map_err(|e| RuvLLMError::Serialization(e.to_string()))?;
std::fs::write(
dir.join(format!("expert_{}_lora_state.bin", layer_idx)),
bytes,
)?;
}
let ewc_export = self.ewc.export_states();
let ewc_bytes = bincode::serde::encode_to_vec(&ewc_export, bincode::config::standard())
.map_err(|e| RuvLLMError::Serialization(e.to_string()))?;
std::fs::write(dir.join("ewc_states.bin"), ewc_bytes)?;
let config_json = serde_json::to_string_pretty(&self.config)
.map_err(|e| RuvLLMError::Serialization(e.to_string()))?;
std::fs::write(dir.join("refiner_config.json"), config_json)?;
Ok(dir.to_path_buf())
}
pub fn result_summary(&self) -> RefinementResult {
let final_kl = self
.metrics_history
.last()
.map(|m| m.kl_divergence)
.unwrap_or(0.0);
let final_reward = self
.metrics_history
.last()
.map(|m| m.grpo_reward)
.unwrap_or(0.0);
RefinementResult {
total_steps: self.global_step,
tokens_processed: self.global_step * self.config.batch_size,
final_kl_divergence: final_kl,
final_grpo_reward: final_reward,
router_accuracy: 0.0, checkpoint_paths: Vec::new(),
history: self.metrics_history.clone(),
}
}
pub fn global_step(&self) -> usize {
self.global_step
}
pub fn config(&self) -> &RlmRefinerConfig {
&self.config
}
pub fn get_expert_lora(&self, expert_idx: usize) -> Option<&MicroLoRA> {
self.lora_adapters.get(&expert_idx)
}
pub fn total_trainable_params(&self) -> usize {
self.lora_adapters.values().map(|l| l.param_count()).sum()
}
pub fn total_lora_memory_bytes(&self) -> usize {
self.lora_adapters.values().map(|l| l.memory_bytes()).sum()
}
}
fn kl_divergence_proxy(predicted: &[f32], target: &[f32]) -> f32 {
if predicted.len() != target.len() || predicted.is_empty() {
return 0.0;
}
let mse: f32 = predicted
.iter()
.zip(target.iter())
.map(|(p, t)| {
let d = p - t;
d * d
})
.sum();
mse / predicted.len() as f32
}
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() || a.is_empty() {
return 0.0;
}
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a > 1e-8 && norm_b > 1e-8 {
dot / (norm_a * norm_b)
} else {
0.0
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_default() {
let config = RlmRefinerConfig::default();
assert_eq!(config.lora_rank, 2);
assert!(!config.use_metal); assert_eq!(config.ewc_lambda, 2000.0);
assert_eq!(config.grpo_group_size, 8);
}
#[test]
fn test_refiner_creation() {
let config = RlmRefinerConfig {
hidden_dim: 64,
..Default::default()
};
let refiner = RlmRefiner::new(config, 4).unwrap();
assert_eq!(refiner.lora_adapters.len(), 4);
assert_eq!(refiner.global_step(), 0);
assert!(refiner.total_trainable_params() > 0);
assert!(refiner.total_lora_memory_bytes() > 0);
}
#[test]
fn test_refine_step() {
let config = RlmRefinerConfig {
hidden_dim: 64,
batch_size: 1,
..Default::default()
};
let mut refiner = RlmRefiner::new(config, 2).unwrap();
refiner.init_ewc_states();
let input = vec![0.1f32; 64];
let ternary_out = vec![0.5f32; 64];
let teacher_out = vec![0.6f32; 64];
let metrics = refiner
.refine_step(0, &input, &ternary_out, &teacher_out)
.unwrap();
assert_eq!(metrics.step, 0);
assert!(metrics.kl_divergence >= 0.0);
assert_eq!(refiner.global_step(), 1);
}
#[test]
fn test_refine_step_dimension_mismatch() {
let config = RlmRefinerConfig {
hidden_dim: 64,
..Default::default()
};
let mut refiner = RlmRefiner::new(config, 1).unwrap();
let result = refiner.refine_step(0, &[0.1; 32], &[0.5; 64], &[0.6; 64]);
assert!(result.is_err());
}
#[test]
fn test_refine_step_invalid_expert() {
let config = RlmRefinerConfig {
hidden_dim: 64,
..Default::default()
};
let mut refiner = RlmRefiner::new(config, 1).unwrap();
let result = refiner.refine_step(99, &[0.1; 64], &[0.5; 64], &[0.6; 64]);
assert!(result.is_err());
}
#[test]
fn test_kl_divergence_proxy() {
let a = vec![1.0, 2.0, 3.0];
let b = vec![1.0, 2.0, 3.0];
assert!((kl_divergence_proxy(&a, &b)).abs() < 1e-6);
let c = vec![2.0, 3.0, 4.0];
assert!(kl_divergence_proxy(&a, &c) > 0.0);
}
#[test]
fn test_cosine_similarity() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![1.0, 0.0, 0.0];
assert!((cosine_similarity(&a, &b) - 1.0).abs() < 1e-6);
let c = vec![0.0, 1.0, 0.0];
assert!(cosine_similarity(&a, &c).abs() < 1e-6);
}
#[test]
fn test_result_summary() {
let config = RlmRefinerConfig {
hidden_dim: 64,
batch_size: 1,
..Default::default()
};
let mut refiner = RlmRefiner::new(config, 1).unwrap();
refiner.init_ewc_states();
let input = vec![0.1f32; 64];
let ternary_out = vec![0.5f32; 64];
let teacher_out = vec![0.6f32; 64];
for _ in 0..5 {
refiner
.refine_step(0, &input, &ternary_out, &teacher_out)
.unwrap();
}
let result = refiner.result_summary();
assert_eq!(result.total_steps, 5);
assert_eq!(result.history.len(), 5);
}
#[test]
fn test_multiple_expert_training() {
let config = RlmRefinerConfig {
hidden_dim: 64,
batch_size: 1,
..Default::default()
};
let mut refiner = RlmRefiner::new(config, 4).unwrap();
refiner.init_ewc_states();
let input = vec![0.1f32; 64];
let ternary_out = vec![0.5f32; 64];
let teacher_out = vec![0.6f32; 64];
for expert in 0..4 {
for _ in 0..3 {
refiner
.refine_step(expert, &input, &ternary_out, &teacher_out)
.unwrap();
}
}
assert_eq!(refiner.global_step(), 12);
assert_eq!(refiner.result_summary().history.len(), 12);
}
}