1use smallvec::SmallVec;
11use vyre_foundation::optimizer::eqsat::{extract_best, EClassId, EGraph, ENodeLang};
12
13use crate::autotune_store::AutotuneRecord;
14use crate::device_profile::DeviceProfile;
15use crate::extraction_cost::{device_aware_cost, NodeHints};
16use crate::trace_jit_policy::{decide_trace_jit_speculation, TraceJitDecision, TraceJitInputs};
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub struct ExtractionDevice<'a> {
21 pub profile: &'a DeviceProfile,
23 pub autotune_record: Option<&'a AutotuneRecord>,
25 pub trace_jit: Option<TraceJitInputs>,
27 pub hot_path: bool,
29}
30
31impl<'a> ExtractionDevice<'a> {
32 #[must_use]
34 pub const fn new(profile: &'a DeviceProfile, hot_path: bool) -> Self {
35 Self {
36 profile,
37 autotune_record: None,
38 trace_jit: None,
39 hot_path,
40 }
41 }
42
43 #[must_use]
45 pub const fn with_autotune_record(mut self, record: &'a AutotuneRecord) -> Self {
46 self.autotune_record = Some(record);
47 self
48 }
49
50 #[must_use]
52 pub const fn with_trace_jit(mut self, counters: TraceJitInputs) -> Self {
53 self.trace_jit = Some(counters);
54 self
55 }
56}
57
58#[derive(Debug, Clone, PartialEq, Eq)]
60pub struct DeviceExtraction<L> {
61 pub backend: &'static str,
63 pub hot_path: bool,
65 pub node: L,
67 pub cost: u64,
69}
70
71#[must_use]
73pub fn extract_best_for_device<L, B, H>(
74 egraph: &EGraph<L>,
75 root: EClassId,
76 device: ExtractionDevice<'_>,
77 base_cost: B,
78 hint_lookup: H,
79) -> Option<DeviceExtraction<L>>
80where
81 L: ENodeLang,
82 B: Fn(&L) -> u64,
83 H: Fn(&L) -> NodeHints,
84{
85 if root.0 as usize >= egraph.class_count() {
86 return None;
87 }
88 let profile_cost = device_aware_cost(device.profile, device.hot_path, &base_cost, &hint_lookup);
89 let cost = |node: &L| {
90 let hints = hint_lookup(node);
91 let cost = profile_cost(node);
92 apply_context_bias(cost, extraction_bias_bps(device, hints))
93 };
94 extract_best(egraph, root, cost).map(|(node, cost)| DeviceExtraction {
95 backend: device.profile.backend,
96 hot_path: device.hot_path,
97 node,
98 cost,
99 })
100}
101
102#[must_use]
108pub fn extract_best_for_devices<'a, L, B, H>(
109 egraph: &EGraph<L>,
110 root: EClassId,
111 devices: impl IntoIterator<Item = ExtractionDevice<'a>>,
112 base_cost: B,
113 hint_lookup: H,
114) -> SmallVec<[DeviceExtraction<L>; 4]>
115where
116 L: ENodeLang,
117 B: Fn(&L) -> u64,
118 H: Fn(&L) -> NodeHints,
119{
120 let mut out = SmallVec::new();
121 for device in devices {
122 if let Some(extracted) =
123 extract_best_for_device(egraph, root, device, &base_cost, &hint_lookup)
124 {
125 out.push(extracted);
126 }
127 }
128 out
129}
130
131fn extraction_bias_bps(device: ExtractionDevice<'_>, hints: NodeHints) -> u32 {
132 let mut bps = 10_000u32;
133 if let Some(record) = device.autotune_record {
134 if hints.compile_time_constant && record.unroll > 1 {
135 bps = scale_bps(bps, 8_000);
136 }
137 if hints.fp16_eligible && record.tile.iter().any(|dim| *dim > 1) {
138 bps = scale_bps(bps, 9_500);
139 }
140 }
141 if hints.compile_time_constant {
142 if let Some(counters) = device.trace_jit {
143 if matches!(
144 decide_trace_jit_speculation(counters),
145 TraceJitDecision::Speculate { .. }
146 ) {
147 bps = scale_bps(bps, 7_000);
148 }
149 }
150 }
151 bps.max(1)
152}
153
154fn scale_bps(lhs_bps: u32, rhs_bps: u32) -> u32 {
155 crate::numeric::compose_basis_points_u32(
156 lhs_bps,
157 rhs_bps,
158 "device extraction bias composition",
159 "driver",
160 )
161}
162
163fn apply_context_bias(cost: u64, bps: u32) -> u64 {
164 if bps >= 10_000 {
165 return cost;
166 }
167 crate::numeric::scale_u64_by_basis_points_floor_min(
168 cost,
169 bps,
170 1,
171 "device extraction context bias",
172 "driver",
173 )
174}
175
176#[cfg(test)]
177mod tests {
178 use super::*;
179 use vyre_foundation::optimizer::eqsat::{EChildren, EGraph};
180
181 #[derive(Debug, Clone, PartialEq, Eq, Hash)]
182 enum Toy {
183 Scalar,
184 TensorCore,
185 Specialized,
186 }
187
188 impl ENodeLang for Toy {
189 fn children(&self) -> EChildren {
190 EChildren::new()
191 }
192
193 fn with_children(&self, _children: &[EClassId]) -> Self {
194 self.clone()
195 }
196 }
197
198 fn base_cost(node: &Toy) -> u64 {
199 match node {
200 Toy::Scalar => 10,
201 Toy::TensorCore => 30,
202 Toy::Specialized => 11,
203 }
204 }
205
206 fn hints(node: &Toy) -> NodeHints {
207 match node {
208 Toy::TensorCore => NodeHints {
209 fp16_eligible: true,
210 compile_time_constant: false,
211 },
212 Toy::Specialized => NodeHints {
213 fp16_eligible: false,
214 compile_time_constant: true,
215 },
216 Toy::Scalar => NodeHints::default(),
217 }
218 }
219
220 fn equivalent_toy_graph() -> (EGraph<Toy>, EClassId) {
221 let mut graph = EGraph::new();
222 let scalar = graph.add(Toy::Scalar);
223 let tensor = graph.add(Toy::TensorCore);
224 graph.union(scalar, tensor);
225 graph.rebuild();
226 (graph, scalar)
227 }
228
229 fn specialized_toy_graph() -> (EGraph<Toy>, EClassId) {
230 let mut graph = EGraph::new();
231 let scalar = graph.add(Toy::Scalar);
232 let specialized = graph.add(Toy::Specialized);
233 graph.union(scalar, specialized);
234 graph.rebuild();
235 (graph, scalar)
236 }
237
238 #[test]
239 fn conservative_profile_extracts_scalar_variant() {
240 let (graph, root) = equivalent_toy_graph();
241 let profile = DeviceProfile::conservative("portable");
242 let extracted = extract_best_for_device(
243 &graph,
244 root,
245 ExtractionDevice::new(&profile, true),
246 base_cost,
247 hints,
248 )
249 .expect("Fix: equivalent toy graph must extract");
250
251 assert_eq!(extracted.backend, "portable");
252 assert_eq!(extracted.node, Toy::Scalar);
253 assert_eq!(extracted.cost, 5);
254 }
255
256 #[test]
257 fn tensor_core_profile_extracts_fp16_variant() {
258 let (graph, root) = equivalent_toy_graph();
259 let mut profile = DeviceProfile::conservative("native");
260 profile.supports_f16 = true;
261 profile.supports_tensor_cores = true;
262
263 let extracted = extract_best_for_device(
264 &graph,
265 root,
266 ExtractionDevice::new(&profile, true),
267 base_cost,
268 hints,
269 )
270 .expect("Fix: equivalent toy graph must extract");
271
272 assert_eq!(extracted.backend, "native");
273 assert_eq!(extracted.node, Toy::TensorCore);
274 assert_eq!(extracted.cost, 4);
275 }
276
277 #[test]
278 fn several_devices_extract_from_one_saturated_graph() {
279 let (graph, root) = equivalent_toy_graph();
280 let portable = DeviceProfile::conservative("portable");
281 let mut native = DeviceProfile::conservative("native");
282 native.supports_f16 = true;
283 native.supports_tensor_cores = true;
284
285 let variants = extract_best_for_devices(
286 &graph,
287 root,
288 [
289 ExtractionDevice::new(&portable, true),
290 ExtractionDevice::new(&native, true),
291 ],
292 base_cost,
293 hints,
294 );
295
296 assert_eq!(variants.len(), 2);
297 assert_eq!(variants[0].node, Toy::Scalar);
298 assert_eq!(variants[1].node, Toy::TensorCore);
299 }
300
301 #[test]
302 fn autotune_record_biases_compile_time_constant_variant() {
303 let (graph, root) = specialized_toy_graph();
304 let profile = DeviceProfile::conservative("native");
305 let record = AutotuneRecord {
306 workgroup_size: [128, 1, 1],
307 unroll: 4,
308 tile: [0, 0, 0],
309 recorded_at: String::new(),
310 };
311
312 let extracted = extract_best_for_device(
313 &graph,
314 root,
315 ExtractionDevice::new(&profile, true).with_autotune_record(&record),
316 base_cost,
317 hints,
318 )
319 .expect("Fix: equivalent toy graph must extract");
320
321 assert_eq!(extracted.node, Toy::Specialized);
322 assert_eq!(extracted.cost, 4);
323 }
324
325 #[test]
326 fn trace_jit_biases_specialized_variant_when_speculation_pays() {
327 let (graph, root) = specialized_toy_graph();
328 let profile = DeviceProfile::conservative("native");
329 let counters = TraceJitInputs {
330 shader_hit_count: 64,
331 prediction_confidence_bps: 10_000,
332 speculative_spec_cost_ns: 1,
333 miss_cost_ns: 1_000_000,
334 };
335
336 let extracted = extract_best_for_device(
337 &graph,
338 root,
339 ExtractionDevice::new(&profile, true).with_trace_jit(counters),
340 base_cost,
341 hints,
342 )
343 .expect("Fix: equivalent toy graph must extract");
344
345 assert_eq!(extracted.node, Toy::Specialized);
346 assert_eq!(extracted.cost, 4);
347 }
348
349 #[test]
350 fn missing_root_returns_no_variant() {
351 let graph: EGraph<Toy> = EGraph::new();
352 let profile = DeviceProfile::conservative("portable");
353 let variants = extract_best_for_devices(
354 &graph,
355 EClassId(77),
356 [ExtractionDevice::new(&profile, true)],
357 base_cost,
358 hints,
359 );
360
361 assert!(variants.is_empty());
362 }
363}