mod distill;
mod types;
pub use distill::DistillationPipeline;
pub use types::{DistillationStats, LabelingResult, RawExample};
use crate::error::{Result, TuneError};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct DistillationConfig {
pub batch_size: usize,
pub concurrency: usize,
pub normalize_labels: bool,
pub min_confidence: Option<f32>,
pub save_intermediate: bool,
pub output_dir: Option<String>,
pub progress_interval: usize,
}
impl Default for DistillationConfig {
fn default() -> Self {
Self {
batch_size: 10,
concurrency: 5,
normalize_labels: true,
min_confidence: None,
save_intermediate: false,
output_dir: None,
progress_interval: 100,
}
}
}
impl DistillationConfig {
pub fn fast() -> Self {
Self {
batch_size: 20,
concurrency: 10,
normalize_labels: true,
min_confidence: None,
save_intermediate: false,
output_dir: None,
progress_interval: 50,
}
}
pub fn quality() -> Self {
Self {
batch_size: 5,
concurrency: 3,
normalize_labels: true,
min_confidence: Some(0.5),
save_intermediate: true,
output_dir: None,
progress_interval: 20,
}
}
pub fn batch_size(mut self, size: usize) -> Self {
self.batch_size = size;
self
}
pub fn concurrency(mut self, level: usize) -> Self {
self.concurrency = level;
self
}
pub fn output_dir(mut self, dir: impl Into<String>) -> Self {
self.output_dir = Some(dir.into());
self.save_intermediate = true;
self
}
pub fn validate(&self) -> Result<()> {
if self.batch_size == 0 {
return Err(TuneError::InvalidConfig(
"batch_size must be > 0".to_string(),
));
}
if self.concurrency == 0 {
return Err(TuneError::InvalidConfig(
"concurrency must be > 0".to_string(),
));
}
if let Some(conf) = self.min_confidence {
if !(0.0..=1.0).contains(&conf) {
return Err(TuneError::InvalidConfig(format!(
"min_confidence must be between 0.0 and 1.0, got {conf}"
)));
}
}
Ok(())
}
}
#[cfg(test)]
mod tests;