use crate::error::{Result, RuvLLMError};
use ndarray::{Array1, Array2};
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use std::collections::VecDeque;
use std::sync::atomic::{AtomicU64, Ordering};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GrpoConfig {
pub group_size: usize,
pub learning_rate: f32,
pub kl_coefficient: f32,
pub kl_min: f32,
pub kl_max: f32,
pub kl_target: f32,
pub entropy_coefficient: f32,
pub max_grad_norm: f32,
pub gamma: f32,
pub gae_lambda: f32,
pub value_coef: f32,
pub adaptive_kl: bool,
pub update_epochs: usize,
pub mini_batch_size: usize,
pub clip_range: f32,
pub normalize_rewards: bool,
pub normalize_advantages: bool,
}
impl Default for GrpoConfig {
fn default() -> Self {
Self {
group_size: 8,
learning_rate: 1e-5,
kl_coefficient: 0.02,
kl_min: 0.001,
kl_max: 0.1,
kl_target: 0.01,
entropy_coefficient: 0.01,
max_grad_norm: 1.0,
gamma: 0.99,
gae_lambda: 0.95,
value_coef: 0.5,
adaptive_kl: true,
update_epochs: 4,
mini_batch_size: 32,
clip_range: 0.2,
normalize_rewards: true,
normalize_advantages: true,
}
}
}
impl GrpoConfig {
pub fn for_tool_use() -> Self {
Self {
group_size: 4,
learning_rate: 5e-6,
kl_coefficient: 0.05,
kl_target: 0.02,
entropy_coefficient: 0.005,
update_epochs: 2,
mini_batch_size: 16,
clip_range: 0.15,
..Default::default()
}
}
pub fn exploration() -> Self {
Self {
entropy_coefficient: 0.05,
kl_coefficient: 0.01,
clip_range: 0.3,
..Default::default()
}
}
pub fn stable() -> Self {
Self {
learning_rate: 1e-6,
kl_coefficient: 0.1,
clip_range: 0.1,
update_epochs: 2,
..Default::default()
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GrpoSample {
pub state: Vec<f32>,
pub action: usize,
pub log_prob: f32,
pub ref_log_prob: f32,
pub reward: f32,
pub done: bool,
pub value: Option<f32>,
pub tool_name: String,
pub parameters: Option<serde_json::Value>,
}
#[derive(Debug, Clone)]
pub struct SampleGroup {
pub samples: Vec<GrpoSample>,
pub group_id: u64,
pub task_context: String,
}
impl SampleGroup {
pub fn new(samples: Vec<GrpoSample>, group_id: u64, task_context: String) -> Self {
Self {
samples,
group_id,
task_context,
}
}
pub fn len(&self) -> usize {
self.samples.len()
}
pub fn is_empty(&self) -> bool {
self.samples.is_empty()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GrpoUpdateResult {
pub policy_loss: f32,
pub kl_divergence: f32,
pub entropy: f32,
pub total_loss: f32,
pub grad_norm: f32,
pub num_samples: usize,
pub avg_advantage: f32,
pub clip_fraction: f32,
pub kl_coef: f32,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct GrpoStats {
pub total_updates: u64,
pub total_samples: u64,
pub avg_reward: f32,
pub avg_policy_loss: f32,
pub avg_kl_divergence: f32,
pub avg_entropy: f32,
pub current_kl_coef: f32,
pub reward_history: Vec<f32>,
}
pub struct GrpoOptimizer {
config: GrpoConfig,
kl_coef: f32,
experience_buffer: RwLock<VecDeque<GrpoSample>>,
group_buffer: RwLock<Vec<SampleGroup>>,
update_count: AtomicU64,
stats: RwLock<GrpoStats>,
reward_mean: f32,
reward_std: f32,
advantage_mean: f32,
advantage_std: f32,
}
impl GrpoOptimizer {
pub fn new(config: GrpoConfig) -> Self {
let kl_coef = config.kl_coefficient;
Self {
config,
kl_coef,
experience_buffer: RwLock::new(VecDeque::with_capacity(10000)),
group_buffer: RwLock::new(Vec::new()),
update_count: AtomicU64::new(0),
stats: RwLock::new(GrpoStats::default()),
reward_mean: 0.0,
reward_std: 1.0,
advantage_mean: 0.0,
advantage_std: 1.0,
}
}
pub fn compute_relative_advantages(&self, rewards: &[f32]) -> Vec<f32> {
if rewards.is_empty() {
return Vec::new();
}
let mean = rewards.iter().sum::<f32>() / rewards.len() as f32;
let variance =
rewards.iter().map(|r| (r - mean).powi(2)).sum::<f32>() / rewards.len() as f32;
let std = variance.sqrt().max(1e-8);
rewards.iter().map(|r| (r - mean) / std).collect()
}
pub fn compute_gae(
&self,
rewards: &[f32],
values: &[f32],
dones: &[bool],
next_value: f32,
) -> Vec<f32> {
let n = rewards.len();
if n == 0 {
return Vec::new();
}
let mut advantages = vec![0.0f32; n];
let mut last_gae = 0.0f32;
for t in (0..n).rev() {
let next_val = if t == n - 1 {
next_value
} else {
values[t + 1]
};
let mask = if dones[t] { 0.0 } else { 1.0 };
let delta = rewards[t] + self.config.gamma * next_val * mask - values[t];
last_gae = delta + self.config.gamma * self.config.gae_lambda * mask * last_gae;
advantages[t] = last_gae;
}
advantages
}
pub fn grpo_update(
&mut self,
log_probs: &[f32],
advantages: &[f32],
ref_log_probs: &[f32],
) -> Result<GrpoUpdateResult> {
if log_probs.len() != advantages.len() || log_probs.len() != ref_log_probs.len() {
return Err(RuvLLMError::InvalidOperation(
"GRPO update: array lengths must match".to_string(),
));
}
let n = log_probs.len();
if n == 0 {
return Err(RuvLLMError::InvalidOperation(
"GRPO update: no samples provided".to_string(),
));
}
let normalized_advantages = if self.config.normalize_advantages {
self.normalize_advantages(advantages)
} else {
advantages.to_vec()
};
let ratios: Vec<f32> = log_probs
.iter()
.zip(ref_log_probs.iter())
.map(|(lp, rlp)| (lp - rlp).exp())
.collect();
let mut policy_loss = 0.0f32;
let mut clip_count = 0;
for (ratio, adv) in ratios.iter().zip(normalized_advantages.iter()) {
let surr1 = ratio * adv;
let surr2 =
ratio.clamp(1.0 - self.config.clip_range, 1.0 + self.config.clip_range) * adv;
policy_loss -= surr1.min(surr2);
if *ratio < 1.0 - self.config.clip_range || *ratio > 1.0 + self.config.clip_range {
clip_count += 1;
}
}
policy_loss /= n as f32;
let kl_divergence: f32 = log_probs
.iter()
.zip(ref_log_probs.iter())
.map(|(lp, rlp)| lp - rlp)
.sum::<f32>()
/ n as f32;
let entropy = -log_probs.iter().sum::<f32>() / n as f32;
let kl_penalty = self.kl_coef * kl_divergence;
let entropy_bonus = self.config.entropy_coefficient * entropy;
let total_loss = policy_loss + kl_penalty - entropy_bonus;
if self.config.adaptive_kl {
self.adapt_kl_coefficient(kl_divergence);
}
let grad_norm = total_loss.abs().sqrt();
let update_count = self.update_count.fetch_add(1, Ordering::SeqCst);
{
let mut stats = self.stats.write();
stats.total_updates = update_count + 1;
stats.total_samples += n as u64;
stats.avg_policy_loss = (stats.avg_policy_loss * 0.99) + (policy_loss * 0.01);
stats.avg_kl_divergence = (stats.avg_kl_divergence * 0.99) + (kl_divergence * 0.01);
stats.avg_entropy = (stats.avg_entropy * 0.99) + (entropy * 0.01);
stats.current_kl_coef = self.kl_coef;
}
Ok(GrpoUpdateResult {
policy_loss,
kl_divergence,
entropy,
total_loss,
grad_norm,
num_samples: n,
avg_advantage: normalized_advantages.iter().sum::<f32>() / n as f32,
clip_fraction: clip_count as f32 / n as f32,
kl_coef: self.kl_coef,
})
}
fn adapt_kl_coefficient(&mut self, observed_kl: f32) {
if observed_kl > self.config.kl_target * 1.5 {
self.kl_coef = (self.kl_coef * 1.5).min(self.config.kl_max);
} else if observed_kl < self.config.kl_target * 0.5 {
self.kl_coef = (self.kl_coef / 1.5).max(self.config.kl_min);
}
}
fn normalize_advantages(&self, advantages: &[f32]) -> Vec<f32> {
if advantages.is_empty() {
return Vec::new();
}
let mean = advantages.iter().sum::<f32>() / advantages.len() as f32;
let variance =
advantages.iter().map(|a| (a - mean).powi(2)).sum::<f32>() / advantages.len() as f32;
let std = variance.sqrt().max(1e-8);
advantages.iter().map(|a| (a - mean) / std).collect()
}
pub fn add_experience(&self, sample: GrpoSample) {
let mut buffer = self.experience_buffer.write();
if buffer.len() >= 10000 {
buffer.pop_front();
}
buffer.push_back(sample);
}
pub fn add_group(&self, group: SampleGroup) {
let mut groups = self.group_buffer.write();
groups.push(group);
}
pub fn process_groups(&mut self) -> Result<Vec<GrpoUpdateResult>> {
let groups = {
let mut buffer = self.group_buffer.write();
std::mem::take(&mut *buffer)
};
let mut results = Vec::new();
for group in groups {
if group.samples.is_empty() {
continue;
}
let rewards: Vec<f32> = group.samples.iter().map(|s| s.reward).collect();
let log_probs: Vec<f32> = group.samples.iter().map(|s| s.log_prob).collect();
let ref_log_probs: Vec<f32> = group.samples.iter().map(|s| s.ref_log_prob).collect();
let advantages = self.compute_relative_advantages(&rewards);
let result = self.grpo_update(&log_probs, &advantages, &ref_log_probs)?;
results.push(result);
}
Ok(results)
}
pub fn stats(&self) -> GrpoStats {
self.stats.read().clone()
}
pub fn config(&self) -> &GrpoConfig {
&self.config
}
pub fn kl_coefficient(&self) -> f32 {
self.kl_coef
}
pub fn reset(&mut self) {
self.kl_coef = self.config.kl_coefficient;
self.experience_buffer.write().clear();
self.group_buffer.write().clear();
self.update_count.store(0, Ordering::SeqCst);
*self.stats.write() = GrpoStats::default();
self.reward_mean = 0.0;
self.reward_std = 1.0;
self.advantage_mean = 0.0;
self.advantage_std = 1.0;
}
pub fn compute_returns(&self, rewards: &[f32], dones: &[bool]) -> Vec<f32> {
let n = rewards.len();
if n == 0 {
return Vec::new();
}
let mut returns = vec![0.0f32; n];
let mut running_return = 0.0f32;
for t in (0..n).rev() {
if dones[t] {
running_return = 0.0;
}
running_return = rewards[t] + self.config.gamma * running_return;
returns[t] = running_return;
}
returns
}
}
#[derive(Debug, Clone)]
pub struct GrpoBatch {
pub states: Array2<f32>,
pub actions: Vec<usize>,
pub log_probs: Array1<f32>,
pub ref_log_probs: Array1<f32>,
pub advantages: Array1<f32>,
pub returns: Array1<f32>,
pub values: Array1<f32>,
}
impl GrpoBatch {
pub fn from_samples(samples: &[GrpoSample], embedding_dim: usize) -> Option<Self> {
if samples.is_empty() {
return None;
}
let n = samples.len();
let mut states = Array2::zeros((n, embedding_dim));
for (i, sample) in samples.iter().enumerate() {
for (j, &val) in sample.state.iter().enumerate().take(embedding_dim) {
states[[i, j]] = val;
}
}
let actions: Vec<usize> = samples.iter().map(|s| s.action).collect();
let log_probs = Array1::from_vec(samples.iter().map(|s| s.log_prob).collect());
let ref_log_probs = Array1::from_vec(samples.iter().map(|s| s.ref_log_prob).collect());
let advantages = Array1::zeros(n);
let returns = Array1::zeros(n);
let values = Array1::from_vec(samples.iter().map(|s| s.value.unwrap_or(0.0)).collect());
Some(Self {
states,
actions,
log_probs,
ref_log_probs,
advantages,
returns,
values,
})
}
pub fn len(&self) -> usize {
self.actions.len()
}
pub fn is_empty(&self) -> bool {
self.actions.is_empty()
}
pub fn into_mini_batches(self, mini_batch_size: usize) -> Vec<GrpoBatch> {
let n = self.len();
if n <= mini_batch_size {
return vec![self];
}
let num_batches = (n + mini_batch_size - 1) / mini_batch_size;
let mut batches = Vec::with_capacity(num_batches);
for i in 0..num_batches {
let start = i * mini_batch_size;
let end = (start + mini_batch_size).min(n);
let states = self.states.slice(ndarray::s![start..end, ..]).to_owned();
let actions = self.actions[start..end].to_vec();
let log_probs = self.log_probs.slice(ndarray::s![start..end]).to_owned();
let ref_log_probs = self.ref_log_probs.slice(ndarray::s![start..end]).to_owned();
let advantages = self.advantages.slice(ndarray::s![start..end]).to_owned();
let returns = self.returns.slice(ndarray::s![start..end]).to_owned();
let values = self.values.slice(ndarray::s![start..end]).to_owned();
batches.push(GrpoBatch {
states,
actions,
log_probs,
ref_log_probs,
advantages,
returns,
values,
});
}
batches
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_grpo_config_default() {
let config = GrpoConfig::default();
assert_eq!(config.group_size, 8);
assert!((config.learning_rate - 1e-5).abs() < 1e-10);
}
#[test]
fn test_compute_relative_advantages() {
let optimizer = GrpoOptimizer::new(GrpoConfig::default());
let rewards = vec![0.8, 0.6, 0.9, 0.5];
let advantages = optimizer.compute_relative_advantages(&rewards);
assert_eq!(advantages.len(), 4);
let mean: f32 = advantages.iter().sum::<f32>() / advantages.len() as f32;
assert!(mean.abs() < 1e-5);
let max_reward_idx = rewards
.iter()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
.map(|(i, _)| i)
.unwrap();
let max_advantage_idx = advantages
.iter()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
.map(|(i, _)| i)
.unwrap();
assert_eq!(max_reward_idx, max_advantage_idx);
}
#[test]
fn test_grpo_update() {
let mut optimizer = GrpoOptimizer::new(GrpoConfig::default());
let log_probs = vec![-0.5, -0.3, -0.7, -0.4];
let advantages = vec![0.5, 0.2, -0.3, 0.1];
let ref_log_probs = vec![-0.5, -0.3, -0.7, -0.4];
let result = optimizer
.grpo_update(&log_probs, &advantages, &ref_log_probs)
.unwrap();
assert_eq!(result.num_samples, 4);
assert!(result.kl_divergence.abs() < 1e-5); }
#[test]
fn test_compute_gae() {
let optimizer = GrpoOptimizer::new(GrpoConfig::default());
let rewards = vec![1.0, 0.0, 1.0, 0.0];
let values = vec![0.5, 0.5, 0.5, 0.5];
let dones = vec![false, false, false, true];
let next_value = 0.5;
let advantages = optimizer.compute_gae(&rewards, &values, &dones, next_value);
assert_eq!(advantages.len(), 4);
let expected_last = rewards[3] + 0.0 - values[3]; assert!((advantages[3] - expected_last).abs() < 1e-5);
}
#[test]
fn test_compute_returns() {
let optimizer = GrpoOptimizer::new(GrpoConfig {
gamma: 0.9,
..Default::default()
});
let rewards = vec![1.0, 1.0, 1.0];
let dones = vec![false, false, true];
let returns = optimizer.compute_returns(&rewards, &dones);
assert_eq!(returns.len(), 3);
assert!((returns[2] - 1.0).abs() < 1e-5);
assert!((returns[1] - 1.9).abs() < 1e-5);
assert!((returns[0] - 2.71).abs() < 1e-5);
}
#[test]
fn test_adaptive_kl() {
let mut optimizer = GrpoOptimizer::new(GrpoConfig {
adaptive_kl: true,
kl_coefficient: 0.02,
kl_target: 0.01,
kl_min: 0.001,
kl_max: 0.1,
..Default::default()
});
optimizer.adapt_kl_coefficient(0.05); assert!(optimizer.kl_coef > 0.02);
optimizer.kl_coef = 0.02;
optimizer.adapt_kl_coefficient(0.001); assert!(optimizer.kl_coef < 0.02);
}
#[test]
fn test_grpo_sample() {
let sample = GrpoSample {
state: vec![0.1, 0.2, 0.3],
action: 5,
log_prob: -0.5,
ref_log_prob: -0.5,
reward: 0.8,
done: false,
value: Some(0.7),
tool_name: "agent_spawn".to_string(),
parameters: None,
};
assert_eq!(sample.action, 5);
assert_eq!(sample.tool_name, "agent_spawn");
}
#[test]
fn test_sample_group() {
let samples = vec![
GrpoSample {
state: vec![0.1, 0.2],
action: 0,
log_prob: -0.5,
ref_log_prob: -0.5,
reward: 0.8,
done: false,
value: None,
tool_name: "memory_store".to_string(),
parameters: None,
},
GrpoSample {
state: vec![0.3, 0.4],
action: 1,
log_prob: -0.3,
ref_log_prob: -0.3,
reward: 0.6,
done: false,
value: None,
tool_name: "memory_search".to_string(),
parameters: None,
},
];
let group = SampleGroup::new(samples, 1, "test task".to_string());
assert_eq!(group.len(), 2);
assert_eq!(group.group_id, 1);
assert!(!group.is_empty());
}
#[test]
fn test_batch_creation() {
let samples = vec![
GrpoSample {
state: vec![0.1, 0.2, 0.3, 0.4],
action: 0,
log_prob: -0.5,
ref_log_prob: -0.5,
reward: 0.8,
done: false,
value: Some(0.7),
tool_name: "test".to_string(),
parameters: None,
},
GrpoSample {
state: vec![0.5, 0.6, 0.7, 0.8],
action: 1,
log_prob: -0.3,
ref_log_prob: -0.3,
reward: 0.6,
done: true,
value: Some(0.5),
tool_name: "test2".to_string(),
parameters: None,
},
];
let batch = GrpoBatch::from_samples(&samples, 4).unwrap();
assert_eq!(batch.len(), 2);
assert_eq!(batch.states.shape(), &[2, 4]);
}
#[test]
fn test_mini_batches() {
let samples: Vec<GrpoSample> = (0..10)
.map(|i| GrpoSample {
state: vec![i as f32; 4],
action: i,
log_prob: -(i as f32) * 0.1,
ref_log_prob: -(i as f32) * 0.1,
reward: i as f32 * 0.1,
done: false,
value: None,
tool_name: format!("tool_{}", i),
parameters: None,
})
.collect();
let batch = GrpoBatch::from_samples(&samples, 4).unwrap();
let mini_batches = batch.into_mini_batches(3);
assert_eq!(mini_batches.len(), 4); assert_eq!(mini_batches[0].len(), 3);
assert_eq!(mini_batches[1].len(), 3);
assert_eq!(mini_batches[2].len(), 3);
assert_eq!(mini_batches[3].len(), 1);
}
}