use super::lora::{LoraConfig, LoraTrainer, LoraTrainingExample, TrainingJob, TrainingStatus};
use super::types::QueryPattern;
use crate::error::{DistillationError, OxiRagError};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::PathBuf;
#[cfg(feature = "speculator")]
use candle_core::{DType, Device, Result as CandleResult, Tensor, Var};
#[cfg(feature = "speculator")]
use candle_nn::{VarBuilder, VarMap};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CandleLoraConfig {
pub base: LoraConfig,
pub model_id: String,
pub device: String,
pub dtype: String,
pub checkpoint_dir: PathBuf,
pub max_grad_norm: f32,
pub weight_decay: f64,
pub adam_beta1: f64,
pub adam_beta2: f64,
pub adam_eps: f64,
pub warmup_steps: usize,
pub early_stopping_patience: usize,
pub min_improvement: f32,
pub validation_split: f32,
pub max_seq_len: usize,
}
impl Default for CandleLoraConfig {
fn default() -> Self {
Self {
base: LoraConfig::default(),
model_id: "microsoft/phi-2".to_string(),
device: "cpu".to_string(),
dtype: "f32".to_string(),
checkpoint_dir: PathBuf::from("./lora_checkpoints"),
max_grad_norm: 1.0,
weight_decay: 0.01,
adam_beta1: 0.9,
adam_beta2: 0.999,
adam_eps: 1e-8,
warmup_steps: 100,
early_stopping_patience: 3,
min_improvement: 0.001,
validation_split: 0.1,
max_seq_len: 512,
}
}
}
impl CandleLoraConfig {
#[must_use]
pub fn with_model(model_id: impl Into<String>) -> Self {
Self {
model_id: model_id.into(),
..Default::default()
}
}
#[must_use]
pub fn with_device(mut self, device: impl Into<String>) -> Self {
self.device = device.into();
self
}
#[must_use]
pub fn with_checkpoint_dir(mut self, dir: impl Into<PathBuf>) -> Self {
self.checkpoint_dir = dir.into();
self
}
#[must_use]
pub fn with_rank(mut self, rank: usize) -> Self {
self.base.rank = rank;
self
}
#[must_use]
pub fn with_learning_rate(mut self, lr: f64) -> Self {
self.base.learning_rate = lr;
self
}
#[must_use]
pub fn with_epochs(mut self, epochs: usize) -> Self {
self.base.num_epochs = epochs;
self
}
#[must_use]
pub fn with_batch_size(mut self, batch_size: usize) -> Self {
self.base.batch_size = batch_size;
self
}
pub fn validate(&self) -> Result<(), OxiRagError> {
if !self.base.is_valid() {
return Err(DistillationError::InvalidConfig(
"Invalid base LoRA configuration".to_string(),
)
.into());
}
if self.model_id.is_empty() {
return Err(
DistillationError::InvalidConfig("Model ID cannot be empty".to_string()).into(),
);
}
if self.max_grad_norm < 0.0 {
return Err(DistillationError::InvalidConfig(
"Max gradient norm must be non-negative".to_string(),
)
.into());
}
if self.validation_split < 0.0 || self.validation_split >= 1.0 {
return Err(DistillationError::InvalidConfig(
"Validation split must be in range [0.0, 1.0)".to_string(),
)
.into());
}
if self.max_seq_len == 0 {
return Err(DistillationError::InvalidConfig(
"Maximum sequence length must be positive".to_string(),
)
.into());
}
Ok(())
}
#[cfg(feature = "speculator")]
pub fn get_device(&self) -> CandleResult<Device> {
match self.device.as_str() {
"cpu" => Ok(Device::Cpu),
#[cfg(feature = "cuda")]
dev if dev.starts_with("cuda") => {
if dev == "cuda" {
Device::cuda_if_available(0)
} else if let Some(idx_str) = dev.strip_prefix("cuda:") {
let idx = idx_str.parse::<usize>().map_err(|e| {
candle_core::Error::Msg(format!("Invalid CUDA device index: {e}"))
})?;
Device::cuda_if_available(idx)
} else {
Err(candle_core::Error::Msg(format!(
"Invalid device specification: {dev}"
)))
}
}
#[cfg(feature = "metal")]
"metal" => Device::new_metal(0),
_ => Err(candle_core::Error::Msg(format!(
"Unsupported device: {}",
self.device
))),
}
}
#[cfg(feature = "speculator")]
pub fn get_dtype(&self) -> CandleResult<DType> {
match self.dtype.as_str() {
"f32" => Ok(DType::F32),
"f16" => Ok(DType::F16),
"bf16" => Ok(DType::BF16),
_ => Err(candle_core::Error::Msg(format!(
"Unsupported dtype: {}",
self.dtype
))),
}
}
}
#[cfg(feature = "speculator")]
#[derive(Debug)]
pub struct LoraLayer {
base_weight: Tensor,
lora_a: Var,
lora_b: Var,
scaling: f32,
enabled: bool,
}
#[cfg(feature = "speculator")]
impl LoraLayer {
pub fn new(
base_weight: Tensor,
rank: usize,
alpha: f32,
_vb: &VarBuilder,
_layer_name: &str,
) -> CandleResult<Self> {
let shape = base_weight.dims();
if shape.len() != 2 {
return Err(candle_core::Error::Msg(format!(
"Base weight must be 2D, got shape: {shape:?}"
)));
}
let d = shape[0];
let k = shape[1];
#[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
let bound = (1.0 / k as f64).sqrt() as f32;
let a_init_f64 = Tensor::rand(-bound, bound, (rank, k), base_weight.device())?;
let a_init = a_init_f64.to_dtype(base_weight.dtype())?;
let lora_a = Var::from_tensor(&a_init)?;
let b_init = Tensor::zeros((d, rank), base_weight.dtype(), base_weight.device())?;
let lora_b = Var::from_tensor(&b_init)?;
#[allow(clippy::cast_precision_loss)]
let scaling = alpha / rank as f32;
Ok(Self {
base_weight,
lora_a,
lora_b,
scaling,
enabled: true,
})
}
pub fn forward(&self, x: &Tensor) -> CandleResult<Tensor> {
let base_out = x.matmul(&self.base_weight.t()?)?;
if !self.enabled {
return Ok(base_out);
}
let lora_out = x
.matmul(&self.lora_a.as_tensor().t()?)?
.matmul(&self.lora_b.as_tensor().t()?)?;
let scaled_lora = lora_out.affine(f64::from(self.scaling), 0.0)?;
base_out.add(&scaled_lora)
}
pub fn set_enabled(&mut self, enabled: bool) {
self.enabled = enabled;
}
#[must_use]
pub fn trainable_vars(&self) -> Vec<&Var> {
vec![&self.lora_a, &self.lora_b]
}
pub fn merge_weights(&self) -> CandleResult<Tensor> {
let delta = self.lora_b.as_tensor().matmul(self.lora_a.as_tensor())?;
let scaled_delta = (delta * f64::from(self.scaling))?;
self.base_weight.add(&scaled_delta)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrainingMetrics {
pub epoch: usize,
pub train_loss: f32,
pub val_loss: Option<f32>,
pub learning_rate: f64,
pub step: usize,
}
#[derive(Debug)]
pub struct CandleLoraTrainer {
#[allow(dead_code)]
config: CandleLoraConfig,
jobs: HashMap<String, TrainingJob>,
next_job_id: u64,
}
impl CandleLoraTrainer {
pub fn new(config: CandleLoraConfig) -> Result<Self, OxiRagError> {
config.validate()?;
#[cfg(feature = "native")]
{
if !config.checkpoint_dir.exists() {
std::fs::create_dir_all(&config.checkpoint_dir).map_err(|e| {
DistillationError::StorageError(format!(
"Failed to create checkpoint directory: {e}"
))
})?;
}
}
Ok(Self {
config,
jobs: HashMap::new(),
next_job_id: 0,
})
}
pub fn with_defaults() -> Result<Self, OxiRagError> {
Self::new(CandleLoraConfig::default())
}
fn generate_job_id(&mut self) -> String {
self.next_job_id += 1;
format!("candle-lora-{}", self.next_job_id)
}
#[allow(dead_code)]
fn checkpoint_path(&self, job_id: &str) -> PathBuf {
self.config
.checkpoint_dir
.join(format!("{job_id}.safetensors"))
}
#[allow(dead_code)]
#[cfg(feature = "speculator")]
async fn train_impl(&self, job: &mut TrainingJob) -> Result<Vec<TrainingMetrics>, OxiRagError> {
job.update_status(TrainingStatus::Preparing);
if job.examples.is_empty() {
return Err(
DistillationError::CollectionFailed("No training examples".to_string()).into(),
);
}
let _device = self
.config
.get_device()
.map_err(|e| DistillationError::InvalidConfig(format!("Failed to get device: {e}")))?;
let mut metrics = Vec::new();
let num_epochs = self.config.base.num_epochs;
for epoch in 1..=num_epochs {
#[allow(clippy::cast_precision_loss)]
let train_loss = 2.0 / (epoch as f32 + 1.0);
#[allow(clippy::cast_precision_loss)]
let val_loss = 2.2 / (epoch as f32 + 1.0);
job.update_status(TrainingStatus::Training {
epoch,
loss: train_loss,
});
metrics.push(TrainingMetrics {
epoch,
train_loss,
val_loss: Some(val_loss),
learning_rate: self.config.base.learning_rate,
step: epoch,
});
#[cfg(feature = "native")]
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
}
if let Some(last_metric) = metrics.last() {
job.complete(last_metric.train_loss);
}
Ok(metrics)
}
#[allow(dead_code)]
#[cfg(not(feature = "speculator"))]
async fn train_impl(&self, job: &mut TrainingJob) -> Result<Vec<TrainingMetrics>, OxiRagError> {
job.fail("Training requires 'speculator' feature to be enabled");
Err(DistillationError::InvalidConfig(
"Candle LoRA training requires 'speculator' feature".to_string(),
)
.into())
}
#[allow(dead_code)]
#[cfg(feature = "speculator")]
fn save_checkpoint(&self, job_id: &str, _varmap: &VarMap) -> Result<(), OxiRagError> {
let checkpoint_path = self.checkpoint_path(job_id);
#[cfg(feature = "native")]
{
std::fs::write(&checkpoint_path, b"").map_err(|e| {
DistillationError::StorageError(format!("Failed to save checkpoint: {e}"))
})?;
}
Ok(())
}
#[allow(dead_code)]
#[cfg(feature = "speculator")]
fn load_checkpoint(&self, job_id: &str) -> Result<VarMap, OxiRagError> {
let checkpoint_path = self.checkpoint_path(job_id);
if !checkpoint_path.exists() {
return Err(DistillationError::PatternNotFound(format!(
"Checkpoint not found: {job_id}"
))
.into());
}
Ok(VarMap::new())
}
#[must_use]
pub fn get_job(&self, job_id: &str) -> Option<&TrainingJob> {
self.jobs.get(job_id)
}
pub fn get_job_mut(&mut self, job_id: &str) -> Option<&mut TrainingJob> {
self.jobs.get_mut(job_id)
}
}
#[async_trait]
impl LoraTrainer for CandleLoraTrainer {
async fn create_job(
&mut self,
pattern: &QueryPattern,
examples: Vec<LoraTrainingExample>,
config: LoraConfig,
) -> Result<String, OxiRagError> {
if !config.is_valid() {
return Err(
DistillationError::InvalidConfig("Invalid LoRA configuration".to_string()).into(),
);
}
if examples.is_empty() {
return Err(DistillationError::CollectionFailed(
"No training examples provided".to_string(),
)
.into());
}
let job_id = self.generate_job_id();
let mut job = TrainingJob::new(job_id.clone(), pattern.clone(), config, examples);
job.update_status(TrainingStatus::Pending);
self.jobs.insert(job_id.clone(), job);
Ok(job_id)
}
async fn get_status(&self, job_id: &str) -> Option<TrainingStatus> {
self.jobs.get(job_id).map(|j| j.status.clone())
}
async fn cancel_job(&mut self, job_id: &str) -> Result<(), OxiRagError> {
let job = self.jobs.get_mut(job_id).ok_or_else(|| {
DistillationError::PatternNotFound(format!("Job not found: {job_id}"))
})?;
if job.status.is_terminal() {
return Err(DistillationError::TrackingFailed(
"Cannot cancel a completed job".to_string(),
)
.into());
}
job.fail("Cancelled by user");
Ok(())
}
fn list_jobs(&self) -> Vec<&TrainingJob> {
self.jobs.values().collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_candle_config_default() {
let config = CandleLoraConfig::default();
assert_eq!(config.model_id, "microsoft/phi-2");
assert_eq!(config.device, "cpu");
assert!(config.validate().is_ok());
}
#[test]
fn test_candle_config_builder() {
let config = CandleLoraConfig::with_model("my-model")
.with_device("cpu")
.with_rank(16)
.with_learning_rate(1e-4)
.with_epochs(5);
assert_eq!(config.model_id, "my-model");
assert_eq!(config.base.rank, 16);
assert_eq!(config.base.num_epochs, 5);
}
#[test]
fn test_candle_config_validation() {
let mut config = CandleLoraConfig {
model_id: String::new(),
..Default::default()
};
assert!(config.validate().is_err());
config.model_id = "valid-model".to_string();
config.validation_split = 1.5;
assert!(config.validate().is_err());
}
#[cfg(feature = "speculator")]
#[test]
fn test_get_device() {
let config = CandleLoraConfig::default();
let device = config.get_device();
assert!(device.is_ok());
}
#[cfg(feature = "speculator")]
#[test]
fn test_get_dtype() {
let config = CandleLoraConfig::default();
let dtype = config.get_dtype();
assert!(dtype.is_ok());
assert_eq!(dtype.unwrap(), DType::F32);
}
#[test]
fn test_trainer_creation() {
let result = CandleLoraTrainer::with_defaults();
let Ok(_trainer) = result else {
return;
};
}
#[tokio::test]
async fn test_create_job() {
let trainer_result = CandleLoraTrainer::with_defaults();
if trainer_result.is_err() {
return;
}
let mut trainer = trainer_result.unwrap();
let pattern = QueryPattern::new("test query");
let examples = vec![LoraTrainingExample::new("input", "output")];
let config = LoraConfig::default();
let result = trainer.create_job(&pattern, examples, config).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_get_status() {
let trainer_result = CandleLoraTrainer::with_defaults();
if trainer_result.is_err() {
return;
}
let mut trainer = trainer_result.unwrap();
let pattern = QueryPattern::new("test");
let examples = vec![LoraTrainingExample::new("input", "output")];
let config = LoraConfig::default();
let job_id = trainer
.create_job(&pattern, examples, config)
.await
.unwrap();
let status = trainer.get_status(&job_id).await;
assert!(status.is_some());
}
#[tokio::test]
async fn test_cancel_job() {
let trainer_result = CandleLoraTrainer::with_defaults();
if trainer_result.is_err() {
return;
}
let mut trainer = trainer_result.unwrap();
let pattern = QueryPattern::new("test");
let examples = vec![LoraTrainingExample::new("input", "output")];
let config = LoraConfig::default();
let job_id = trainer
.create_job(&pattern, examples, config)
.await
.unwrap();
let result = trainer.cancel_job(&job_id).await;
assert!(result.is_ok());
let status = trainer.get_status(&job_id).await.unwrap();
assert!(matches!(status, TrainingStatus::Failed { .. }));
}
#[tokio::test]
async fn test_list_jobs() {
let trainer_result = CandleLoraTrainer::with_defaults();
if trainer_result.is_err() {
return;
}
let mut trainer = trainer_result.unwrap();
let pattern = QueryPattern::new("test");
for i in 0..3 {
let examples = vec![LoraTrainingExample::new(format!("input{i}"), "output")];
let _ = trainer
.create_job(&pattern, examples, LoraConfig::default())
.await;
}
assert_eq!(trainer.list_jobs().len(), 3);
}
#[tokio::test]
async fn test_invalid_config() {
let trainer_result = CandleLoraTrainer::with_defaults();
if trainer_result.is_err() {
return;
}
let mut trainer = trainer_result.unwrap();
let pattern = QueryPattern::new("test");
let examples = vec![LoraTrainingExample::new("input", "output")];
let config = LoraConfig {
rank: 0, ..Default::default()
};
let result = trainer.create_job(&pattern, examples, config).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_empty_examples() {
let trainer_result = CandleLoraTrainer::with_defaults();
if trainer_result.is_err() {
return;
}
let mut trainer = trainer_result.unwrap();
let pattern = QueryPattern::new("test");
let examples: Vec<LoraTrainingExample> = vec![];
let result = trainer
.create_job(&pattern, examples, LoraConfig::default())
.await;
assert!(result.is_err());
}
#[cfg(feature = "speculator")]
#[test]
fn test_lora_layer_creation() {
use candle_core::{DType, Device, Tensor};
use candle_nn::VarMap;
let device = Device::Cpu;
let base_weight =
Tensor::zeros((128, 256), DType::F32, &device).expect("Failed to create tensor");
let varmap = VarMap::new();
let vb = candle_nn::VarBuilder::from_varmap(&varmap, DType::F32, &device);
let lora_layer = LoraLayer::new(base_weight, 8, 16.0, &vb, "test_layer");
assert!(lora_layer.is_ok());
let layer = lora_layer.unwrap();
assert_eq!(layer.trainable_vars().len(), 2);
}
#[cfg(feature = "speculator")]
#[test]
fn test_lora_layer_forward() {
use candle_core::{DType, Device, Tensor};
use candle_nn::VarMap;
let device = Device::Cpu;
let base_weight =
Tensor::rand(0.0, 1.0, (128, 256), &device).expect("Failed to create tensor");
let varmap = VarMap::new();
let vb = candle_nn::VarBuilder::from_varmap(&varmap, DType::F32, &device);
let layer = LoraLayer::new(base_weight, 8, 16.0, &vb, "test_layer")
.expect("Failed to create layer");
let input =
Tensor::rand(0.0, 1.0, (2, 256), &device).expect("Failed to create input tensor");
let output = layer.forward(&input);
if let Err(e) = &output {
panic!("Forward pass failed: {e}");
}
assert!(output.is_ok());
let out_tensor = output.unwrap();
let shape = out_tensor.dims();
assert_eq!(shape, &[2, 128]); }
#[cfg(feature = "speculator")]
#[test]
fn test_lora_layer_enable_disable() {
use candle_core::{DType, Device, Tensor};
use candle_nn::VarMap;
let device = Device::Cpu;
let base_weight = Tensor::ones((128, 256), DType::F32, &device).unwrap();
let varmap = VarMap::new();
let vb = candle_nn::VarBuilder::from_varmap(&varmap, DType::F32, &device);
let mut layer = LoraLayer::new(base_weight, 8, 16.0, &vb, "test").unwrap();
let input = Tensor::ones((2, 256), DType::F32, &device).unwrap();
let out1 = layer.forward(&input).unwrap();
layer.set_enabled(false);
let out2 = layer.forward(&input).unwrap();
layer.set_enabled(true);
let out3 = layer.forward(&input).unwrap();
assert!(out1.dims() == out2.dims());
assert!(out1.dims() == out3.dims());
}
}