use super::types::QueryPattern;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[cfg(feature = "distillation")]
use crate::error::{DistillationError, OxiRagError};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LoraConfig {
pub rank: usize,
pub alpha: f32,
pub dropout: f32,
pub target_modules: Vec<String>,
pub learning_rate: f64,
pub num_epochs: usize,
pub batch_size: usize,
}
impl Default for LoraConfig {
fn default() -> Self {
Self {
rank: 8,
alpha: 16.0,
dropout: 0.05,
target_modules: vec!["q_proj".to_string(), "v_proj".to_string()],
learning_rate: 1e-4,
num_epochs: 3,
batch_size: 4,
}
}
}
impl LoraConfig {
#[must_use]
pub fn with_rank(rank: usize) -> Self {
Self {
rank,
..Default::default()
}
}
#[must_use]
pub fn with_alpha(mut self, alpha: f32) -> Self {
self.alpha = alpha;
self
}
#[must_use]
pub fn with_dropout(mut self, dropout: f32) -> Self {
self.dropout = dropout.clamp(0.0, 1.0);
self
}
#[must_use]
pub fn with_target_modules(mut self, modules: Vec<String>) -> Self {
self.target_modules = modules;
self
}
#[must_use]
pub fn with_learning_rate(mut self, lr: f64) -> Self {
self.learning_rate = lr;
self
}
#[must_use]
pub fn with_epochs(mut self, epochs: usize) -> Self {
self.num_epochs = epochs;
self
}
#[must_use]
pub fn with_batch_size(mut self, batch_size: usize) -> Self {
self.batch_size = batch_size.max(1);
self
}
#[must_use]
pub fn is_valid(&self) -> bool {
self.rank > 0
&& self.alpha > 0.0
&& self.dropout >= 0.0
&& self.dropout <= 1.0
&& self.learning_rate > 0.0
&& self.num_epochs > 0
&& self.batch_size > 0
&& !self.target_modules.is_empty()
}
#[must_use]
#[allow(clippy::cast_precision_loss)]
pub fn scaling_factor(&self) -> f32 {
self.alpha / self.rank as f32
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LoraTrainingExample {
pub input: String,
pub output: String,
pub weight: f32,
}
impl LoraTrainingExample {
#[must_use]
pub fn new(input: impl Into<String>, output: impl Into<String>) -> Self {
Self {
input: input.into(),
output: output.into(),
weight: 1.0,
}
}
#[must_use]
pub fn with_weight(input: impl Into<String>, output: impl Into<String>, weight: f32) -> Self {
Self {
input: input.into(),
output: output.into(),
weight: weight.max(0.0),
}
}
#[must_use]
pub fn weight(mut self, weight: f32) -> Self {
self.weight = weight.max(0.0);
self
}
#[must_use]
pub fn is_valid(&self) -> bool {
!self.input.trim().is_empty() && !self.output.trim().is_empty() && self.weight > 0.0
}
#[must_use]
pub fn effective_contribution(&self, total_weight: f32) -> f32 {
if total_weight <= 0.0 {
0.0
} else {
self.weight / total_weight
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
pub enum TrainingStatus {
#[default]
Pending,
Preparing,
Training {
epoch: usize,
loss: f32,
},
Completed {
final_loss: f32,
},
Failed {
error: String,
},
}
impl TrainingStatus {
#[must_use]
pub fn is_terminal(&self) -> bool {
matches!(self, Self::Completed { .. } | Self::Failed { .. })
}
#[must_use]
pub fn is_active(&self) -> bool {
matches!(self, Self::Preparing | Self::Training { .. })
}
#[must_use]
pub fn is_success(&self) -> bool {
matches!(self, Self::Completed { .. })
}
#[must_use]
pub fn current_loss(&self) -> Option<f32> {
match self {
Self::Training { loss, .. } => Some(*loss),
Self::Completed { final_loss } => Some(*final_loss),
_ => None,
}
}
#[must_use]
pub fn current_epoch(&self) -> Option<usize> {
match self {
Self::Training { epoch, .. } => Some(*epoch),
_ => None,
}
}
}
#[derive(Debug, Clone)]
pub struct TrainingJob {
pub job_id: String,
pub pattern: QueryPattern,
pub config: LoraConfig,
pub examples: Vec<LoraTrainingExample>,
pub status: TrainingStatus,
pub created_at: u64,
pub completed_at: Option<u64>,
}
impl TrainingJob {
#[must_use]
pub fn new(
job_id: impl Into<String>,
pattern: QueryPattern,
config: LoraConfig,
examples: Vec<LoraTrainingExample>,
) -> Self {
Self {
job_id: job_id.into(),
pattern,
config,
examples,
status: TrainingStatus::Pending,
created_at: super::types::current_timestamp(),
completed_at: None,
}
}
#[must_use]
pub fn total_weight(&self) -> f32 {
self.examples.iter().map(|e| e.weight).sum()
}
#[must_use]
pub fn valid_example_count(&self) -> usize {
self.examples.iter().filter(|e| e.is_valid()).count()
}
#[must_use]
pub fn is_ready(&self) -> bool {
matches!(self.status, TrainingStatus::Pending)
&& self.config.is_valid()
&& self.valid_example_count() >= self.config.batch_size
}
#[must_use]
pub fn estimated_steps(&self) -> usize {
let num_examples = self.valid_example_count();
if self.config.batch_size == 0 {
return 0;
}
let steps_per_epoch = num_examples.div_ceil(self.config.batch_size);
steps_per_epoch * self.config.num_epochs
}
pub fn update_status(&mut self, status: TrainingStatus) {
if status.is_terminal() && self.completed_at.is_none() {
self.completed_at = Some(super::types::current_timestamp());
}
self.status = status;
}
pub fn fail(&mut self, error: impl Into<String>) {
self.update_status(TrainingStatus::Failed {
error: error.into(),
});
}
pub fn complete(&mut self, final_loss: f32) {
self.update_status(TrainingStatus::Completed { final_loss });
}
#[must_use]
pub fn duration_secs(&self) -> Option<u64> {
self.completed_at.map(|c| c.saturating_sub(self.created_at))
}
}
#[async_trait]
pub trait LoraTrainer: Send + Sync {
async fn create_job(
&mut self,
pattern: &QueryPattern,
examples: Vec<LoraTrainingExample>,
config: LoraConfig,
) -> Result<String, OxiRagError>;
async fn get_status(&self, job_id: &str) -> Option<TrainingStatus>;
async fn cancel_job(&mut self, job_id: &str) -> Result<(), OxiRagError>;
fn list_jobs(&self) -> Vec<&TrainingJob>;
}
#[derive(Debug, Default)]
pub struct MockLoraTrainer {
jobs: HashMap<String, TrainingJob>,
next_job_id: u64,
simulate_failure: bool,
}
impl MockLoraTrainer {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_simulated_failure(mut self, simulate: bool) -> Self {
self.simulate_failure = simulate;
self
}
#[must_use]
pub fn get_job(&self, job_id: &str) -> Option<&TrainingJob> {
self.jobs.get(job_id)
}
pub fn get_job_mut(&mut self, job_id: &str) -> Option<&mut TrainingJob> {
self.jobs.get_mut(job_id)
}
pub fn simulate_progress(&mut self, job_id: &str, epoch: usize, loss: f32) {
if let Some(job) = self.jobs.get_mut(job_id) {
job.update_status(TrainingStatus::Training { epoch, loss });
}
}
pub fn simulate_completion(&mut self, job_id: &str, final_loss: f32) {
if let Some(job) = self.jobs.get_mut(job_id) {
job.complete(final_loss);
}
}
#[must_use]
pub fn active_job_count(&self) -> usize {
self.jobs.values().filter(|j| j.status.is_active()).count()
}
#[must_use]
pub fn completed_job_count(&self) -> usize {
self.jobs.values().filter(|j| j.status.is_success()).count()
}
pub fn clear_completed(&mut self) {
self.jobs.retain(|_, j| !j.status.is_terminal());
}
fn generate_job_id(&mut self) -> String {
self.next_job_id += 1;
format!("mock-job-{}", self.next_job_id)
}
}
#[async_trait]
impl LoraTrainer for MockLoraTrainer {
async fn create_job(
&mut self,
pattern: &QueryPattern,
examples: Vec<LoraTrainingExample>,
config: LoraConfig,
) -> Result<String, OxiRagError> {
if self.simulate_failure {
return Err(DistillationError::TrackingFailed("Simulated failure".to_string()).into());
}
if !config.is_valid() {
return Err(
DistillationError::InvalidConfig("Invalid LoRA configuration".to_string()).into(),
);
}
if examples.is_empty() {
return Err(DistillationError::CollectionFailed(
"No training examples provided".to_string(),
)
.into());
}
let job_id = self.generate_job_id();
let job = TrainingJob::new(job_id.clone(), pattern.clone(), config, examples);
self.jobs.insert(job_id.clone(), job);
Ok(job_id)
}
async fn get_status(&self, job_id: &str) -> Option<TrainingStatus> {
self.jobs.get(job_id).map(|j| j.status.clone())
}
async fn cancel_job(&mut self, job_id: &str) -> Result<(), OxiRagError> {
let job = self.jobs.get_mut(job_id).ok_or_else(|| {
DistillationError::PatternNotFound(format!("Job not found: {job_id}"))
})?;
if job.status.is_terminal() {
return Err(DistillationError::TrackingFailed(
"Cannot cancel a completed job".to_string(),
)
.into());
}
job.fail("Cancelled by user");
Ok(())
}
fn list_jobs(&self) -> Vec<&TrainingJob> {
self.jobs.values().collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_lora_config_default() {
let config = LoraConfig::default();
assert_eq!(config.rank, 8);
assert!((config.alpha - 16.0).abs() < f32::EPSILON);
assert!(config.is_valid());
}
#[test]
fn test_lora_config_builder() {
let config = LoraConfig::with_rank(16)
.with_alpha(32.0)
.with_dropout(0.1)
.with_epochs(5)
.with_batch_size(8);
assert_eq!(config.rank, 16);
assert!((config.alpha - 32.0).abs() < f32::EPSILON);
assert!((config.dropout - 0.1).abs() < f32::EPSILON);
assert_eq!(config.num_epochs, 5);
assert_eq!(config.batch_size, 8);
}
#[test]
fn test_lora_config_scaling_factor() {
let config = LoraConfig {
rank: 8,
alpha: 16.0,
..Default::default()
};
assert!((config.scaling_factor() - 2.0).abs() < f32::EPSILON);
}
#[test]
fn test_lora_config_validation() {
let invalid_config = LoraConfig {
rank: 0,
..Default::default()
};
assert!(!invalid_config.is_valid());
let invalid_dropout = LoraConfig::default().with_dropout(1.5);
assert!((invalid_dropout.dropout - 1.0).abs() < f32::EPSILON);
}
#[test]
fn test_training_example_creation() {
let example = LoraTrainingExample::new("input", "output");
assert_eq!(example.input, "input");
assert_eq!(example.output, "output");
assert!((example.weight - 1.0).abs() < f32::EPSILON);
}
#[test]
fn test_training_example_with_weight() {
let example = LoraTrainingExample::with_weight("input", "output", 2.5);
assert!((example.weight - 2.5).abs() < f32::EPSILON);
}
#[test]
fn test_training_example_validation() {
let valid = LoraTrainingExample::new("input", "output");
assert!(valid.is_valid());
let empty_input = LoraTrainingExample::new("", "output");
assert!(!empty_input.is_valid());
let zero_weight = LoraTrainingExample::with_weight("input", "output", 0.0);
assert!(!zero_weight.is_valid());
}
#[test]
fn test_training_status_states() {
assert!(!TrainingStatus::Pending.is_terminal());
assert!(!TrainingStatus::Preparing.is_active().not());
assert!(TrainingStatus::Completed { final_loss: 0.1 }.is_terminal());
assert!(
TrainingStatus::Failed {
error: "test".to_string()
}
.is_terminal()
);
}
#[test]
fn test_training_status_loss() {
let training = TrainingStatus::Training {
epoch: 1,
loss: 0.5,
};
assert!((training.current_loss().unwrap() - 0.5).abs() < f32::EPSILON);
let completed = TrainingStatus::Completed { final_loss: 0.1 };
assert!((completed.current_loss().unwrap() - 0.1).abs() < f32::EPSILON);
assert!(TrainingStatus::Pending.current_loss().is_none());
}
#[test]
fn test_training_job_creation() {
let pattern = QueryPattern::new("test query");
let config = LoraConfig::default();
let examples = vec![LoraTrainingExample::new("input", "output")];
let job = TrainingJob::new("job-1", pattern, config, examples);
assert_eq!(job.job_id, "job-1");
assert_eq!(job.status, TrainingStatus::Pending);
assert!(job.completed_at.is_none());
}
#[test]
fn test_training_job_total_weight() {
let pattern = QueryPattern::new("test");
let config = LoraConfig::default();
let examples = vec![
LoraTrainingExample::with_weight("a", "b", 1.0),
LoraTrainingExample::with_weight("c", "d", 2.0),
LoraTrainingExample::with_weight("e", "f", 0.5),
];
let job = TrainingJob::new("job", pattern, config, examples);
assert!((job.total_weight() - 3.5).abs() < f32::EPSILON);
}
#[test]
fn test_training_job_status_update() {
let pattern = QueryPattern::new("test");
let config = LoraConfig::default();
let examples = vec![LoraTrainingExample::new("input", "output")];
let mut job = TrainingJob::new("job", pattern, config, examples);
job.update_status(TrainingStatus::Training {
epoch: 1,
loss: 0.5,
});
assert!(matches!(job.status, TrainingStatus::Training { .. }));
job.complete(0.1);
assert!(job.completed_at.is_some());
}
#[test]
fn test_training_job_estimated_steps() {
let pattern = QueryPattern::new("test");
let config = LoraConfig::default().with_epochs(2).with_batch_size(2);
let examples = vec![
LoraTrainingExample::new("a", "b"),
LoraTrainingExample::new("c", "d"),
LoraTrainingExample::new("e", "f"),
LoraTrainingExample::new("g", "h"),
];
let job = TrainingJob::new("job", pattern, config, examples);
assert_eq!(job.estimated_steps(), 4);
}
#[tokio::test]
async fn test_mock_trainer_create_job() {
let mut trainer = MockLoraTrainer::new();
let pattern = QueryPattern::new("test");
let config = LoraConfig::default();
let examples = vec![LoraTrainingExample::new("input", "output")];
let result = trainer.create_job(&pattern, examples, config).await;
assert!(result.is_ok());
let job_id = result.unwrap();
assert!(trainer.get_job(&job_id).is_some());
}
#[tokio::test]
async fn test_mock_trainer_get_status() {
let mut trainer = MockLoraTrainer::new();
let pattern = QueryPattern::new("test");
let config = LoraConfig::default();
let examples = vec![LoraTrainingExample::new("input", "output")];
let job_id = trainer
.create_job(&pattern, examples, config)
.await
.unwrap();
let status = trainer.get_status(&job_id).await;
assert!(matches!(status, Some(TrainingStatus::Pending)));
}
#[tokio::test]
async fn test_mock_trainer_cancel_job() {
let mut trainer = MockLoraTrainer::new();
let pattern = QueryPattern::new("test");
let config = LoraConfig::default();
let examples = vec![LoraTrainingExample::new("input", "output")];
let job_id = trainer
.create_job(&pattern, examples, config)
.await
.unwrap();
let result = trainer.cancel_job(&job_id).await;
assert!(result.is_ok());
assert!(matches!(
trainer.get_status(&job_id).await,
Some(TrainingStatus::Failed { .. })
));
}
#[tokio::test]
async fn test_mock_trainer_simulate_failure() {
let mut trainer = MockLoraTrainer::new().with_simulated_failure(true);
let pattern = QueryPattern::new("test");
let config = LoraConfig::default();
let examples = vec![LoraTrainingExample::new("input", "output")];
let result = trainer.create_job(&pattern, examples, config).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_mock_trainer_list_jobs() {
let mut trainer = MockLoraTrainer::new();
let pattern = QueryPattern::new("test");
let config = LoraConfig::default();
for i in 0..3 {
let examples = vec![LoraTrainingExample::new(format!("input{i}"), "output")];
let _ = trainer.create_job(&pattern, examples, config.clone()).await;
}
assert_eq!(trainer.list_jobs().len(), 3);
}
#[tokio::test]
async fn test_mock_trainer_simulate_progress() {
let mut trainer = MockLoraTrainer::new();
let pattern = QueryPattern::new("test");
let config = LoraConfig::default();
let examples = vec![LoraTrainingExample::new("input", "output")];
let job_id = trainer
.create_job(&pattern, examples, config)
.await
.unwrap();
trainer.simulate_progress(&job_id, 1, 0.5);
let status = trainer.get_status(&job_id).await.unwrap();
assert!(matches!(status, TrainingStatus::Training { epoch: 1, .. }));
trainer.simulate_completion(&job_id, 0.1);
let status = trainer.get_status(&job_id).await.unwrap();
assert!(matches!(status, TrainingStatus::Completed { .. }));
}
#[tokio::test]
async fn test_mock_trainer_active_count() {
let mut trainer = MockLoraTrainer::new();
let pattern = QueryPattern::new("test");
let config = LoraConfig::default();
let examples = vec![LoraTrainingExample::new("input", "output")];
let job_id = trainer
.create_job(&pattern, examples, config)
.await
.unwrap();
assert_eq!(trainer.active_job_count(), 0);
trainer.simulate_progress(&job_id, 1, 0.5);
assert_eq!(trainer.active_job_count(), 1);
trainer.simulate_completion(&job_id, 0.1);
assert_eq!(trainer.active_job_count(), 0);
assert_eq!(trainer.completed_job_count(), 1);
}
#[tokio::test]
async fn test_mock_trainer_clear_completed() {
let mut trainer = MockLoraTrainer::new();
let pattern = QueryPattern::new("test");
let config = LoraConfig::default();
let examples = vec![LoraTrainingExample::new("input", "output")];
let job_id = trainer
.create_job(&pattern, examples, config)
.await
.unwrap();
trainer.simulate_completion(&job_id, 0.1);
assert_eq!(trainer.list_jobs().len(), 1);
trainer.clear_completed();
assert_eq!(trainer.list_jobs().len(), 0);
}
#[tokio::test]
async fn test_mock_trainer_invalid_config() {
let mut trainer = MockLoraTrainer::new();
let pattern = QueryPattern::new("test");
let config = LoraConfig {
rank: 0, ..Default::default()
};
let examples = vec![LoraTrainingExample::new("input", "output")];
let result = trainer.create_job(&pattern, examples, config).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_mock_trainer_empty_examples() {
let mut trainer = MockLoraTrainer::new();
let pattern = QueryPattern::new("test");
let config = LoraConfig::default();
let examples: Vec<LoraTrainingExample> = vec![];
let result = trainer.create_job(&pattern, examples, config).await;
assert!(result.is_err());
}
trait BoolExt {
fn not(self) -> bool;
}
impl BoolExt for bool {
fn not(self) -> bool {
!self
}
}
}