mod drift;
mod persistence;
mod types;
pub use types::{ConceptDriftStatus, TrainingSample, TrainingStats, UserFeedback};
use crate::brick::BrickProfiler;
use super::brick_tuner::BrickTuner;
use super::error::TunerError;
use super::features::{FeatureExtractor, RunConfig, TunerFeatures};
use super::helpers::chrono_lite_now;
use super::types::{BottleneckClass, KernelType};
#[derive(Debug, Default)]
pub struct TunerDataCollector {
pub(crate) samples: Vec<TrainingSample>,
pub(crate) extractor: FeatureExtractor,
pub(crate) retrain_threshold: usize,
pub(crate) samples_at_last_train: usize,
pub(crate) feedback: Vec<UserFeedback>,
pub(crate) online_learning_enabled: bool,
pub(crate) error_window: Vec<f32>,
error_window_size: usize,
}
impl TunerDataCollector {
pub(super) const DEFAULT_ERROR_WINDOW_SIZE: usize = 50;
pub(super) const DRIFT_ERROR_THRESHOLD: f32 = 0.15;
pub(super) const STALENESS_THRESHOLD: usize = 100;
pub const MIN_SAMPLES_FOR_TRAINING: usize = 1000;
pub fn new() -> Self {
Self {
samples: Vec::new(),
extractor: FeatureExtractor::new(),
retrain_threshold: 100,
samples_at_last_train: 0,
feedback: Vec::new(),
online_learning_enabled: false, error_window: Vec::new(),
error_window_size: Self::DEFAULT_ERROR_WINDOW_SIZE,
}
}
pub fn with_online_learning() -> Self {
let mut collector = Self::new();
collector.online_learning_enabled = true;
collector
}
pub fn enable_online_learning(&mut self) {
self.online_learning_enabled = true;
}
pub fn disable_online_learning(&mut self) {
self.online_learning_enabled = false;
}
pub fn is_online_learning_enabled(&self) -> bool {
self.online_learning_enabled
}
pub fn record(
&mut self,
profiler: &BrickProfiler,
config: &RunConfig,
kernel: KernelType,
) -> Option<()> {
let throughput_tps = profiler.tokens_per_sec()?;
let features = self.extractor.extract(profiler, config);
let bottleneck = features.bottleneck_class.unwrap_or(BottleneckClass::Unknown);
let sample = TrainingSample {
features,
throughput_tps,
best_kernel: kernel,
bottleneck,
timestamp: chrono_lite_now(),
hardware_id: "unknown".to_string(),
};
self.samples.push(sample);
Some(())
}
pub fn samples(&self) -> &[TrainingSample] {
&self.samples
}
pub fn len(&self) -> usize {
self.samples.len()
}
pub fn is_empty(&self) -> bool {
self.samples.is_empty()
}
pub fn to_json(&self) -> Result<String, TunerError> {
serde_json::to_string_pretty(&self.samples)
.map_err(|e| TunerError::Serialization(e.to_string()))
}
pub fn prepare_training_data(&self) -> Vec<(TunerFeatures, f32)> {
self.samples.iter().map(|s| (s.features.clone(), s.throughput_tps)).collect()
}
pub fn ready_to_train(&self) -> bool {
self.samples.len() >= Self::MIN_SAMPLES_FOR_TRAINING
}
pub fn train_if_ready(&self) -> Option<BrickTuner> {
if !self.ready_to_train() {
return None;
}
let training_data = self.prepare_training_data();
let mut tuner = BrickTuner::new();
match tuner.train(&training_data) {
Ok(()) => Some(tuner),
Err(_) => None,
}
}
pub fn training_progress(&self) -> (usize, usize) {
(self.samples.len(), Self::MIN_SAMPLES_FOR_TRAINING)
}
pub fn merge(&mut self, other: &TunerDataCollector) {
self.samples.extend(other.samples.iter().cloned());
}
}