trueno/tuner/models/
mod.rs1#![allow(missing_docs)]
2mod kernel;
7mod throughput;
8
9pub use kernel::KernelClassifier;
10pub use throughput::ThroughputRegressor;
11
12use serde::{Deserialize, Serialize};
13
14use super::features::TunerFeatures;
15use super::types::{BottleneckClass, KernelType};
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct ThroughputPrediction {
24 pub predicted_tps: f32,
26 pub confidence: f32,
28 pub top_features: Vec<(String, f32)>,
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct KernelRecommendation {
35 pub top_kernel: KernelType,
37 pub confidence: f32,
39 pub alternatives: Vec<(KernelType, f32)>,
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct BottleneckPrediction {
46 pub class: BottleneckClass,
48 pub confidence: f32,
50 pub explanation: String,
52 pub recommended_action: String,
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize, Default)]
62pub struct BottleneckClassifier {
63 accuracy: f32,
65}
66
67impl BottleneckClassifier {
68 pub fn new() -> Self {
69 Self { accuracy: 0.90 }
70 }
71
72 pub fn predict(&self, features: &TunerFeatures) -> BottleneckPrediction {
74 if let Some(class) = features.bottleneck_class {
76 return BottleneckPrediction {
77 class,
78 confidence: 0.95,
79 explanation: format!("Bottleneck classified from profiler data: {}", class),
80 recommended_action: class.recommended_action().to_string(),
81 };
82 }
83
84 let batch_size = (features.batch_size_norm * 64.0).round() as u32;
86 let seq_len = (2.0_f32.powf(features.seq_len_log * 15.0)).round() as u32;
87
88 let (class, confidence, explanation) = if batch_size == 1 && features.cuda_graphs < 0.5 {
89 (
90 BottleneckClass::LaunchBound,
91 0.75,
92 "Single sequence without CUDA graphs: kernel launch overhead may dominate".into(),
93 )
94 } else if seq_len > 512 {
95 (
96 BottleneckClass::AttentionBound,
97 0.80,
98 format!("Long sequence (len={}) likely makes attention the bottleneck", seq_len),
99 )
100 } else {
101 (
102 BottleneckClass::MemoryBound,
103 0.85,
104 "Q4K GEMV is typically memory-bound for LLM inference".into(),
105 )
106 };
107
108 BottleneckPrediction {
109 class,
110 confidence,
111 explanation,
112 recommended_action: class.recommended_action().to_string(),
113 }
114 }
115}