use crate::data::{ExampleMetadata, IntentLabels};
use uuid::Uuid;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct LabelingResult {
pub example_id: Uuid,
pub labels: IntentLabels,
pub confidence: f32,
pub raw_response: Option<String>,
pub error: Option<String>,
pub latency_ms: u64,
}
impl LabelingResult {
pub fn success(
example_id: Uuid,
labels: IntentLabels,
confidence: f32,
latency_ms: u64,
) -> Self {
Self {
example_id,
labels,
confidence,
raw_response: None,
error: None,
latency_ms,
}
}
pub fn failure(example_id: Uuid, error: impl Into<String>, latency_ms: u64) -> Self {
Self {
example_id,
labels: IntentLabels::default(),
confidence: 0.0,
raw_response: None,
error: Some(error.into()),
latency_ms,
}
}
pub fn is_success(&self) -> bool {
self.error.is_none()
}
pub fn with_raw_response(mut self, response: impl Into<String>) -> Self {
self.raw_response = Some(response.into());
self
}
}
#[derive(Debug, Clone, Default)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct DistillationStats {
pub total_processed: usize,
pub successful: usize,
pub failed: usize,
pub skipped: usize,
pub total_latency_ms: u64,
pub avg_latency_ms: f64,
pub avg_confidence: f32,
pub label_distribution: Vec<usize>,
}
impl DistillationStats {
pub fn success_rate(&self) -> f64 {
if self.total_processed == 0 {
return 0.0;
}
self.successful as f64 / self.total_processed as f64
}
pub fn update(&mut self, result: &LabelingResult) {
self.total_processed += 1;
self.total_latency_ms += result.latency_ms;
if result.is_success() {
self.successful += 1;
self.avg_confidence = (self.avg_confidence * (self.successful - 1) as f32
+ result.confidence)
/ self.successful as f32;
if self.label_distribution.is_empty() {
self.label_distribution = vec![0; IntentLabels::NUM_CLASSES];
}
let (name, _) = result.labels.dominant();
let idx = IntentLabels::class_names()
.iter()
.position(|&n| n == name)
.unwrap_or(0);
self.label_distribution[idx] += 1;
} else {
self.failed += 1;
}
if self.total_processed > 0 {
self.avg_latency_ms = self.total_latency_ms as f64 / self.total_processed as f64;
}
}
}
pub const MAX_MESSAGE_LENGTH: usize = 10_000;
pub const MAX_PROMPT_LENGTH: usize = 50_000;
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct RawExample {
pub id: Uuid,
pub context: Vec<String>,
pub message: String,
pub metadata: Option<ExampleMetadata>,
}
impl RawExample {
pub fn new(context: Vec<String>, message: impl Into<String>) -> Self {
Self {
id: Uuid::new_v4(),
context,
message: message.into(),
metadata: None,
}
}
pub fn with_id(id: Uuid, context: Vec<String>, message: impl Into<String>) -> Self {
Self {
id,
context,
message: message.into(),
metadata: None,
}
}
pub fn with_metadata(mut self, metadata: ExampleMetadata) -> Self {
self.metadata = Some(metadata);
self
}
pub fn to_prompt(&self) -> String {
let mut prompt = String::new();
if !self.context.is_empty() {
prompt.push_str("Context (previous messages):\n");
for (i, msg) in self.context.iter().enumerate() {
let sanitized = Self::sanitize_input(msg);
prompt.push_str(&format!("{}. {}\n", i + 1, sanitized));
}
prompt.push('\n');
}
let sanitized_message = Self::sanitize_input(&self.message);
prompt.push_str(&format!(
"Current message to classify:\n{sanitized_message}"
));
if prompt.len() > MAX_PROMPT_LENGTH {
prompt.truncate(MAX_PROMPT_LENGTH);
prompt.push_str("\n[truncated]");
}
prompt
}
fn sanitize_input(input: &str) -> String {
let truncated = if input.len() > MAX_MESSAGE_LENGTH {
let boundary = input
.char_indices()
.map(|(i, _)| i)
.take_while(|&i| i <= MAX_MESSAGE_LENGTH)
.last()
.unwrap_or(0);
&input[..boundary]
} else {
input
};
truncated
.chars()
.filter(|c| !c.is_control() || *c == '\n' || *c == '\t' || *c == '\r')
.collect()
}
}