Skip to main content

trueno/tuner/brick_tuner/
mod.rs

1#![allow(missing_docs)]
2//! BrickTuner - ML-based ComputeBrick Tuner Ensemble
3//!
4//! Combines throughput regression, kernel classification, and bottleneck analysis.
5
6mod persistence;
7mod render;
8
9use serde::{Deserialize, Serialize};
10
11use super::error::TunerError;
12use super::features::TunerFeatures;
13use super::helpers::chrono_lite_now;
14use super::models::{
15    BottleneckClassifier, BottleneckPrediction, KernelClassifier, KernelRecommendation,
16    ThroughputPrediction, ThroughputRegressor,
17};
18use super::types::{BottleneckClass, KernelType};
19
20// ============================================================================
21// TunerRecommendation
22// ============================================================================
23
24/// Combined tuner recommendation
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct TunerRecommendation {
27    /// Throughput prediction
28    pub throughput: ThroughputPrediction,
29    /// Kernel recommendation
30    pub kernel: KernelRecommendation,
31    /// Bottleneck analysis
32    pub bottleneck: BottleneckPrediction,
33    /// Model version
34    pub model_version: String,
35    /// Overall confidence
36    pub confidence_overall: f32,
37    /// Suggested experiments to try
38    pub suggested_experiments: Vec<ExperimentSuggestion>,
39}
40
41/// Suggested experiment to improve performance
42#[derive(Debug, Clone, Serialize, Deserialize)]
43pub enum ExperimentSuggestion {
44    /// Increase batch size
45    IncreaseBatchSize { from: u32, to: u32 },
46    /// Enable CUDA graphs
47    EnableCudaGraphs,
48    /// Try a specific kernel
49    TryKernel { kernel: KernelType },
50    /// Reduce sequence length
51    ReduceSequenceLength { factor: f32 },
52    /// Enable multi-KV cache
53    EnableMultiKvCache { count: u32 },
54}
55
56impl std::fmt::Display for ExperimentSuggestion {
57    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58        match self {
59            ExperimentSuggestion::IncreaseBatchSize { from, to } => {
60                write!(f, "Increase batch size: M={} → M={}", from, to)
61            }
62            ExperimentSuggestion::EnableCudaGraphs => {
63                write!(f, "Enable CUDA graphs for kernel launch amortization")
64            }
65            ExperimentSuggestion::TryKernel { kernel } => {
66                write!(f, "Try kernel: {:?}", kernel)
67            }
68            ExperimentSuggestion::ReduceSequenceLength { factor } => {
69                write!(f, "Reduce sequence length by {:.0}%", (1.0 - factor) * 100.0)
70            }
71            ExperimentSuggestion::EnableMultiKvCache { count } => {
72                write!(f, "Enable {} separate KV caches for batched attention", count)
73            }
74        }
75    }
76}
77
78// ============================================================================
79// BrickTuner
80// ============================================================================
81
82/// ML-based ComputeBrick tuner ensemble.
83///
84/// Combines three models for comprehensive recommendations:
85/// - ThroughputRegressor: Predicts tok/s
86/// - KernelClassifier: Selects best kernel
87/// - BottleneckClassifier: Identifies performance bottleneck
88#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct BrickTuner {
90    /// Throughput regression model
91    pub(crate) throughput: ThroughputRegressor,
92    /// Kernel classification model
93    pub(crate) kernel: KernelClassifier,
94    /// Bottleneck classification model
95    pub(crate) bottleneck: BottleneckClassifier,
96    /// Model version
97    pub(crate) version: String,
98    /// Training timestamp
99    pub(crate) trained_at: String,
100    /// Number of training samples
101    pub(crate) sample_count: usize,
102}
103
104impl Default for BrickTuner {
105    fn default() -> Self {
106        Self::new()
107    }
108}
109
110impl BrickTuner {
111    /// Model version
112    pub const VERSION: &'static str = "1.0.0";
113
114    /// Create a new tuner with default models
115    pub fn new() -> Self {
116        Self {
117            throughput: ThroughputRegressor::new(),
118            kernel: KernelClassifier::new(),
119            bottleneck: BottleneckClassifier::new(),
120            version: Self::VERSION.to_string(),
121            trained_at: chrono_lite_now(),
122            sample_count: 0,
123        }
124    }
125
126    /// Get the model version string
127    pub fn version(&self) -> &str {
128        &self.version
129    }
130
131    /// Get the throughput regressor's MAPE (Mean Absolute Percentage Error)
132    pub fn throughput_mape(&self) -> f32 {
133        self.throughput.mape
134    }
135
136    /// Get the number of training samples used
137    pub fn throughput_sample_count(&self) -> usize {
138        self.throughput.sample_count
139    }
140
141    /// Get comprehensive tuning recommendation
142    pub fn recommend(&self, features: &TunerFeatures) -> TunerRecommendation {
143        let throughput = self.throughput.predict(features);
144        let kernel = self.kernel.predict(features);
145        let bottleneck = self.bottleneck.predict(features);
146
147        // Calculate overall confidence
148        let confidence_overall =
149            (throughput.confidence + kernel.confidence + bottleneck.confidence) / 3.0;
150
151        // Generate experiment suggestions based on bottleneck
152        let suggested_experiments = self.suggest_experiments(features, &bottleneck);
153
154        TunerRecommendation {
155            throughput,
156            kernel,
157            bottleneck,
158            model_version: self.version.clone(),
159            confidence_overall,
160            suggested_experiments,
161        }
162    }
163
164    /// Suggest experiments based on current bottleneck
165    pub fn suggest_experiments(
166        &self,
167        features: &TunerFeatures,
168        bottleneck: &BottleneckPrediction,
169    ) -> Vec<ExperimentSuggestion> {
170        let mut suggestions = Vec::new();
171        let batch_size = (features.batch_size_norm * 64.0).round() as u32;
172
173        match bottleneck.class {
174            BottleneckClass::MemoryBound => {
175                if batch_size < 8 {
176                    suggestions.push(ExperimentSuggestion::IncreaseBatchSize {
177                        from: batch_size,
178                        to: (batch_size * 2).min(8),
179                    });
180                }
181                suggestions
182                    .push(ExperimentSuggestion::TryKernel { kernel: KernelType::BatchedQ4K });
183                if batch_size > 1 {
184                    suggestions
185                        .push(ExperimentSuggestion::EnableMultiKvCache { count: batch_size });
186                }
187            }
188            BottleneckClass::LaunchBound => {
189                if features.cuda_graphs < 0.5 {
190                    suggestions.push(ExperimentSuggestion::EnableCudaGraphs);
191                }
192                suggestions
193                    .push(ExperimentSuggestion::TryKernel { kernel: KernelType::FusedRmsNormQ4K });
194            }
195            BottleneckClass::AttentionBound => {
196                suggestions
197                    .push(ExperimentSuggestion::TryKernel { kernel: KernelType::BatchedAttention });
198                suggestions.push(ExperimentSuggestion::ReduceSequenceLength { factor: 0.5 });
199            }
200            _ => {
201                // Default suggestions
202                if batch_size < 4 {
203                    suggestions
204                        .push(ExperimentSuggestion::IncreaseBatchSize { from: batch_size, to: 4 });
205                }
206            }
207        }
208
209        suggestions
210    }
211
212    /// Train all models on labeled data
213    pub fn train(&mut self, data: &[(TunerFeatures, f32)]) -> Result<(), TunerError> {
214        self.throughput.train(data)?;
215        self.sample_count = data.len();
216        self.trained_at = chrono_lite_now();
217        Ok(())
218    }
219}