Skip to main content

trueno/tuner/features/
extractor.rs

1#![allow(missing_docs)]
2//! Feature extraction and runtime configuration.
3//!
4//! Implements `FeatureExtractor` and `RunConfig`.
5
6use crate::brick::{BrickCategory, BrickProfiler};
7use crate::hardware::HardwareCapability;
8use serde::{Deserialize, Serialize};
9
10use crate::tuner::types::{BottleneckClass, KernelType, QuantType};
11
12use super::TunerFeatures;
13
14// ============================================================================
15// FeatureExtractor
16// ============================================================================
17
18/// Extracts features from BrickProfiler and runtime configuration.
19#[derive(Debug)]
20pub struct FeatureExtractor {
21    /// Hardware capability (cached)
22    pub(crate) hardware: Option<HardwareCapability>,
23}
24
25impl Default for FeatureExtractor {
26    fn default() -> Self {
27        Self::new()
28    }
29}
30
31impl FeatureExtractor {
32    /// Create a new feature extractor
33    pub fn new() -> Self {
34        Self { hardware: None }
35    }
36
37    /// Create with hardware capability
38    pub fn with_hardware(hardware: HardwareCapability) -> Self {
39        Self { hardware: Some(hardware) }
40    }
41
42    /// Extract features from profiler and configuration
43    pub fn extract(&self, profiler: &BrickProfiler, config: &RunConfig) -> TunerFeatures {
44        let mut builder = TunerFeatures::builder()
45            .model_params_b(config.model_params_b)
46            .hidden_dim(config.hidden_dim)
47            .num_layers(config.num_layers)
48            .num_heads(config.num_heads)
49            .batch_size(config.batch_size)
50            .seq_len(config.seq_len)
51            .cuda_graphs(config.cuda_graphs)
52            .quant_type(config.quant_type)
53            .kernel_type(config.kernel_type);
54
55        // Add hardware features if available
56        if let Some(hw) = &self.hardware {
57            builder = builder.hardware(hw);
58        }
59
60        // Add measured throughput if available
61        if let Some(tps) = profiler.tokens_per_sec() {
62            builder = builder.measured_tps(tps);
63        }
64
65        let mut features = builder.build();
66
67        // Update derived features from profiler
68        if let Some(efficiency) = self.calculate_efficiency(profiler, config) {
69            features.theoretical_efficiency = efficiency;
70        }
71
72        // Classify bottleneck from profiler data
73        features.bottleneck_class = Some(self.classify_bottleneck(profiler));
74
75        features
76    }
77
78    /// Calculate efficiency from profiler data
79    #[allow(clippy::cast_precision_loss, clippy::cast_possible_truncation)]
80    // SAFETY: GPU bandwidth f64→f32 truncation is negligible for roofline efficiency calculation.
81    pub fn calculate_efficiency(
82        &self,
83        profiler: &BrickProfiler,
84        config: &RunConfig,
85    ) -> Option<f32> {
86        let measured_tps = profiler.tokens_per_sec()?;
87        let hw = self.hardware.as_ref()?;
88        let gpu = hw.gpu.as_ref()?;
89
90        // Calculate theoretical max based on roofline
91        let bytes_per_token = config.model_params_b * 1e9 * config.quant_type.bytes_per_param();
92        let theoretical_tps = (gpu.memory_bw_gbps as f32) * 1e9 / bytes_per_token;
93
94        Some((measured_tps / theoretical_tps).clamp(0.0, 1.0))
95    }
96
97    /// Classify bottleneck from profiler brick breakdown.
98    ///
99    /// PAR-200: Uses category_stats() for efficient aggregation.
100    #[allow(clippy::cast_precision_loss, clippy::cast_possible_truncation)]
101    // SAFETY: percentage() returns f64 in 0–100; f64→f32 truncation is negligible for threshold comparisons.
102    pub fn classify_bottleneck(&self, profiler: &BrickProfiler) -> BottleneckClass {
103        let cats = profiler.category_stats();
104        let total_ns = profiler.total_ns();
105
106        if total_ns == 0 {
107            return BottleneckClass::Unknown;
108        }
109
110        // Get category percentages
111        let attention_pct =
112            cats[BrickCategory::Attention as usize].percentage(total_ns) as f32 / 100.0;
113        let ffn_pct = cats[BrickCategory::Ffn as usize].percentage(total_ns) as f32 / 100.0;
114        let norm_pct = cats[BrickCategory::Norm as usize].percentage(total_ns) as f32 / 100.0;
115
116        // Classify based on dominant component
117        if attention_pct > 0.35 {
118            BottleneckClass::AttentionBound
119        } else if ffn_pct > 0.50 {
120            // FFN is memory-bound (large GEMV operations)
121            BottleneckClass::MemoryBound
122        } else if norm_pct > 0.20 {
123            // High norm percentage indicates launch overhead
124            BottleneckClass::LaunchBound
125        } else {
126            BottleneckClass::MemoryBound // Default for inference
127        }
128    }
129}
130
131// ============================================================================
132// RunConfig
133// ============================================================================
134
135/// Runtime configuration for feature extraction
136#[derive(Debug, Clone, Serialize, Deserialize)]
137pub struct RunConfig {
138    pub model_params_b: f32,
139    pub hidden_dim: u32,
140    pub num_layers: u32,
141    pub num_heads: u32,
142    pub batch_size: u32,
143    pub seq_len: u32,
144    pub cuda_graphs: bool,
145    pub quant_type: QuantType,
146    pub kernel_type: KernelType,
147}
148
149/// Default hidden dimension for 1.5B parameter model
150const DEFAULT_HIDDEN_DIM: u32 = 1536;
151
152impl Default for RunConfig {
153    fn default() -> Self {
154        Self {
155            model_params_b: 1.5,
156            hidden_dim: DEFAULT_HIDDEN_DIM,
157            num_layers: 28,
158            num_heads: 12,
159            batch_size: 1,
160            seq_len: 1,
161            cuda_graphs: false,
162            quant_type: QuantType::Q4K,
163            kernel_type: KernelType::VectorizedQ4K,
164        }
165    }
166}