oxirag 0.1.1

A four-layer RAG engine with SMT-based logic verification and knowledge graph support
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
//! On-the-fly Distillation Module for `OxiRAG`.
//!
//! This module implements Core Vision #3: On-the-fly Distillation, which
//! automatically generates specialized lightweight models for frequent queries.
//!
//! # Overview
//!
//! The distillation system works by:
//! 1. Tracking query patterns and their frequencies
//! 2. Collecting Q&A pairs as training data
//! 3. Detecting candidates ready for distillation
//! 4. Preparing training data for model specialization
//!
//! # Architecture
//!
//! ```text
//! Incoming Query
//!//!//! ┌─────────────────────┐
//! │ QueryFrequencyTracker│  ← Normalize and track patterns
//! └──────────┬──────────┘
//!//!//! ┌─────────────────────┐
//! │   QAPairCollector   │  ← Collect training data
//! └──────────┬──────────┘
//!//!//! ┌─────────────────────┐
//! │  CandidateDetector  │  ← Identify distillation candidates
//! └──────────┬──────────┘
//!//!//!   Ready for Distillation
//! ```
//!
//! # Usage
//!
//! ```rust,ignore
//! use oxirag::distillation::{
//!     InMemoryDistillationTracker, DistillationConfig, DistillationTracker,
//! };
//!
//! #[tokio::main]
//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
//!     // Create tracker with custom config
//!     let config = DistillationConfig {
//!         min_frequency_threshold: 3,
//!         similarity_threshold: 0.85,
//!         max_candidates: 100,
//!         collection_window_secs: 3600,
//!         max_qa_pairs_per_pattern: 50,
//!     };
//!
//!     let mut tracker = InMemoryDistillationTracker::new(config);
//!
//!     // Track queries as they come in
//!     tracker.track_query("What is Rust?", Some("A systems programming language."), 0.95).await?;
//!     tracker.track_query("What is Rust?", Some("Rust is a language focused on safety."), 0.90).await?;
//!
//!     // Check for candidates ready for distillation
//!     let candidates = tracker.get_candidates().await;
//!     for candidate in candidates {
//!         if candidate.ready_for_distillation {
//!             println!("Ready: {} (freq: {})", candidate.pattern.normalized_text, candidate.frequency);
//!         }
//!     }
//!
//!     Ok(())
//! }
//! ```
//!
//! # Features
//!
//! - **Query Normalization**: Queries are normalized (lowercase, punctuation removed) for pattern matching
//! - **Similarity Detection**: Similar queries are grouped together even with slight variations
//! - **Time Windows**: Q&A pairs expire after a configurable time window
//! - **Priority Ranking**: Candidates are ranked by frequency, confidence, and data quality
//! - **Deduplication**: Duplicate Q&A pairs are automatically rejected

pub mod candle_lora;
pub mod collector;
pub mod detector;
pub mod feature;
pub mod hotswap;
pub mod lora;
pub mod losses;
pub mod metrics;
pub mod progressive;
pub mod registry;
pub mod teacher_student;
pub mod tracker;
pub mod traits;
pub mod trigger;
pub mod types;

// Re-export main types
pub use candle_lora::{CandleLoraConfig, CandleLoraTrainer, TrainingMetrics};
pub use collector::{CollectorStatistics, QAPairCollector, TrainingExample};
pub use detector::{CandidateDetector, CandidateEvaluation, NearReadyReason};
pub use feature::{
    AttentionTransfer, FeatureDistillation, FeatureDistillationConfig, FeatureLoss, LayerMapping,
    MockFeatureDistillation, ProjectionType,
};
pub use hotswap::{ModelSelector, ModelSelectorBuilder, SelectionStrategy, SelectorStatistics};
pub use lora::{
    LoraConfig, LoraTrainer, LoraTrainingExample, MockLoraTrainer, TrainingJob, TrainingStatus,
};
pub use losses::{
    CombinedLoss, CosineLoss, DistillationLoss, HardTargetLoss, KLDivergenceLoss, LossConfig,
    LossType, MSELoss, SoftTargetLoss, TemperatureScaling,
};
pub use metrics::{
    ComparisonResult, ComparisonSummary, DistillationEvaluator, EvalStudentModel, EvalTeacherModel,
    EvaluationResult, EvaluatorConfig, ExtraEpochMetrics, KnowledgeTransferMetrics,
    LayerSimilarity, MetricsTracker, PlotData, TestExample, TestExampleMetadata, TrackerSummary,
    TrainingEpochMetrics,
};
pub use progressive::{
    EpochMetrics, LossWeights, MockProgressiveDistillation, ModelSize, ProgressiveConfig,
    ProgressiveDistillation, ProgressiveResult, ProgressiveScheduler, StageConfig, StageResult,
};
pub use registry::{ModelMetadata, ModelMetrics, ModelRegistry, RegistryStatistics};
pub use teacher_student::{
    DistillationMetrics, DistillationPair, DistillationStepConfig, DistillationStepResult,
    MockStudentModel, MockTeacherModel, StudentModel, TeacherModel,
};
pub use tracker::QueryFrequencyTracker;
pub use traits::DistillationTracker;
pub use trigger::{DistillationTrigger, TriggerCondition, TriggerEvaluation, TriggerStatistics};
pub use types::{
    DistillationCandidate, DistillationConfig, DistillationStats, QAPair, QueryPattern,
    current_timestamp,
};

use crate::error::OxiRagError;
use async_trait::async_trait;

/// An in-memory implementation of the `DistillationTracker` trait.
///
/// This combines the frequency tracker, collector, and detector into
/// a unified interface for distillation tracking.
#[derive(Debug, Clone)]
pub struct InMemoryDistillationTracker {
    /// The frequency tracker.
    tracker: QueryFrequencyTracker,
    /// The Q&A pair collector.
    collector: QAPairCollector,
    /// The candidate detector.
    detector: CandidateDetector,
}

impl InMemoryDistillationTracker {
    /// Create a new in-memory tracker with the given configuration.
    #[must_use]
    pub fn new(config: DistillationConfig) -> Self {
        Self {
            tracker: QueryFrequencyTracker::new(config.clone()),
            collector: QAPairCollector::new(config.clone()),
            detector: CandidateDetector::new(config),
        }
    }

    /// Create a new in-memory tracker with default configuration.
    #[must_use]
    pub fn with_defaults() -> Self {
        Self::new(DistillationConfig::default())
    }

    /// Get the frequency tracker.
    #[must_use]
    pub fn tracker(&self) -> &QueryFrequencyTracker {
        &self.tracker
    }

    /// Get a mutable reference to the frequency tracker.
    pub fn tracker_mut(&mut self) -> &mut QueryFrequencyTracker {
        &mut self.tracker
    }

    /// Get the collector.
    #[must_use]
    pub fn collector(&self) -> &QAPairCollector {
        &self.collector
    }

    /// Get a mutable reference to the collector.
    pub fn collector_mut(&mut self) -> &mut QAPairCollector {
        &mut self.collector
    }

    /// Get the detector.
    #[must_use]
    pub fn detector(&self) -> &CandidateDetector {
        &self.detector
    }

    /// Clean up expired data.
    pub fn cleanup(&mut self) {
        self.tracker.cleanup_expired();
        self.collector.cleanup_expired();
    }

    /// Get training examples for a specific pattern.
    #[must_use]
    pub fn get_training_examples(&self, pattern: &QueryPattern) -> Vec<TrainingExample> {
        self.collector.export_for_training(pattern)
    }

    /// Get all training examples.
    #[must_use]
    pub fn get_all_training_examples(&self) -> Vec<TrainingExample> {
        self.collector.export_all_for_training()
    }
}

impl Default for InMemoryDistillationTracker {
    fn default() -> Self {
        Self::with_defaults()
    }
}

#[async_trait]
impl DistillationTracker for InMemoryDistillationTracker {
    async fn track_query(
        &mut self,
        query: &str,
        answer: Option<&str>,
        confidence: f32,
    ) -> Result<(), OxiRagError> {
        let pattern = if let Some(ans) = answer {
            self.tracker.track_with_answer(query, ans, confidence)
        } else {
            self.tracker.track(query)
        };

        // Also add to collector if we have an answer
        if let Some(ans) = answer {
            self.collector
                .collect_with_pattern(query, ans, confidence, pattern);
        }

        Ok(())
    }

    async fn get_candidates(&self) -> Vec<DistillationCandidate> {
        self.tracker.all_candidates().into_iter().cloned().collect()
    }

    async fn get_qa_pairs(&self, pattern: &QueryPattern) -> Vec<QAPair> {
        self.collector.get_pairs(pattern)
    }

    async fn is_ready_for_distillation(&self, pattern: &QueryPattern) -> bool {
        self.tracker.is_ready(pattern)
    }

    fn stats(&self) -> DistillationStats {
        let tracker_stats = self.tracker.stats();
        let collector_stats = self.collector.statistics();

        DistillationStats {
            total_queries_tracked: tracker_stats.total_queries_tracked,
            unique_patterns: tracker_stats.unique_patterns,
            candidates_ready: tracker_stats.candidates_ready,
            qa_pairs_collected: collector_stats.total_pairs,
        }
    }

    async fn clear(&mut self) {
        self.tracker.clear();
        self.collector.clear();
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[tokio::test]
    async fn test_in_memory_tracker_creation() {
        let tracker = InMemoryDistillationTracker::with_defaults();
        let stats = tracker.stats();
        assert_eq!(stats.total_queries_tracked, 0);
    }

    #[tokio::test]
    async fn test_track_query_without_answer() {
        let mut tracker = InMemoryDistillationTracker::with_defaults();

        tracker
            .track_query("What is Rust?", None, 0.0)
            .await
            .unwrap();

        let stats = tracker.stats();
        assert_eq!(stats.total_queries_tracked, 1);
        assert_eq!(stats.qa_pairs_collected, 0);
    }

    #[tokio::test]
    async fn test_track_query_with_answer() {
        let mut tracker = InMemoryDistillationTracker::with_defaults();

        tracker
            .track_query("What is Rust?", Some("A programming language."), 0.95)
            .await
            .unwrap();

        let stats = tracker.stats();
        assert_eq!(stats.total_queries_tracked, 1);
        assert_eq!(stats.qa_pairs_collected, 1);
    }

    #[tokio::test]
    async fn test_get_candidates() {
        let mut tracker = InMemoryDistillationTracker::with_defaults();

        tracker.track_query("query1", None, 0.0).await.unwrap();
        tracker.track_query("query2", None, 0.0).await.unwrap();

        let candidates = tracker.get_candidates().await;
        assert_eq!(candidates.len(), 2);
    }

    #[tokio::test]
    async fn test_get_qa_pairs() {
        let mut tracker = InMemoryDistillationTracker::with_defaults();

        tracker
            .track_query("What is Rust?", Some("Answer 1"), 0.9)
            .await
            .unwrap();
        tracker
            .track_query("What is Rust?", Some("Answer 2"), 0.85)
            .await
            .unwrap();

        let pattern = QueryPattern::new("What is Rust?");
        let pairs = tracker.get_qa_pairs(&pattern).await;

        assert_eq!(pairs.len(), 2);
    }

    #[tokio::test]
    async fn test_is_ready_for_distillation() {
        let config = DistillationConfig {
            min_frequency_threshold: 2,
            similarity_threshold: 0.7,
            ..Default::default()
        };
        let mut tracker = InMemoryDistillationTracker::new(config);

        // Track same query multiple times with answers
        tracker
            .track_query("test query", Some("answer 1"), 0.9)
            .await
            .unwrap();
        tracker
            .track_query("test query", Some("answer 2"), 0.85)
            .await
            .unwrap();

        let pattern = QueryPattern::new("test query");
        assert!(tracker.is_ready_for_distillation(&pattern).await);
    }

    #[tokio::test]
    async fn test_clear() {
        let mut tracker = InMemoryDistillationTracker::with_defaults();

        tracker
            .track_query("query1", Some("answer"), 0.9)
            .await
            .unwrap();
        tracker
            .track_query("query2", Some("answer"), 0.9)
            .await
            .unwrap();

        tracker.clear().await;

        let stats = tracker.stats();
        assert_eq!(stats.total_queries_tracked, 0);
        assert_eq!(stats.qa_pairs_collected, 0);
    }

    #[tokio::test]
    async fn test_get_training_examples() {
        let mut tracker = InMemoryDistillationTracker::with_defaults();

        tracker
            .track_query("What is Rust?", Some("A programming language."), 0.95)
            .await
            .unwrap();

        let pattern = QueryPattern::new("What is Rust?");
        let examples = tracker.get_training_examples(&pattern);

        assert_eq!(examples.len(), 1);
        assert_eq!(examples[0].input, "What is Rust?");
        assert_eq!(examples[0].output, "A programming language.");
    }

    #[tokio::test]
    async fn test_cleanup() {
        let mut tracker = InMemoryDistillationTracker::with_defaults();

        tracker
            .track_query("test", Some("answer"), 0.9)
            .await
            .unwrap();

        // Cleanup should not error
        tracker.cleanup();

        // Stats should still be valid (total_queries_tracked should be 1)
        let stats = tracker.stats();
        assert_eq!(stats.total_queries_tracked, 1);
    }

    #[tokio::test]
    async fn test_multiple_patterns() {
        let mut tracker = InMemoryDistillationTracker::with_defaults();

        tracker
            .track_query("What is Rust?", Some("A language."), 0.9)
            .await
            .unwrap();
        tracker
            .track_query("What is Python?", Some("Another language."), 0.85)
            .await
            .unwrap();
        tracker
            .track_query("What is JavaScript?", Some("Yet another language."), 0.8)
            .await
            .unwrap();

        let stats = tracker.stats();
        assert_eq!(stats.total_queries_tracked, 3);
        assert_eq!(stats.unique_patterns, 3);
    }

    #[tokio::test]
    async fn test_ready_candidates_flow() {
        let config = DistillationConfig {
            min_frequency_threshold: 3,
            similarity_threshold: 0.7,
            ..Default::default()
        };
        let mut tracker = InMemoryDistillationTracker::new(config);

        // Track same query multiple times
        for i in 0..5 {
            tracker
                .track_query("frequent query", Some(&format!("answer {i}")), 0.9)
                .await
                .unwrap();
        }

        // Track another query just once
        tracker
            .track_query("rare query", Some("answer"), 0.9)
            .await
            .unwrap();

        let stats = tracker.stats();
        assert_eq!(stats.candidates_ready, 1);

        // Verify the right pattern is ready
        let pattern = QueryPattern::new("frequent query");
        assert!(tracker.is_ready_for_distillation(&pattern).await);

        let rare_pattern = QueryPattern::new("rare query");
        assert!(!tracker.is_ready_for_distillation(&rare_pattern).await);
    }

    #[test]
    fn test_query_pattern_exports() {
        // Test that all types are properly exported
        let _config = DistillationConfig::default();
        let _pattern = QueryPattern::new("test");
        let _stats = DistillationStats::default();
    }

    #[test]
    fn test_detector_exports() {
        let detector = CandidateDetector::with_defaults();
        let _config = detector.config();
    }

    #[test]
    fn test_collector_exports() {
        let collector = QAPairCollector::with_defaults();
        let _stats = collector.statistics();
    }
}