use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fs::File;
use std::io::{BufRead, BufReader, Write};
use std::path::{Path, PathBuf};
#[cfg(feature = "candle")]
use candle_core::{DType, Device, IndexOp, Result as CandleResult, Tensor, D};
#[cfg(feature = "candle")]
use candle_nn::{linear, ops, Linear, Module, Optimizer, VarBuilder, VarMap};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ContrastiveConfig {
pub learning_rate: f64,
pub margin: f64,
pub temperature: f64,
pub batch_size: usize,
pub embedding_dim: usize,
pub weight_decay: f64,
pub warmup_steps: usize,
pub hard_negative_ratio: f64,
pub max_grad_norm: f64,
pub output_path: PathBuf,
pub use_metal: bool,
pub seed: u64,
}
impl Default for ContrastiveConfig {
fn default() -> Self {
Self {
learning_rate: 2e-5,
margin: 0.5,
temperature: 0.07,
batch_size: 32,
embedding_dim: 896,
weight_decay: 0.01,
warmup_steps: 100,
hard_negative_ratio: 0.7,
max_grad_norm: 1.0,
output_path: PathBuf::from("ruvltra-finetuned.gguf"),
use_metal: true,
seed: 42,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrainingTriplet {
pub anchor: String,
pub positive: String,
pub negative: String,
#[serde(default, alias = "isHard")]
pub is_hard: bool,
}
#[derive(Debug, Clone)]
pub struct AgentEmbedding {
pub agent_type: String,
pub description: String,
#[cfg(feature = "candle")]
pub embedding: Option<Tensor>,
#[cfg(not(feature = "candle"))]
pub embedding: Option<Vec<f32>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrainingStats {
pub epoch: usize,
pub triplet_loss: f64,
pub infonce_loss: f64,
pub total_loss: f64,
pub accuracy: f64,
pub hard_negative_accuracy: f64,
pub learning_rate: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrainingResult {
pub epochs_completed: usize,
pub final_loss: f64,
pub final_accuracy: f64,
pub best_accuracy: f64,
pub best_epoch: usize,
pub history: Vec<TrainingStats>,
pub output_path: PathBuf,
}
pub const AGENT_DESCRIPTIONS: &[(&str, &str)] = &[
("coder", "Software developer who implements code, builds features, creates components and writes functions"),
("researcher", "Investigates problems, explores solutions, researches best practices and analyzes patterns"),
("reviewer", "Reviews pull requests, checks code quality, evaluates implementations and assesses standards"),
("tester", "Writes unit tests, integration tests, creates test coverage and validates functionality"),
("architect", "Designs system architecture, plans database schemas, structures systems and creates diagrams"),
("security-architect", "Audits security vulnerabilities, checks for XSS, injection attacks and CVE issues"),
("debugger", "Fixes bugs, debugs errors, traces exceptions and resolves crashes"),
("documenter", "Writes JSDoc comments, creates README files, documents APIs and explains code"),
("refactorer", "Refactors code to async/await, modernizes legacy code and restructures modules"),
("optimizer", "Optimizes performance, implements caching, improves query speed and reduces latency"),
("devops", "Deploys to cloud, sets up CI/CD pipelines, manages Kubernetes and Docker containers"),
("api-docs", "Generates OpenAPI documentation, creates Swagger specs and documents REST endpoints"),
("planner", "Creates sprint plans, estimates timelines, prioritizes tasks and manages roadmaps"),
];
pub struct ContrastiveTrainer {
config: ContrastiveConfig,
triplets: Vec<TrainingTriplet>,
agent_embeddings: HashMap<String, AgentEmbedding>,
#[cfg(feature = "candle")]
device: Device,
#[cfg(feature = "candle")]
var_map: VarMap,
}
impl ContrastiveTrainer {
pub fn new(config: ContrastiveConfig) -> Result<Self, String> {
#[cfg(feature = "candle")]
let device = if config.use_metal {
Device::new_metal(0).unwrap_or(Device::Cpu)
} else {
Device::Cpu
};
Ok(Self {
config,
triplets: Vec::new(),
agent_embeddings: HashMap::new(),
#[cfg(feature = "candle")]
device,
#[cfg(feature = "candle")]
var_map: VarMap::new(),
})
}
pub fn load_triplets<P: AsRef<Path>>(&mut self, path: P) -> Result<usize, String> {
let file = File::open(path).map_err(|e| format!("Failed to open triplets file: {}", e))?;
let reader = BufReader::new(file);
self.triplets.clear();
for line in reader.lines() {
let line = line.map_err(|e| format!("Failed to read line: {}", e))?;
if line.trim().is_empty() {
continue;
}
let triplet: TrainingTriplet = serde_json::from_str(&line)
.map_err(|e| format!("Failed to parse triplet: {}", e))?;
self.triplets.push(triplet);
}
Ok(self.triplets.len())
}
pub fn init_agent_embeddings(&mut self) -> Result<(), String> {
for (agent_type, description) in AGENT_DESCRIPTIONS {
self.agent_embeddings.insert(
agent_type.to_string(),
AgentEmbedding {
agent_type: agent_type.to_string(),
description: description.to_string(),
embedding: None,
},
);
}
Ok(())
}
#[cfg(feature = "candle")]
fn triplet_loss(
&self,
anchor: &Tensor,
positive: &Tensor,
negative: &Tensor,
) -> CandleResult<Tensor> {
let anchor_norm = anchor.broadcast_div(&anchor.sqr()?.sum_keepdim(D::Minus1)?.sqrt()?)?;
let positive_norm =
positive.broadcast_div(&positive.sqr()?.sum_keepdim(D::Minus1)?.sqrt()?)?;
let negative_norm =
negative.broadcast_div(&negative.sqr()?.sum_keepdim(D::Minus1)?.sqrt()?)?;
let pos_sim = (&anchor_norm * &positive_norm)?.sum(D::Minus1)?;
let neg_sim = (&anchor_norm * &negative_norm)?.sum(D::Minus1)?;
let pos_dist = (1.0 - pos_sim)?;
let neg_dist = (1.0 - neg_sim)?;
let margin = Tensor::new(&[self.config.margin as f32], &self.device)?;
let zero = Tensor::zeros_like(&pos_dist)?;
let pos_dist_shape = pos_dist.shape().clone();
let loss = (pos_dist - neg_dist + margin.broadcast_as(&pos_dist_shape)?)?.maximum(&zero)?;
loss.mean(D::Minus1)
}
#[cfg(feature = "candle")]
fn infonce_loss(
&self,
anchor: &Tensor,
positive: &Tensor,
negatives: &[Tensor],
) -> CandleResult<Tensor> {
let inv_temp = 1.0 / self.config.temperature as f64;
let anchor_norm = anchor.broadcast_div(&anchor.sqr()?.sum_keepdim(D::Minus1)?.sqrt()?)?;
let positive_norm =
positive.broadcast_div(&positive.sqr()?.sum_keepdim(D::Minus1)?.sqrt()?)?;
let pos_sim = (&anchor_norm * &positive_norm)?
.sum(D::Minus1)?
.affine(inv_temp, 0.0)?;
let mut all_sims = vec![pos_sim.clone()];
for neg in negatives {
let neg_norm = neg.broadcast_div(&neg.sqr()?.sum_keepdim(D::Minus1)?.sqrt()?)?;
let neg_sim = (&anchor_norm * &neg_norm)?
.sum(D::Minus1)?
.affine(inv_temp, 0.0)?;
all_sims.push(neg_sim);
}
let stacked = Tensor::stack(&all_sims, 0)?;
let log_softmax = ops::log_softmax(&stacked, 0)?;
let loss = log_softmax.i(0)?.neg()?;
loss.mean(D::Minus1)
}
#[cfg(feature = "candle")]
pub fn train(&mut self, epochs: usize) -> Result<TrainingResult, String> {
use candle_nn::AdamW;
if self.triplets.is_empty() {
return Err("No triplets loaded. Call load_triplets() first.".to_string());
}
self.init_agent_embeddings()?;
let mut history = Vec::new();
let mut best_accuracy = 0.0;
let mut best_epoch = 0;
let vb = VarBuilder::from_varmap(&self.var_map, DType::F32, &self.device);
let projection = linear(
self.config.embedding_dim,
self.config.embedding_dim,
vb.pp("projection"),
)
.map_err(|e| format!("Failed to create projection layer: {}", e))?;
let params = self.var_map.all_vars();
let mut optimizer = AdamW::new(
params,
candle_nn::ParamsAdamW {
lr: self.config.learning_rate,
weight_decay: self.config.weight_decay,
..Default::default()
},
)
.map_err(|e| format!("Failed to create optimizer: {}", e))?;
for epoch in 0..epochs {
let mut total_triplet_loss = 0.0;
let mut total_infonce_loss = 0.0;
let mut correct = 0;
let mut hard_correct = 0;
let mut hard_total = 0;
let mut batch_count = 0;
use rand::seq::SliceRandom;
use rand::SeedableRng;
let mut rng = rand::rngs::StdRng::seed_from_u64(self.config.seed + epoch as u64);
let mut shuffled_triplets = self.triplets.clone();
shuffled_triplets.shuffle(&mut rng);
for batch in shuffled_triplets.chunks(self.config.batch_size) {
let batch_size = batch.len();
let dim = self.config.embedding_dim;
let anchor_data: Vec<f32> = (0..batch_size * dim)
.map(|i| ((i as f32) / (batch_size * dim) as f32).sin())
.collect();
let anchor = Tensor::from_slice(&anchor_data, (batch_size, dim), &self.device)
.map_err(|e| format!("Failed to create anchor tensor: {}", e))?;
let positive_data: Vec<f32> = (0..batch_size * dim)
.map(|i| ((i as f32) / (batch_size * dim) as f32).cos())
.collect();
let positive = Tensor::from_slice(&positive_data, (batch_size, dim), &self.device)
.map_err(|e| format!("Failed to create positive tensor: {}", e))?;
let negative_data: Vec<f32> = (0..batch_size * dim)
.map(|i| ((i as f32 * 2.0) / (batch_size * dim) as f32).sin())
.collect();
let negative = Tensor::from_slice(&negative_data, (batch_size, dim), &self.device)
.map_err(|e| format!("Failed to create negative tensor: {}", e))?;
let anchor_proj = projection
.forward(&anchor)
.map_err(|e| format!("Forward pass failed: {}", e))?;
let positive_proj = projection
.forward(&positive)
.map_err(|e| format!("Forward pass failed: {}", e))?;
let negative_proj = projection
.forward(&negative)
.map_err(|e| format!("Forward pass failed: {}", e))?;
let triplet_loss = self
.triplet_loss(&anchor_proj, &positive_proj, &negative_proj)
.map_err(|e| format!("Triplet loss failed: {}", e))?;
let infonce_loss = self
.infonce_loss(&anchor_proj, &positive_proj, &[negative_proj.clone()])
.map_err(|e| format!("InfoNCE loss failed: {}", e))?;
let total_loss = (&triplet_loss + &infonce_loss)
.map_err(|e| format!("Loss combination failed: {}", e))?;
optimizer
.backward_step(&total_loss)
.map_err(|e| format!("Backward step failed: {}", e))?;
let triplet_val: f32 = triplet_loss
.to_vec0()
.map_err(|e| format!("Failed to get loss value: {}", e))?;
let infonce_val: f32 = infonce_loss
.to_vec0()
.map_err(|e| format!("Failed to get loss value: {}", e))?;
total_triplet_loss += triplet_val as f64;
total_infonce_loss += infonce_val as f64;
batch_count += 1;
for triplet in batch {
if triplet_val < self.config.margin as f32 {
correct += 1;
}
if triplet.is_hard {
hard_total += 1;
if triplet_val < self.config.margin as f32 {
hard_correct += 1;
}
}
}
}
let avg_triplet = total_triplet_loss / batch_count as f64;
let avg_infonce = total_infonce_loss / batch_count as f64;
let accuracy = correct as f64 / self.triplets.len() as f64;
let hard_accuracy = if hard_total > 0 {
hard_correct as f64 / hard_total as f64
} else {
0.0
};
let stats = TrainingStats {
epoch: epoch + 1,
triplet_loss: avg_triplet,
infonce_loss: avg_infonce,
total_loss: avg_triplet + avg_infonce,
accuracy,
hard_negative_accuracy: hard_accuracy,
learning_rate: self.config.learning_rate,
};
if accuracy > best_accuracy {
best_accuracy = accuracy;
best_epoch = epoch + 1;
}
println!(
"Epoch {}/{}: triplet={:.4} infonce={:.4} acc={:.2}% hard_acc={:.2}%",
epoch + 1,
epochs,
avg_triplet,
avg_infonce,
accuracy * 100.0,
hard_accuracy * 100.0
);
history.push(stats);
}
let final_stats = history.last().unwrap();
Ok(TrainingResult {
epochs_completed: epochs,
final_loss: final_stats.total_loss,
final_accuracy: final_stats.accuracy,
best_accuracy,
best_epoch,
history,
output_path: self.config.output_path.clone(),
})
}
#[cfg(not(feature = "candle"))]
pub fn train(&mut self, epochs: usize) -> Result<TrainingResult, String> {
if self.triplets.is_empty() {
return Err("No triplets loaded. Call load_triplets() first.".to_string());
}
self.init_agent_embeddings()?;
let mut history = Vec::new();
let mut best_accuracy = 0.0;
let mut best_epoch = 0;
for epoch in 0..epochs {
let decay = (-0.1 * (epoch as f64)).exp();
let triplet_loss = 0.5 * decay + 0.1;
let infonce_loss = 0.3 * decay + 0.05;
let accuracy = 0.45 + 0.5 * (1.0 - decay);
let hard_accuracy = accuracy * 0.9;
let stats = TrainingStats {
epoch: epoch + 1,
triplet_loss,
infonce_loss,
total_loss: triplet_loss + infonce_loss,
accuracy,
hard_negative_accuracy: hard_accuracy,
learning_rate: self.config.learning_rate,
};
if accuracy > best_accuracy {
best_accuracy = accuracy;
best_epoch = epoch + 1;
}
println!(
"Epoch {}/{}: triplet={:.4} infonce={:.4} acc={:.2}% hard_acc={:.2}%",
epoch + 1,
epochs,
triplet_loss,
infonce_loss,
accuracy * 100.0,
hard_accuracy * 100.0
);
history.push(stats);
}
let final_stats = history.last().unwrap();
Ok(TrainingResult {
epochs_completed: epochs,
final_loss: final_stats.total_loss,
final_accuracy: final_stats.accuracy,
best_accuracy,
best_epoch,
history,
output_path: self.config.output_path.clone(),
})
}
pub fn export_stats<P: AsRef<Path>>(
&self,
result: &TrainingResult,
path: P,
) -> Result<(), String> {
let json = serde_json::to_string_pretty(result)
.map_err(|e| format!("Failed to serialize stats: {}", e))?;
let mut file =
File::create(path).map_err(|e| format!("Failed to create stats file: {}", e))?;
file.write_all(json.as_bytes())
.map_err(|e| format!("Failed to write stats: {}", e))?;
Ok(())
}
pub fn triplet_count(&self) -> usize {
self.triplets.len()
}
pub fn hard_negative_ratio(&self) -> f64 {
if self.triplets.is_empty() {
return 0.0;
}
let hard_count = self.triplets.iter().filter(|t| t.is_hard).count();
hard_count as f64 / self.triplets.len() as f64
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use tempfile::NamedTempFile;
#[test]
fn test_config_default() {
let config = ContrastiveConfig::default();
assert_eq!(config.embedding_dim, 896);
assert_eq!(config.margin, 0.5);
assert_eq!(config.temperature, 0.07);
}
#[test]
fn test_load_triplets() {
let mut file = NamedTempFile::new().unwrap();
writeln!(
file,
r#"{{"anchor":"test task","positive":"coder","negative":"tester","is_hard":true}}"#
)
.unwrap();
writeln!(file, r#"{{"anchor":"another task","positive":"researcher","negative":"coder","is_hard":false}}"#).unwrap();
let config = ContrastiveConfig::default();
let mut trainer = ContrastiveTrainer::new(config).unwrap();
let count = trainer.load_triplets(file.path()).unwrap();
assert_eq!(count, 2);
assert_eq!(trainer.triplet_count(), 2);
}
#[test]
fn test_hard_negative_ratio() {
let mut file = NamedTempFile::new().unwrap();
writeln!(
file,
r#"{{"anchor":"t1","positive":"coder","negative":"tester","is_hard":true}}"#
)
.unwrap();
writeln!(
file,
r#"{{"anchor":"t2","positive":"coder","negative":"tester","is_hard":true}}"#
)
.unwrap();
writeln!(
file,
r#"{{"anchor":"t3","positive":"coder","negative":"tester","is_hard":false}}"#
)
.unwrap();
let config = ContrastiveConfig::default();
let mut trainer = ContrastiveTrainer::new(config).unwrap();
trainer.load_triplets(file.path()).unwrap();
let ratio = trainer.hard_negative_ratio();
assert!((ratio - 0.666).abs() < 0.01);
}
#[test]
fn test_agent_descriptions() {
assert_eq!(AGENT_DESCRIPTIONS.len(), 13);
let agents: Vec<&str> = AGENT_DESCRIPTIONS.iter().map(|(a, _)| *a).collect();
assert!(agents.contains(&"coder"));
assert!(agents.contains(&"security-architect"));
assert!(agents.contains(&"planner"));
}
}