Skip to main content

trueno/tuner/models/
mod.rs

1#![allow(missing_docs)]
2//! ML Models for Tuner
3//!
4//! Throughput regressor, kernel classifier, and bottleneck classifier implementations.
5
6mod kernel;
7mod throughput;
8
9pub use kernel::KernelClassifier;
10pub use throughput::ThroughputRegressor;
11
12use serde::{Deserialize, Serialize};
13
14use super::features::TunerFeatures;
15use super::types::{BottleneckClass, KernelType};
16
17// ============================================================================
18// Prediction Results
19// ============================================================================
20
21/// Throughput prediction result
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct ThroughputPrediction {
24    /// Predicted tokens per second
25    pub predicted_tps: f32,
26    /// Confidence (0-1)
27    pub confidence: f32,
28    /// Top contributing features
29    pub top_features: Vec<(String, f32)>,
30}
31
32/// Kernel recommendation result
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct KernelRecommendation {
35    /// Top recommended kernel
36    pub top_kernel: KernelType,
37    /// Confidence (0-1)
38    pub confidence: f32,
39    /// Alternative kernels with probabilities
40    pub alternatives: Vec<(KernelType, f32)>,
41}
42
43/// Bottleneck prediction result
44#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct BottleneckPrediction {
46    /// Predicted bottleneck class
47    pub class: BottleneckClass,
48    /// Confidence (0-1)
49    pub confidence: f32,
50    /// Human-readable explanation
51    pub explanation: String,
52    /// Recommended action
53    pub recommended_action: String,
54}
55
56// ============================================================================
57// BottleneckClassifier
58// ============================================================================
59
60/// Bottleneck classifier using heuristics from profiler data.
61#[derive(Debug, Clone, Serialize, Deserialize, Default)]
62pub struct BottleneckClassifier {
63    /// Classification accuracy
64    accuracy: f32,
65}
66
67impl BottleneckClassifier {
68    pub fn new() -> Self {
69        Self { accuracy: 0.90 }
70    }
71
72    /// Predict bottleneck from features
73    pub fn predict(&self, features: &TunerFeatures) -> BottleneckPrediction {
74        // Use already-computed bottleneck if available
75        if let Some(class) = features.bottleneck_class {
76            return BottleneckPrediction {
77                class,
78                confidence: 0.95,
79                explanation: format!("Bottleneck classified from profiler data: {}", class),
80                recommended_action: class.recommended_action().to_string(),
81            };
82        }
83
84        // Heuristic classification based on features
85        let batch_size = (features.batch_size_norm * 64.0).round() as u32;
86        let seq_len = (2.0_f32.powf(features.seq_len_log * 15.0)).round() as u32;
87
88        let (class, confidence, explanation) = if batch_size == 1 && features.cuda_graphs < 0.5 {
89            (
90                BottleneckClass::LaunchBound,
91                0.75,
92                "Single sequence without CUDA graphs: kernel launch overhead may dominate".into(),
93            )
94        } else if seq_len > 512 {
95            (
96                BottleneckClass::AttentionBound,
97                0.80,
98                format!("Long sequence (len={}) likely makes attention the bottleneck", seq_len),
99            )
100        } else {
101            (
102                BottleneckClass::MemoryBound,
103                0.85,
104                "Q4K GEMV is typically memory-bound for LLM inference".into(),
105            )
106        };
107
108        BottleneckPrediction {
109            class,
110            confidence,
111            explanation,
112            recommended_action: class.recommended_action().to_string(),
113        }
114    }
115}