use super::types::{DistillationCandidate, DistillationStats, QAPair, QueryPattern};
use crate::error::OxiRagError;
use async_trait::async_trait;
#[async_trait]
pub trait DistillationTracker: Send + Sync {
async fn track_query(
&mut self,
query: &str,
answer: Option<&str>,
confidence: f32,
) -> Result<(), OxiRagError>;
async fn get_candidates(&self) -> Vec<DistillationCandidate>;
async fn get_qa_pairs(&self, pattern: &QueryPattern) -> Vec<QAPair>;
async fn is_ready_for_distillation(&self, pattern: &QueryPattern) -> bool;
fn stats(&self) -> DistillationStats;
async fn clear(&mut self);
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use tokio::sync::Mutex;
struct MockTracker {
candidates: Vec<DistillationCandidate>,
qa_pairs: Vec<QAPair>,
stats: DistillationStats,
}
#[async_trait]
impl DistillationTracker for MockTracker {
async fn track_query(
&mut self,
_query: &str,
_answer: Option<&str>,
_confidence: f32,
) -> Result<(), OxiRagError> {
self.stats.total_queries_tracked += 1;
Ok(())
}
async fn get_candidates(&self) -> Vec<DistillationCandidate> {
self.candidates.clone()
}
async fn get_qa_pairs(&self, _pattern: &QueryPattern) -> Vec<QAPair> {
self.qa_pairs.clone()
}
async fn is_ready_for_distillation(&self, _pattern: &QueryPattern) -> bool {
true
}
fn stats(&self) -> DistillationStats {
self.stats.clone()
}
async fn clear(&mut self) {
self.candidates.clear();
self.qa_pairs.clear();
self.stats = DistillationStats::default();
}
}
#[tokio::test]
async fn test_mock_tracker_track_query() {
let mut tracker = MockTracker {
candidates: Vec::new(),
qa_pairs: Vec::new(),
stats: DistillationStats::default(),
};
tracker
.track_query("test", Some("answer"), 0.9)
.await
.unwrap();
assert_eq!(tracker.stats().total_queries_tracked, 1);
}
#[tokio::test]
async fn test_mock_tracker_clear() {
let mut tracker = MockTracker {
candidates: vec![DistillationCandidate::new(QueryPattern::new("test"))],
qa_pairs: Vec::new(),
stats: DistillationStats {
total_queries_tracked: 10,
..Default::default()
},
};
tracker.clear().await;
assert!(tracker.get_candidates().await.is_empty());
assert_eq!(tracker.stats().total_queries_tracked, 0);
}
#[tokio::test]
async fn test_trait_object_safety() {
let tracker: Arc<Mutex<dyn DistillationTracker>> = Arc::new(Mutex::new(MockTracker {
candidates: Vec::new(),
qa_pairs: Vec::new(),
stats: DistillationStats::default(),
}));
let mut guard = tracker.lock().await;
guard
.track_query("test", None, 0.5)
.await
.expect("tracking should work");
assert_eq!(guard.stats().total_queries_tracked, 1);
}
}