use oxirag::distillation::{
CandleLoraConfig, CandleLoraTrainer, DistillationConfig, DistillationTracker,
InMemoryDistillationTracker, LoraConfig, LoraTrainer, LoraTrainingExample, TrainingStatus,
};
#[cfg(all(feature = "native", feature = "distillation"))]
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
println!("=== LoRA Training Example: On-the-Fly Distillation ===\n");
println!("Step 1: Setting up distillation tracker...");
let config = DistillationConfig {
min_frequency_threshold: 3,
similarity_threshold: 0.85,
max_candidates: 100,
collection_window_secs: 3600,
max_qa_pairs_per_pattern: 50,
};
let mut tracker = InMemoryDistillationTracker::new(config);
println!("✓ Tracker initialized\n");
println!("Step 2: Tracking incoming queries...");
let queries = vec![
(
"What is Rust?",
"Rust is a systems programming language focused on safety and performance.",
),
(
"What is Rust?",
"Rust is a modern language for building reliable and efficient software.",
),
(
"What is Rust programming?",
"Rust is a statically typed compiled language with memory safety guarantees.",
),
(
"Tell me about Rust",
"Rust is developed by Mozilla and provides zero-cost abstractions.",
),
(
"Explain Rust",
"Rust eliminates memory bugs through its ownership system.",
),
];
for (query, answer) in &queries {
tracker.track_query(query, Some(answer), 0.95).await?;
println!(" Tracked: {query}");
}
let stats = tracker.stats();
println!(
"\n✓ Tracked {} queries, {} unique patterns",
stats.total_queries_tracked, stats.unique_patterns
);
println!("\nStep 3: Detecting distillation candidates...");
let candidates = tracker.get_candidates().await;
let ready_candidates: Vec<_> = candidates
.iter()
.filter(|c| c.ready_for_distillation)
.collect();
if ready_candidates.is_empty() {
println!("⚠ No candidates ready for distillation yet");
println!(" (This is expected with only {} queries)", queries.len());
println!(" In production, track more queries to reach the threshold");
return Ok(());
}
println!(
"✓ Found {} candidates ready for distillation",
ready_candidates.len()
);
let candidate = &ready_candidates[0];
println!("\nStep 4: Preparing training data...");
println!(" Pattern: {}", candidate.pattern.normalized_text);
println!(" Frequency: {}", candidate.frequency);
println!(" Avg Confidence: {:.2}", candidate.avg_confidence);
println!(" Q&A Pairs: {}", candidate.qa_pairs.len());
let training_examples: Vec<LoraTrainingExample> = candidate
.qa_pairs
.iter()
.map(|pair| LoraTrainingExample::with_weight(&pair.query, &pair.answer, pair.confidence))
.collect();
println!("✓ Prepared {} training examples\n", training_examples.len());
println!("Step 5: Setting up LoRA trainer...");
let lora_config = CandleLoraConfig {
base: LoraConfig {
rank: 8,
alpha: 16.0,
dropout: 0.05,
target_modules: vec!["q_proj".to_string(), "v_proj".to_string()],
learning_rate: 1e-4,
num_epochs: 3,
batch_size: 2,
},
model_id: "microsoft/phi-2".to_string(),
device: "cpu".to_string(),
dtype: "f32".to_string(),
checkpoint_dir: std::env::temp_dir().join("oxirag_lora_checkpoints"),
max_grad_norm: 1.0,
weight_decay: 0.01,
adam_beta1: 0.9,
adam_beta2: 0.999,
adam_eps: 1e-8,
warmup_steps: 10,
early_stopping_patience: 3,
min_improvement: 0.001,
validation_split: 0.1,
max_seq_len: 512,
};
let mut trainer = match CandleLoraTrainer::new(lora_config.clone()) {
Ok(t) => t,
Err(e) => {
println!("⚠ Failed to create trainer: {e}");
println!(" (This is expected in environments without filesystem access)");
return Ok(());
}
};
println!("✓ Trainer initialized");
println!(" Model: {}", lora_config.model_id);
println!(" Device: {}", lora_config.device);
println!(" LoRA Rank: {}", lora_config.base.rank);
println!(" Learning Rate: {}", lora_config.base.learning_rate);
println!(" Epochs: {}\n", lora_config.base.num_epochs);
println!("Step 6: Creating training job...");
let job_id = trainer
.create_job(
&candidate.pattern,
training_examples,
lora_config.base.clone(),
)
.await?;
println!("✓ Training job created: {job_id}\n");
println!("Step 7: Monitoring training progress...");
println!(" (In production, training runs asynchronously in the background)\n");
loop {
if let Some(status) = trainer.get_status(&job_id).await {
match &status {
TrainingStatus::Pending => {
println!(" Status: Pending...");
}
TrainingStatus::Preparing => {
println!(" Status: Preparing data...");
}
TrainingStatus::Training { epoch, loss } => {
println!(" Epoch {}: Training Loss = {:.4}", epoch, loss);
}
TrainingStatus::Completed { final_loss } => {
println!("\n✓ Training completed!");
println!(" Final Loss: {:.4}", final_loss);
break;
}
TrainingStatus::Failed { error } => {
println!("\n✗ Training failed: {error}");
break;
}
}
if status.is_terminal() {
break;
}
}
#[cfg(feature = "native")]
tokio::time::sleep(std::time::Duration::from_millis(500)).await;
}
println!("\n=== Summary ===");
println!("✓ Successfully demonstrated LoRA training workflow");
println!(
"✓ Trained adapter for pattern: {}",
candidate.pattern.normalized_text
);
println!("✓ Model ready for deployment");
println!("\nNext Steps:");
println!(" - Integrate with hot-swap system for automatic model switching");
println!(" - Monitor specialized model performance vs. base model");
println!(" - Collect feedback to refine training data");
println!(" - Scale to multiple query patterns");
Ok(())
}
#[cfg(not(all(feature = "native", feature = "distillation")))]
fn main() {
println!("This example requires 'native' and 'distillation' features.");
println!(
"Run with: cargo run --example lora_training_example --features speculator,distillation,native"
);
}