trueno/tuner/data_collector/
mod.rs1mod drift;
12mod persistence;
13mod types;
14
15pub use types::{ConceptDriftStatus, TrainingSample, TrainingStats, UserFeedback};
16
17use crate::brick::BrickProfiler;
18
19use super::brick_tuner::BrickTuner;
20use super::error::TunerError;
21use super::features::{FeatureExtractor, RunConfig, TunerFeatures};
22use super::helpers::chrono_lite_now;
23use super::types::{BottleneckClass, KernelType};
24
25#[derive(Debug, Default)]
31pub struct TunerDataCollector {
32 pub(crate) samples: Vec<TrainingSample>,
34 pub(crate) extractor: FeatureExtractor,
36 pub(crate) retrain_threshold: usize,
38 pub(crate) samples_at_last_train: usize,
40 pub(crate) feedback: Vec<UserFeedback>,
42 pub(crate) online_learning_enabled: bool,
44 pub(crate) error_window: Vec<f32>,
46 error_window_size: usize,
48}
49
50impl TunerDataCollector {
51 pub(super) const DEFAULT_ERROR_WINDOW_SIZE: usize = 50;
53
54 pub(super) const DRIFT_ERROR_THRESHOLD: f32 = 0.15;
56
57 pub(super) const STALENESS_THRESHOLD: usize = 100;
59
60 pub const MIN_SAMPLES_FOR_TRAINING: usize = 1000;
62
63 pub fn new() -> Self {
65 Self {
66 samples: Vec::new(),
67 extractor: FeatureExtractor::new(),
68 retrain_threshold: 100,
69 samples_at_last_train: 0,
70 feedback: Vec::new(),
71 online_learning_enabled: false, error_window: Vec::new(),
73 error_window_size: Self::DEFAULT_ERROR_WINDOW_SIZE,
74 }
75 }
76
77 pub fn with_online_learning() -> Self {
79 let mut collector = Self::new();
80 collector.online_learning_enabled = true;
81 collector
82 }
83
84 pub fn enable_online_learning(&mut self) {
86 self.online_learning_enabled = true;
87 }
88
89 pub fn disable_online_learning(&mut self) {
91 self.online_learning_enabled = false;
92 }
93
94 pub fn is_online_learning_enabled(&self) -> bool {
96 self.online_learning_enabled
97 }
98
99 pub fn record(
101 &mut self,
102 profiler: &BrickProfiler,
103 config: &RunConfig,
104 kernel: KernelType,
105 ) -> Option<()> {
106 let throughput_tps = profiler.tokens_per_sec()?;
107 let features = self.extractor.extract(profiler, config);
108 let bottleneck = features.bottleneck_class.unwrap_or(BottleneckClass::Unknown);
109
110 let sample = TrainingSample {
111 features,
112 throughput_tps,
113 best_kernel: kernel,
114 bottleneck,
115 timestamp: chrono_lite_now(),
116 hardware_id: "unknown".to_string(),
117 };
118
119 self.samples.push(sample);
120 Some(())
121 }
122
123 pub fn samples(&self) -> &[TrainingSample] {
125 &self.samples
126 }
127
128 pub fn len(&self) -> usize {
130 self.samples.len()
131 }
132
133 pub fn is_empty(&self) -> bool {
135 self.samples.is_empty()
136 }
137
138 pub fn to_json(&self) -> Result<String, TunerError> {
140 serde_json::to_string_pretty(&self.samples)
141 .map_err(|e| TunerError::Serialization(e.to_string()))
142 }
143
144 pub fn prepare_training_data(&self) -> Vec<(TunerFeatures, f32)> {
146 self.samples.iter().map(|s| (s.features.clone(), s.throughput_tps)).collect()
147 }
148
149 pub fn ready_to_train(&self) -> bool {
151 self.samples.len() >= Self::MIN_SAMPLES_FOR_TRAINING
152 }
153
154 pub fn train_if_ready(&self) -> Option<BrickTuner> {
156 if !self.ready_to_train() {
157 return None;
158 }
159
160 let training_data = self.prepare_training_data();
161 let mut tuner = BrickTuner::new();
162
163 match tuner.train(&training_data) {
164 Ok(()) => Some(tuner),
165 Err(_) => None,
166 }
167 }
168
169 pub fn training_progress(&self) -> (usize, usize) {
171 (self.samples.len(), Self::MIN_SAMPLES_FOR_TRAINING)
172 }
173
174 pub fn merge(&mut self, other: &TunerDataCollector) {
176 self.samples.extend(other.samples.iter().cloned());
177 }
178}