use super::DistillationConfig;
use super::types::{DistillationStats, LabelingResult, RawExample};
use crate::data::{ExampleMetadata, IntentLabels, TrainingExample};
use crate::distill::teacher::TeacherConfig;
use crate::error::{Result, TuneError};
use chrono::Utc;
pub struct DistillationPipeline {
teacher: TeacherConfig,
config: DistillationConfig,
stats: DistillationStats,
}
impl DistillationPipeline {
pub fn new(teacher: TeacherConfig, config: DistillationConfig) -> Result<Self> {
teacher.validate().map_err(TuneError::InvalidConfig)?;
config.validate()?;
Ok(Self {
teacher,
config,
stats: DistillationStats::default(),
})
}
pub fn with_teacher(teacher: TeacherConfig) -> Result<Self> {
Self::new(teacher, DistillationConfig::default())
}
pub fn teacher(&self) -> &TeacherConfig {
&self.teacher
}
pub fn config(&self) -> &DistillationConfig {
&self.config
}
pub fn stats(&self) -> &DistillationStats {
&self.stats
}
pub fn reset_stats(&mut self) {
self.stats = DistillationStats::default();
}
pub fn label_single(&mut self, raw: &RawExample) -> Result<LabelingResult> {
let start = std::time::Instant::now();
let _prompt = raw.to_prompt();
let mut labels = IntentLabels {
continuation: 0.4,
topic_shift: 0.1,
explicit_query: 0.3,
person_lookup: 0.1,
health_check: 0.05,
task_status: 0.05,
};
if self.config.normalize_labels {
labels
.softmax_normalize()
.map_err(TuneError::InvalidConfig)?;
}
let confidence = 0.85; let latency_ms = start.elapsed().as_millis() as u64;
if let Some(min_conf) = self.config.min_confidence {
if confidence < min_conf {
self.stats.skipped += 1;
return Err(TuneError::Validation(format!(
"Confidence {confidence} below threshold {min_conf}"
)));
}
}
let result = LabelingResult::success(raw.id, labels, confidence, latency_ms);
self.stats.update(&result);
Ok(result)
}
pub fn label_batch(&mut self, raws: &[RawExample]) -> Vec<LabelingResult> {
let mut results = Vec::with_capacity(raws.len());
for raw in raws {
let result = match self.label_single(raw) {
Ok(r) => r,
Err(e) => {
let r = LabelingResult::failure(raw.id, e.to_string(), 0);
self.stats.update(&r);
r
}
};
results.push(result);
}
results
}
pub fn to_training_examples(
&self,
results: &[LabelingResult],
context_embeddings: &[Vec<Vec<f32>>],
message_embeddings: &[Vec<f32>],
) -> Result<Vec<TrainingExample>> {
if results.len() != context_embeddings.len() || results.len() != message_embeddings.len() {
return Err(TuneError::DimensionMismatch {
expected: results.len(),
actual: context_embeddings.len(),
});
}
let mut examples = Vec::with_capacity(results.len());
for (i, result) in results.iter().enumerate() {
if !result.is_success() {
continue;
}
let mut example = TrainingExample::with_id(
result.example_id,
context_embeddings[i].clone(),
message_embeddings[i].clone(),
result.labels.clone(),
);
let metadata = ExampleMetadata::with_source(result.example_id.to_string())
.teacher(self.teacher.display_name())
.labeled_at(Utc::now())
.confidence(result.confidence);
example = example.with_metadata(metadata);
examples.push(example);
}
Ok(examples)
}
}