trueno/tuner/models/
kernel.rs1#![allow(missing_docs)]
2use serde::{Deserialize, Serialize};
5
6#[cfg(feature = "ml-tuner")]
7use aprender::{tree::RandomForestClassifier, Matrix};
8
9#[cfg(feature = "ml-tuner")]
10use super::super::error::TunerError;
11use super::super::features::TunerFeatures;
12use super::super::types::KernelType;
13use super::KernelRecommendation;
14
15#[derive(Debug, Clone, Serialize, Deserialize, Default)]
19pub struct KernelClassifier {
20 accuracy: f32,
22 #[cfg(feature = "ml-tuner")]
24 #[serde(skip)]
25 rf_classifier: Option<RandomForestClassifier>,
26}
27
28impl KernelClassifier {
29 pub fn new() -> Self {
30 Self {
31 accuracy: 0.85,
32 #[cfg(feature = "ml-tuner")]
33 rf_classifier: None,
34 }
35 }
36
37 #[cfg(feature = "ml-tuner")]
39 pub fn with_random_forest(n_estimators: usize) -> Self {
40 Self { accuracy: 0.85, rf_classifier: Some(RandomForestClassifier::new(n_estimators)) }
41 }
42
43 #[cfg(feature = "ml-tuner")]
47 pub fn train(&mut self, data: &[(TunerFeatures, u32)]) -> Result<(), TunerError> {
48 if data.len() < 10 {
49 return Err(TunerError::InsufficientData(data.len()));
50 }
51
52 let n_samples = data.len();
54 let n_features = TunerFeatures::DIM;
55 let mut x_data = Vec::with_capacity(n_samples * n_features);
56 let mut y_data: Vec<usize> = Vec::with_capacity(n_samples);
57
58 for (features, label) in data {
59 x_data.extend(features.to_vector());
60 y_data.push(*label as usize);
61 }
62
63 let x_matrix = Matrix::from_vec(n_samples, n_features, x_data)
64 .map_err(|e| TunerError::TrainingFailed(e.to_string()))?;
65
66 let rf = self.rf_classifier.get_or_insert_with(|| RandomForestClassifier::new(50));
67 rf.fit(&x_matrix, &y_data).map_err(|e| TunerError::TrainingFailed(e.to_string()))?;
68
69 let predictions = rf.predict(&x_matrix);
71 let mut correct = 0;
72 for (i, (_, label)) in data.iter().enumerate() {
73 if predictions[i] as u32 == *label {
74 correct += 1;
75 }
76 }
77 self.accuracy = correct as f32 / data.len().max(1) as f32;
78
79 Ok(())
80 }
81
82 pub fn predict(&self, features: &TunerFeatures) -> KernelRecommendation {
84 let batch_size = (features.batch_size_norm * 64.0).round() as u32;
86 let seq_len = (2.0_f32.powf(features.seq_len_log * 15.0)).round() as u32;
87
88 let (top_kernel, confidence) = if batch_size >= 4 {
90 (KernelType::BatchedQ4K, 0.90)
92 } else if batch_size >= 2 {
93 (KernelType::VectorizedQ4K, 0.85)
95 } else {
96 if features.cuda_graphs > 0.5 {
98 (KernelType::VectorizedQ4K, 0.88)
99 } else {
100 (KernelType::CoalescedQ4K, 0.82)
101 }
102 };
103
104 let attention_kernel = if seq_len > 128 {
106 KernelType::MultiWarpAttention
107 } else {
108 KernelType::IncrementalAttention
109 };
110
111 let alternatives = vec![
113 (KernelType::VectorizedQ4K, 0.85),
114 (KernelType::CoalescedQ4K, 0.75),
115 (attention_kernel, 0.70),
116 ]
117 .into_iter()
118 .filter(|(k, _)| *k != top_kernel)
119 .take(2)
120 .collect();
121
122 KernelRecommendation { top_kernel, confidence, alternatives }
123 }
124}