Skip to main content

trueno/tuner/models/
kernel.rs

1#![allow(missing_docs)]
2//! Kernel classifier model for ML tuner.
3
4use serde::{Deserialize, Serialize};
5
6#[cfg(feature = "ml-tuner")]
7use aprender::{tree::RandomForestClassifier, Matrix};
8
9#[cfg(feature = "ml-tuner")]
10use super::super::error::TunerError;
11use super::super::features::TunerFeatures;
12use super::super::types::KernelType;
13use super::KernelRecommendation;
14
15/// Kernel classifier using simple rule-based logic.
16///
17/// With `ml-tuner` feature: uses aprender::RandomForestClassifier (SHOWCASE-BRICK-001)
18#[derive(Debug, Clone, Serialize, Deserialize, Default)]
19pub struct KernelClassifier {
20    /// Kernel accuracy on validation (for confidence)
21    accuracy: f32,
22    /// RandomForest classifier when ml-tuner feature is enabled
23    #[cfg(feature = "ml-tuner")]
24    #[serde(skip)]
25    rf_classifier: Option<RandomForestClassifier>,
26}
27
28impl KernelClassifier {
29    pub fn new() -> Self {
30        Self {
31            accuracy: 0.85,
32            #[cfg(feature = "ml-tuner")]
33            rf_classifier: None,
34        }
35    }
36
37    /// Create a classifier with aprender RandomForest (ml-tuner feature)
38    #[cfg(feature = "ml-tuner")]
39    pub fn with_random_forest(n_estimators: usize) -> Self {
40        Self { accuracy: 0.85, rf_classifier: Some(RandomForestClassifier::new(n_estimators)) }
41    }
42
43    /// Train the classifier using aprender RandomForest (ml-tuner feature)
44    ///
45    /// Labels should be kernel type indices (0=TiledQ4K, 1=CoalescedQ4K, etc.)
46    #[cfg(feature = "ml-tuner")]
47    pub fn train(&mut self, data: &[(TunerFeatures, u32)]) -> Result<(), TunerError> {
48        if data.len() < 10 {
49            return Err(TunerError::InsufficientData(data.len()));
50        }
51
52        // Convert to aprender format (Matrix<f32> for features, &[usize] for labels)
53        let n_samples = data.len();
54        let n_features = TunerFeatures::DIM;
55        let mut x_data = Vec::with_capacity(n_samples * n_features);
56        let mut y_data: Vec<usize> = Vec::with_capacity(n_samples);
57
58        for (features, label) in data {
59            x_data.extend(features.to_vector());
60            y_data.push(*label as usize);
61        }
62
63        let x_matrix = Matrix::from_vec(n_samples, n_features, x_data)
64            .map_err(|e| TunerError::TrainingFailed(e.to_string()))?;
65
66        let rf = self.rf_classifier.get_or_insert_with(|| RandomForestClassifier::new(50));
67        rf.fit(&x_matrix, &y_data).map_err(|e| TunerError::TrainingFailed(e.to_string()))?;
68
69        // Calculate accuracy on training data
70        let predictions = rf.predict(&x_matrix);
71        let mut correct = 0;
72        for (i, (_, label)) in data.iter().enumerate() {
73            if predictions[i] as u32 == *label {
74                correct += 1;
75            }
76        }
77        self.accuracy = correct as f32 / data.len().max(1) as f32;
78
79        Ok(())
80    }
81
82    /// Predict best kernel based on features
83    pub fn predict(&self, features: &TunerFeatures) -> KernelRecommendation {
84        // Rule-based kernel selection from SHOWCASE-BRICK-001 learnings
85        let batch_size = (features.batch_size_norm * 64.0).round() as u32;
86        let seq_len = (2.0_f32.powf(features.seq_len_log * 15.0)).round() as u32;
87
88        // Determine best Q4K variant based on batch size
89        let (top_kernel, confidence) = if batch_size >= 4 {
90            // M >= 4: Use batched kernels
91            (KernelType::BatchedQ4K, 0.90)
92        } else if batch_size >= 2 {
93            // M = 2-3: Vectorized is good
94            (KernelType::VectorizedQ4K, 0.85)
95        } else {
96            // M = 1: Coalesced or Vectorized
97            if features.cuda_graphs > 0.5 {
98                (KernelType::VectorizedQ4K, 0.88)
99            } else {
100                (KernelType::CoalescedQ4K, 0.82)
101            }
102        };
103
104        // Check for attention-bound cases
105        let attention_kernel = if seq_len > 128 {
106            KernelType::MultiWarpAttention
107        } else {
108            KernelType::IncrementalAttention
109        };
110
111        // Build alternatives
112        let alternatives = vec![
113            (KernelType::VectorizedQ4K, 0.85),
114            (KernelType::CoalescedQ4K, 0.75),
115            (attention_kernel, 0.70),
116        ]
117        .into_iter()
118        .filter(|(k, _)| *k != top_kernel)
119        .take(2)
120        .collect();
121
122        KernelRecommendation { top_kernel, confidence, alternatives }
123    }
124}