Skip to main content

trueno/tuner/data_collector/
mod.rs

1//! Training Data Collection
2//!
3//! Implements `TunerDataCollector` for collecting and persisting training samples.
4//!
5//! # Module Layout
6//!
7//! - [`types`] -- `TrainingSample`, `UserFeedback`, `ConceptDriftStatus`, `TrainingStats`
8//! - [`persistence`] -- APR binary save/load, JSON import/export, cache paths
9//! - [`drift`] -- Online learning, concept drift detection, feedback, auto-retrain
10
11mod drift;
12mod persistence;
13mod types;
14
15pub use types::{ConceptDriftStatus, TrainingSample, TrainingStats, UserFeedback};
16
17use crate::brick::BrickProfiler;
18
19use super::brick_tuner::BrickTuner;
20use super::error::TunerError;
21use super::features::{FeatureExtractor, RunConfig, TunerFeatures};
22use super::helpers::chrono_lite_now;
23use super::types::{BottleneckClass, KernelType};
24
25// ============================================================================
26// TunerDataCollector
27// ============================================================================
28
29/// Training data collector with online learning support (T-TUNER-005, GitHub #82)
30#[derive(Debug, Default)]
31pub struct TunerDataCollector {
32    /// Collected samples
33    pub(crate) samples: Vec<TrainingSample>,
34    /// Feature extractor
35    pub(crate) extractor: FeatureExtractor,
36    /// Auto-retrain threshold
37    pub(crate) retrain_threshold: usize,
38    /// Number of samples at last training
39    pub(crate) samples_at_last_train: usize,
40    /// User feedback history (sample index -> feedback)
41    pub(crate) feedback: Vec<UserFeedback>,
42    /// Online learning enabled (privacy: opt-in only)
43    pub(crate) online_learning_enabled: bool,
44    /// Moving average of prediction errors (for concept drift)
45    pub(crate) error_window: Vec<f32>,
46    /// Error window size for drift detection
47    error_window_size: usize,
48}
49
50impl TunerDataCollector {
51    /// Default error window size for concept drift detection
52    pub(super) const DEFAULT_ERROR_WINDOW_SIZE: usize = 50;
53
54    /// Error threshold for drift detection (mean absolute error)
55    pub(super) const DRIFT_ERROR_THRESHOLD: f32 = 0.15;
56
57    /// Staleness threshold (samples since training) for recommending retrain
58    pub(super) const STALENESS_THRESHOLD: usize = 100;
59
60    /// Minimum samples required before training triggers
61    pub const MIN_SAMPLES_FOR_TRAINING: usize = 1000;
62
63    /// Create a new collector
64    pub fn new() -> Self {
65        Self {
66            samples: Vec::new(),
67            extractor: FeatureExtractor::new(),
68            retrain_threshold: 100,
69            samples_at_last_train: 0,
70            feedback: Vec::new(),
71            online_learning_enabled: false, // Privacy: opt-in
72            error_window: Vec::new(),
73            error_window_size: Self::DEFAULT_ERROR_WINDOW_SIZE,
74        }
75    }
76
77    /// Create a collector with online learning enabled
78    pub fn with_online_learning() -> Self {
79        let mut collector = Self::new();
80        collector.online_learning_enabled = true;
81        collector
82    }
83
84    /// Enable online learning (privacy: explicit opt-in)
85    pub fn enable_online_learning(&mut self) {
86        self.online_learning_enabled = true;
87    }
88
89    /// Disable online learning
90    pub fn disable_online_learning(&mut self) {
91        self.online_learning_enabled = false;
92    }
93
94    /// Check if online learning is enabled
95    pub fn is_online_learning_enabled(&self) -> bool {
96        self.online_learning_enabled
97    }
98
99    /// Record a profiling run as training data
100    pub fn record(
101        &mut self,
102        profiler: &BrickProfiler,
103        config: &RunConfig,
104        kernel: KernelType,
105    ) -> Option<()> {
106        let throughput_tps = profiler.tokens_per_sec()?;
107        let features = self.extractor.extract(profiler, config);
108        let bottleneck = features.bottleneck_class.unwrap_or(BottleneckClass::Unknown);
109
110        let sample = TrainingSample {
111            features,
112            throughput_tps,
113            best_kernel: kernel,
114            bottleneck,
115            timestamp: chrono_lite_now(),
116            hardware_id: "unknown".to_string(),
117        };
118
119        self.samples.push(sample);
120        Some(())
121    }
122
123    /// Get all samples
124    pub fn samples(&self) -> &[TrainingSample] {
125        &self.samples
126    }
127
128    /// Get sample count
129    pub fn len(&self) -> usize {
130        self.samples.len()
131    }
132
133    /// Check if empty
134    pub fn is_empty(&self) -> bool {
135        self.samples.is_empty()
136    }
137
138    /// Export to JSON
139    pub fn to_json(&self) -> Result<String, TunerError> {
140        serde_json::to_string_pretty(&self.samples)
141            .map_err(|e| TunerError::Serialization(e.to_string()))
142    }
143
144    /// Prepare training data for model
145    pub fn prepare_training_data(&self) -> Vec<(TunerFeatures, f32)> {
146        self.samples.iter().map(|s| (s.features.clone(), s.throughput_tps)).collect()
147    }
148
149    /// Check if we have enough samples to train
150    pub fn ready_to_train(&self) -> bool {
151        self.samples.len() >= Self::MIN_SAMPLES_FOR_TRAINING
152    }
153
154    /// Train a BrickTuner from collected data if we have enough samples
155    pub fn train_if_ready(&self) -> Option<BrickTuner> {
156        if !self.ready_to_train() {
157            return None;
158        }
159
160        let training_data = self.prepare_training_data();
161        let mut tuner = BrickTuner::new();
162
163        match tuner.train(&training_data) {
164            Ok(()) => Some(tuner),
165            Err(_) => None,
166        }
167    }
168
169    /// Get training progress as (current, required)
170    pub fn training_progress(&self) -> (usize, usize) {
171        (self.samples.len(), Self::MIN_SAMPLES_FOR_TRAINING)
172    }
173
174    /// Merge samples from another collector
175    pub fn merge(&mut self, other: &TunerDataCollector) {
176        self.samples.extend(other.samples.iter().cloned());
177    }
178}