use crate::error::{Result, RuvLLMError};
use crate::training::grpo::{GrpoConfig, GrpoOptimizer, GrpoSample, GrpoUpdateResult, SampleGroup};
use crate::training::tool_dataset::{
DifficultyLevel, McpToolDef, ToolCallDataset, ToolCallExample, ToolDatasetConfig,
};
use ndarray::Array2;
use parking_lot::RwLock;
use rand::{rngs::StdRng, Rng, SeedableRng};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct McpTrainingConfig {
pub grpo: GrpoConfig,
pub embedding_dim: usize,
pub max_seq_length: usize,
pub batch_size: usize,
pub epochs: usize,
pub supervised_lr: f32,
pub warmup_steps: usize,
pub eval_frequency: usize,
pub checkpoint_frequency: usize,
pub seed: u64,
pub mixed_precision: bool,
pub gradient_accumulation: usize,
pub max_grad_norm: f32,
pub label_smoothing: f32,
pub weight_decay: f32,
pub train_params: bool,
pub train_error_recovery: bool,
}
impl Default for McpTrainingConfig {
fn default() -> Self {
Self {
grpo: GrpoConfig::for_tool_use(),
embedding_dim: 768,
max_seq_length: 2048,
batch_size: 16,
epochs: 10,
supervised_lr: 2e-5,
warmup_steps: 500,
eval_frequency: 100,
checkpoint_frequency: 1000,
seed: 42,
mixed_precision: true,
gradient_accumulation: 4,
max_grad_norm: 1.0,
label_smoothing: 0.1,
weight_decay: 0.01,
train_params: true,
train_error_recovery: true,
}
}
}
impl McpTrainingConfig {
pub fn quick() -> Self {
Self {
batch_size: 8,
epochs: 3,
eval_frequency: 50,
checkpoint_frequency: 500,
gradient_accumulation: 2,
..Default::default()
}
}
pub fn production() -> Self {
Self {
batch_size: 32,
epochs: 20,
eval_frequency: 200,
checkpoint_frequency: 2000,
gradient_accumulation: 8,
train_params: true,
train_error_recovery: true,
..Default::default()
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolTrajectory {
pub id: String,
pub task: String,
pub steps: Vec<TrajectoryStep>,
pub success: bool,
pub total_reward: f32,
pub metadata: TrajectoryMetadata,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrajectoryStep {
pub tool_name: String,
pub parameters: serde_json::Value,
pub state_embedding: Vec<f32>,
pub log_prob: f32,
pub ref_log_prob: f32,
pub reward: f32,
pub success: bool,
pub error: Option<String>,
pub duration_ms: u64,
pub next_state_embedding: Option<Vec<f32>>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct TrajectoryMetadata {
pub timestamp: u64,
pub user_id: Option<String>,
pub session_id: Option<String>,
pub complexity: Option<DifficultyLevel>,
pub domain: Option<String>,
pub context: HashMap<String, String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrainingResult {
pub avg_loss: f32,
pub tool_accuracy: f32,
pub param_accuracy: Option<f32>,
pub grpo_results: Vec<GrpoUpdateResult>,
pub samples_processed: usize,
pub step: u64,
pub grad_norm: f32,
pub learning_rate: f32,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct EvaluationMetrics {
pub tool_accuracy: f32,
pub accuracy_by_category: HashMap<String, f32>,
pub accuracy_by_difficulty: HashMap<String, f32>,
pub param_accuracy: f32,
pub error_recovery_rate: f32,
pub avg_reward: f32,
pub num_samples: usize,
pub confusion: HashMap<String, HashMap<String, usize>>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct TrainingStats {
pub total_steps: u64,
pub total_samples: u64,
pub total_trajectories: u64,
pub avg_loss: f32,
pub best_accuracy: f32,
pub current_lr: f32,
pub loss_history: Vec<f32>,
pub eval_history: Vec<f32>,
}
pub struct McpToolTrainer {
config: McpTrainingConfig,
grpo: GrpoOptimizer,
tool_defs: Vec<McpToolDef>,
tool_to_idx: HashMap<String, usize>,
idx_to_tool: Vec<String>,
stats: RwLock<TrainingStats>,
step: AtomicU64,
rng: RwLock<StdRng>,
trajectory_buffer: RwLock<Vec<ToolTrajectory>>,
tool_embeddings: RwLock<Array2<f32>>,
}
impl McpToolTrainer {
pub fn new(config: McpTrainingConfig) -> Result<Self> {
let grpo = GrpoOptimizer::new(config.grpo.clone());
let rng = StdRng::seed_from_u64(config.seed);
Ok(Self {
config,
grpo,
tool_defs: Vec::new(),
tool_to_idx: HashMap::new(),
idx_to_tool: Vec::new(),
stats: RwLock::new(TrainingStats::default()),
step: AtomicU64::new(0),
rng: RwLock::new(rng),
trajectory_buffer: RwLock::new(Vec::new()),
tool_embeddings: RwLock::new(Array2::zeros((0, 0))),
})
}
pub fn load_tool_definitions(&mut self) -> Result<()> {
let config = ToolDatasetConfig::minimal();
let dataset = ToolCallDataset::generate(config)?;
self.tool_defs = dataset.tool_definitions;
self.tool_to_idx.clear();
self.idx_to_tool.clear();
for (idx, tool) in self.tool_defs.iter().enumerate() {
self.tool_to_idx.insert(tool.name.clone(), idx);
self.idx_to_tool.push(tool.name.clone());
}
let num_tools = self.tool_defs.len();
let embed_dim = self.config.embedding_dim;
let mut rng = self.rng.write();
let mut embeddings = Array2::zeros((num_tools, embed_dim));
for i in 0..num_tools {
for j in 0..embed_dim {
embeddings[[i, j]] = rng.gen::<f32>() * 0.02 - 0.01; }
}
*self.tool_embeddings.write() = embeddings;
Ok(())
}
pub fn num_tools(&self) -> usize {
self.tool_defs.len()
}
pub fn tool_index(&self, name: &str) -> Option<usize> {
self.tool_to_idx.get(name).copied()
}
pub fn tool_name(&self, idx: usize) -> Option<&str> {
self.idx_to_tool.get(idx).map(|s| s.as_str())
}
pub fn add_trajectory(&self, trajectory: ToolTrajectory) {
let mut buffer = self.trajectory_buffer.write();
buffer.push(trajectory);
self.stats.write().total_trajectories += 1;
}
pub fn train_on_trajectories(
&mut self,
trajectories: &[ToolTrajectory],
) -> Result<TrainingResult> {
if trajectories.is_empty() {
return Err(RuvLLMError::InvalidOperation(
"No trajectories provided for training".to_string(),
));
}
let mut all_samples = Vec::new();
let mut all_groups = Vec::new();
for trajectory in trajectories {
let samples = self.trajectory_to_samples(trajectory)?;
let group = SampleGroup::new(
samples.clone(),
self.step.load(Ordering::SeqCst),
trajectory.task.clone(),
);
all_groups.push(group);
all_samples.extend(samples);
}
for group in all_groups {
self.grpo.add_group(group);
}
let grpo_results = self.grpo.process_groups()?;
let avg_loss = if grpo_results.is_empty() {
0.0
} else {
grpo_results.iter().map(|r| r.total_loss).sum::<f32>() / grpo_results.len() as f32
};
let step = self.step.fetch_add(1, Ordering::SeqCst);
{
let mut stats = self.stats.write();
stats.total_steps = step + 1;
stats.total_samples += all_samples.len() as u64;
stats.avg_loss = (stats.avg_loss * 0.99) + (avg_loss * 0.01);
stats.loss_history.push(avg_loss);
}
let tool_accuracy = self.compute_batch_accuracy(&all_samples);
Ok(TrainingResult {
avg_loss,
tool_accuracy,
param_accuracy: if self.config.train_params {
Some(self.compute_param_accuracy(&all_samples))
} else {
None
},
grpo_results,
samples_processed: all_samples.len(),
step,
grad_norm: avg_loss.abs().sqrt(), learning_rate: self.config.supervised_lr,
})
}
fn trajectory_to_samples(&self, trajectory: &ToolTrajectory) -> Result<Vec<GrpoSample>> {
let mut samples = Vec::new();
for (i, step) in trajectory.steps.iter().enumerate() {
let action = self.tool_index(&step.tool_name).unwrap_or(0);
let is_done = i == trajectory.steps.len() - 1;
samples.push(GrpoSample {
state: step.state_embedding.clone(),
action,
log_prob: step.log_prob,
ref_log_prob: step.ref_log_prob,
reward: step.reward,
done: is_done,
value: None,
tool_name: step.tool_name.clone(),
parameters: Some(step.parameters.clone()),
});
}
Ok(samples)
}
fn compute_batch_accuracy(&self, samples: &[GrpoSample]) -> f32 {
if samples.is_empty() {
return 0.0;
}
let correct = samples.iter().filter(|s| s.reward > 0.5).count();
correct as f32 / samples.len() as f32
}
fn compute_param_accuracy(&self, samples: &[GrpoSample]) -> f32 {
if samples.is_empty() {
return 0.0;
}
let valid = samples
.iter()
.filter(|s| {
s.parameters
.as_ref()
.map(|p| p.is_object() && !p.as_object().unwrap().is_empty())
.unwrap_or(false)
})
.count();
valid as f32 / samples.len() as f32
}
pub fn evaluate_tool_accuracy(
&self,
test_examples: &[ToolCallExample],
) -> Result<EvaluationMetrics> {
if test_examples.is_empty() {
return Ok(EvaluationMetrics::default());
}
let mut metrics = EvaluationMetrics::default();
let mut correct = 0;
let mut by_category: HashMap<String, (usize, usize)> = HashMap::new(); let mut by_difficulty: HashMap<String, (usize, usize)> = HashMap::new();
let mut confusion: HashMap<String, HashMap<String, usize>> = HashMap::new();
for example in test_examples {
let predicted = self.predict_tool(&example.prompt)?;
let is_correct = predicted == example.expected_tool;
if is_correct {
correct += 1;
}
let cat_key = example.category.name().to_string();
let entry = by_category.entry(cat_key.clone()).or_insert((0, 0));
if is_correct {
entry.0 += 1;
}
entry.1 += 1;
let diff_key = format!("{:?}", example.difficulty);
let entry = by_difficulty.entry(diff_key.clone()).or_insert((0, 0));
if is_correct {
entry.0 += 1;
}
entry.1 += 1;
*confusion
.entry(example.expected_tool.clone())
.or_default()
.entry(predicted)
.or_insert(0) += 1;
metrics.avg_reward += example.quality_score;
}
metrics.tool_accuracy = correct as f32 / test_examples.len() as f32;
metrics.num_samples = test_examples.len();
metrics.avg_reward /= test_examples.len() as f32;
for (cat, (c, t)) in by_category {
metrics
.accuracy_by_category
.insert(cat, c as f32 / t as f32);
}
for (diff, (c, t)) in by_difficulty {
metrics
.accuracy_by_difficulty
.insert(diff, c as f32 / t as f32);
}
metrics.confusion = confusion;
{
let mut stats = self.stats.write();
if metrics.tool_accuracy > stats.best_accuracy {
stats.best_accuracy = metrics.tool_accuracy;
}
stats.eval_history.push(metrics.tool_accuracy);
}
Ok(metrics)
}
pub fn predict_tool(&self, prompt: &str) -> Result<String> {
let prompt_lower = prompt.to_lowercase();
for tool in &self.tool_defs {
for use_case in &tool.use_cases {
if prompt_lower.contains(&use_case.to_lowercase()) {
return Ok(tool.name.clone());
}
}
}
if prompt_lower.contains("spawn") || prompt_lower.contains("agent") {
return Ok("agent_spawn".to_string());
}
if prompt_lower.contains("memory") || prompt_lower.contains("store") {
return Ok("memory_store".to_string());
}
if prompt_lower.contains("search") {
return Ok("memory_search".to_string());
}
if prompt_lower.contains("swarm") || prompt_lower.contains("initialize") {
return Ok("swarm_init".to_string());
}
if prompt_lower.contains("task") {
return Ok("task_create".to_string());
}
if prompt_lower.contains("hook") || prompt_lower.contains("route") {
return Ok("hooks_route".to_string());
}
Ok("system_status".to_string())
}
pub fn generate_tool_dataset(&self, config: ToolDatasetConfig) -> Result<ToolCallDataset> {
ToolCallDataset::generate(config)
}
pub fn stats(&self) -> TrainingStats {
self.stats.read().clone()
}
pub fn grpo_stats(&self) -> crate::training::grpo::GrpoStats {
self.grpo.stats()
}
pub fn reset(&mut self) {
self.grpo.reset();
self.step.store(0, Ordering::SeqCst);
*self.stats.write() = TrainingStats::default();
self.trajectory_buffer.write().clear();
}
pub fn config(&self) -> &McpTrainingConfig {
&self.config
}
pub fn tool_definitions(&self) -> &[McpToolDef] {
&self.tool_defs
}
pub fn compute_reward(
&self,
predicted_tool: &str,
expected_tool: &str,
params_correct: bool,
execution_success: bool,
) -> f32 {
let mut reward = 0.0;
if predicted_tool == expected_tool {
reward += 0.5;
} else if self.same_category(predicted_tool, expected_tool) {
reward += 0.2; }
if params_correct {
reward += 0.3;
}
if execution_success {
reward += 0.2;
}
reward
}
fn same_category(&self, tool1: &str, tool2: &str) -> bool {
let cat1 = self
.tool_defs
.iter()
.find(|t| t.name == tool1)
.map(|t| t.category);
let cat2 = self
.tool_defs
.iter()
.find(|t| t.name == tool2)
.map(|t| t.category);
cat1.is_some() && cat1 == cat2
}
pub fn train_buffered(&mut self) -> Result<Option<TrainingResult>> {
let trajectories = {
let mut buffer = self.trajectory_buffer.write();
if buffer.is_empty() {
return Ok(None);
}
std::mem::take(&mut *buffer)
};
let result = self.train_on_trajectories(&trajectories)?;
Ok(Some(result))
}
pub fn export_checkpoint(&self) -> TrainingCheckpoint {
TrainingCheckpoint {
step: self.step.load(Ordering::SeqCst),
stats: self.stats.read().clone(),
grpo_stats: self.grpo.stats(),
tool_embeddings: {
let (vec, _offset) = self
.tool_embeddings
.read()
.clone()
.into_raw_vec_and_offset();
vec
},
embedding_shape: {
let emb = self.tool_embeddings.read();
(emb.nrows(), emb.ncols())
},
config: self.config.clone(),
}
}
pub fn import_checkpoint(&mut self, checkpoint: TrainingCheckpoint) -> Result<()> {
self.step.store(checkpoint.step, Ordering::SeqCst);
*self.stats.write() = checkpoint.stats;
let (rows, cols) = checkpoint.embedding_shape;
if checkpoint.tool_embeddings.len() == rows * cols {
let embeddings = Array2::from_shape_vec((rows, cols), checkpoint.tool_embeddings)
.map_err(|e| RuvLLMError::InvalidOperation(e.to_string()))?;
*self.tool_embeddings.write() = embeddings;
}
self.config = checkpoint.config;
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrainingCheckpoint {
pub step: u64,
pub stats: TrainingStats,
pub grpo_stats: crate::training::grpo::GrpoStats,
pub tool_embeddings: Vec<f32>,
pub embedding_shape: (usize, usize),
pub config: McpTrainingConfig,
}
pub struct TrajectoryBuilder {
id: String,
task: String,
steps: Vec<TrajectoryStep>,
metadata: TrajectoryMetadata,
}
impl TrajectoryBuilder {
pub fn new(id: impl Into<String>, task: impl Into<String>) -> Self {
Self {
id: id.into(),
task: task.into(),
steps: Vec::new(),
metadata: TrajectoryMetadata::default(),
}
}
pub fn add_step(mut self, step: TrajectoryStep) -> Self {
self.steps.push(step);
self
}
pub fn with_metadata(mut self, metadata: TrajectoryMetadata) -> Self {
self.metadata = metadata;
self
}
pub fn with_complexity(mut self, complexity: DifficultyLevel) -> Self {
self.metadata.complexity = Some(complexity);
self
}
pub fn with_session(mut self, session_id: impl Into<String>) -> Self {
self.metadata.session_id = Some(session_id.into());
self
}
pub fn build(self) -> ToolTrajectory {
let success = self.steps.last().map(|s| s.success).unwrap_or(false);
let total_reward = self.steps.iter().map(|s| s.reward).sum();
ToolTrajectory {
id: self.id,
task: self.task,
steps: self.steps,
success,
total_reward,
metadata: self.metadata,
}
}
}
pub struct StepBuilder {
tool_name: String,
parameters: serde_json::Value,
state_embedding: Vec<f32>,
log_prob: f32,
ref_log_prob: f32,
reward: f32,
success: bool,
error: Option<String>,
duration_ms: u64,
next_state_embedding: Option<Vec<f32>>,
}
impl StepBuilder {
pub fn new(tool_name: impl Into<String>) -> Self {
Self {
tool_name: tool_name.into(),
parameters: serde_json::Value::Object(serde_json::Map::new()),
state_embedding: Vec::new(),
log_prob: 0.0,
ref_log_prob: 0.0,
reward: 0.0,
success: true,
error: None,
duration_ms: 0,
next_state_embedding: None,
}
}
pub fn with_params(mut self, params: serde_json::Value) -> Self {
self.parameters = params;
self
}
pub fn with_state(mut self, embedding: Vec<f32>) -> Self {
self.state_embedding = embedding;
self
}
pub fn with_log_prob(mut self, log_prob: f32) -> Self {
self.log_prob = log_prob;
self
}
pub fn with_ref_log_prob(mut self, ref_log_prob: f32) -> Self {
self.ref_log_prob = ref_log_prob;
self
}
pub fn with_reward(mut self, reward: f32) -> Self {
self.reward = reward;
self
}
pub fn with_success(mut self, success: bool) -> Self {
self.success = success;
self
}
pub fn with_error(mut self, error: impl Into<String>) -> Self {
self.error = Some(error.into());
self.success = false;
self
}
pub fn with_duration(mut self, ms: u64) -> Self {
self.duration_ms = ms;
self
}
pub fn build(self) -> TrajectoryStep {
TrajectoryStep {
tool_name: self.tool_name,
parameters: self.parameters,
state_embedding: self.state_embedding,
log_prob: self.log_prob,
ref_log_prob: self.ref_log_prob,
reward: self.reward,
success: self.success,
error: self.error,
duration_ms: self.duration_ms,
next_state_embedding: self.next_state_embedding,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_trainer_creation() {
let config = McpTrainingConfig::default();
let trainer = McpToolTrainer::new(config).unwrap();
assert_eq!(trainer.num_tools(), 0);
}
#[test]
fn test_load_tool_definitions() {
let config = McpTrainingConfig::default();
let mut trainer = McpToolTrainer::new(config).unwrap();
trainer.load_tool_definitions().unwrap();
assert!(trainer.num_tools() > 0);
assert!(trainer.tool_index("agent_spawn").is_some());
assert!(trainer.tool_index("memory_store").is_some());
}
#[test]
fn test_predict_tool() {
let config = McpTrainingConfig::default();
let mut trainer = McpToolTrainer::new(config).unwrap();
trainer.load_tool_definitions().unwrap();
let prediction = trainer.predict_tool("spawn a coder agent").unwrap();
assert_eq!(prediction, "agent_spawn");
let prediction = trainer.predict_tool("store this in memory").unwrap();
assert_eq!(prediction, "memory_store");
}
#[test]
fn test_generate_dataset() {
let config = McpTrainingConfig::default();
let trainer = McpToolTrainer::new(config).unwrap();
let dataset_config = ToolDatasetConfig::minimal();
let dataset = trainer.generate_tool_dataset(dataset_config).unwrap();
assert!(!dataset.examples.is_empty());
}
#[test]
fn test_trajectory_builder() {
let step1 = StepBuilder::new("agent_spawn")
.with_params(serde_json::json!({"agentType": "coder"}))
.with_state(vec![0.1, 0.2, 0.3])
.with_reward(0.8)
.build();
let step2 = StepBuilder::new("task_create")
.with_params(serde_json::json!({"type": "feature"}))
.with_state(vec![0.4, 0.5, 0.6])
.with_reward(0.9)
.build();
let trajectory = TrajectoryBuilder::new("traj-1", "implement authentication")
.add_step(step1)
.add_step(step2)
.with_complexity(DifficultyLevel::Medium)
.build();
assert_eq!(trajectory.steps.len(), 2);
assert!(trajectory.success);
assert!((trajectory.total_reward - 1.7).abs() < 0.01);
}
#[test]
fn test_compute_reward() {
let config = McpTrainingConfig::default();
let mut trainer = McpToolTrainer::new(config).unwrap();
trainer.load_tool_definitions().unwrap();
let reward = trainer.compute_reward("agent_spawn", "agent_spawn", true, true);
assert!((reward - 1.0).abs() < 0.01);
let reward = trainer.compute_reward("memory_store", "agent_spawn", false, false);
assert!(reward < 0.3); }
#[test]
fn test_train_on_trajectories() {
let config = McpTrainingConfig::quick();
let mut trainer = McpToolTrainer::new(config).unwrap();
trainer.load_tool_definitions().unwrap();
let step = StepBuilder::new("agent_spawn")
.with_params(serde_json::json!({"agentType": "coder"}))
.with_state(vec![0.1; 768])
.with_log_prob(-0.5)
.with_ref_log_prob(-0.5)
.with_reward(0.8)
.build();
let trajectory = TrajectoryBuilder::new("test-traj", "test task")
.add_step(step)
.build();
let result = trainer.train_on_trajectories(&[trajectory]).unwrap();
assert!(result.samples_processed > 0);
}
#[test]
fn test_evaluate_accuracy() {
let config = McpTrainingConfig::default();
let mut trainer = McpToolTrainer::new(config).unwrap();
trainer.load_tool_definitions().unwrap();
let dataset_config = ToolDatasetConfig::minimal();
let dataset = trainer.generate_tool_dataset(dataset_config).unwrap();
let metrics = trainer
.evaluate_tool_accuracy(&dataset.examples[..5])
.unwrap();
assert!(metrics.num_samples == 5);
assert!(metrics.tool_accuracy >= 0.0 && metrics.tool_accuracy <= 1.0);
}
#[test]
fn test_checkpoint() {
let config = McpTrainingConfig::default();
let mut trainer = McpToolTrainer::new(config).unwrap();
trainer.load_tool_definitions().unwrap();
let checkpoint = trainer.export_checkpoint();
assert_eq!(checkpoint.step, 0);
let config2 = McpTrainingConfig::default();
let mut trainer2 = McpToolTrainer::new(config2).unwrap();
trainer2.import_checkpoint(checkpoint).unwrap();
assert_eq!(trainer2.step.load(Ordering::SeqCst), 0);
}
#[test]
fn test_add_trajectory_to_buffer() {
let config = McpTrainingConfig::default();
let trainer = McpToolTrainer::new(config).unwrap();
let trajectory = TrajectoryBuilder::new("buf-traj", "buffer test")
.add_step(StepBuilder::new("system_status").build())
.build();
trainer.add_trajectory(trajectory);
assert_eq!(trainer.stats().total_trajectories, 1);
}
#[test]
fn test_same_category() {
let config = McpTrainingConfig::default();
let mut trainer = McpToolTrainer::new(config).unwrap();
trainer.load_tool_definitions().unwrap();
assert!(trainer.same_category("memory_store", "memory_search"));
assert!(!trainer.same_category("memory_store", "agent_spawn"));
}
}