#![allow(missing_docs)]
use serde::{Deserialize, Serialize};
#[cfg(feature = "ml-tuner")]
use aprender::{tree::RandomForestClassifier, Matrix};
#[cfg(feature = "ml-tuner")]
use super::super::error::TunerError;
use super::super::features::TunerFeatures;
use super::super::types::KernelType;
use super::KernelRecommendation;
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct KernelClassifier {
accuracy: f32,
#[cfg(feature = "ml-tuner")]
#[serde(skip)]
rf_classifier: Option<RandomForestClassifier>,
}
impl KernelClassifier {
pub fn new() -> Self {
Self {
accuracy: 0.85,
#[cfg(feature = "ml-tuner")]
rf_classifier: None,
}
}
#[cfg(feature = "ml-tuner")]
pub fn with_random_forest(n_estimators: usize) -> Self {
Self { accuracy: 0.85, rf_classifier: Some(RandomForestClassifier::new(n_estimators)) }
}
#[cfg(feature = "ml-tuner")]
pub fn train(&mut self, data: &[(TunerFeatures, u32)]) -> Result<(), TunerError> {
if data.len() < 10 {
return Err(TunerError::InsufficientData(data.len()));
}
let n_samples = data.len();
let n_features = TunerFeatures::DIM;
let mut x_data = Vec::with_capacity(n_samples * n_features);
let mut y_data: Vec<usize> = Vec::with_capacity(n_samples);
for (features, label) in data {
x_data.extend(features.to_vector());
y_data.push(*label as usize);
}
let x_matrix = Matrix::from_vec(n_samples, n_features, x_data)
.map_err(|e| TunerError::TrainingFailed(e.to_string()))?;
let rf = self.rf_classifier.get_or_insert_with(|| RandomForestClassifier::new(50));
rf.fit(&x_matrix, &y_data).map_err(|e| TunerError::TrainingFailed(e.to_string()))?;
let predictions = rf.predict(&x_matrix);
let mut correct = 0;
for (i, (_, label)) in data.iter().enumerate() {
if predictions[i] as u32 == *label {
correct += 1;
}
}
self.accuracy = correct as f32 / data.len().max(1) as f32;
Ok(())
}
pub fn predict(&self, features: &TunerFeatures) -> KernelRecommendation {
let batch_size = (features.batch_size_norm * 64.0).round() as u32;
let seq_len = (2.0_f32.powf(features.seq_len_log * 15.0)).round() as u32;
let (top_kernel, confidence) = if batch_size >= 4 {
(KernelType::BatchedQ4K, 0.90)
} else if batch_size >= 2 {
(KernelType::VectorizedQ4K, 0.85)
} else {
if features.cuda_graphs > 0.5 {
(KernelType::VectorizedQ4K, 0.88)
} else {
(KernelType::CoalescedQ4K, 0.82)
}
};
let attention_kernel = if seq_len > 128 {
KernelType::MultiWarpAttention
} else {
KernelType::IncrementalAttention
};
let alternatives = vec![
(KernelType::VectorizedQ4K, 0.85),
(KernelType::CoalescedQ4K, 0.75),
(attention_kernel, 0.70),
]
.into_iter()
.filter(|(k, _)| *k != top_kernel)
.take(2)
.collect();
KernelRecommendation { top_kernel, confidence, alternatives }
}
}