vyre_driver/strategy/mod.rs
1//! Backend-specific lowering strategies.
2//!
3//! # Two-Layer Optimization Architecture
4//!
5//! Vyre separates optimizations into two layers with clear separation of
6//! concerns:
7//!
8//! ## Layer 1 - IR-Level Passes (`vyre-foundation/src/optimizer/passes/`)
9//!
10//! Pure mathematical rewrites that transform `Expr → Expr` in the IR.
11//! Backend-agnostic - every backend benefits equally.
12//!
13//! | Pass | Example | Lives In |
14//! |------|---------|----------|
15//! | Strength reduce | `x / 7` → `mulhi(x, M) >> s` | `strength_reduce/` |
16//! | Const fold | `3 + 4` → `7` | `const_fold/` |
17//! | Shift-add decomp | `x * 5` → `(x<<2) + x` | `strength_reduce/` |
18//! | FMA synthesis | `a*b + c` → `fma(a,b,c)` | `strength_reduce/` |
19//! | Exact division | `(x*6)/3` → `x * inv(3)` | `strength_reduce/` |
20//! | Lemire remainder | `x % 7` → `lowbits(x*M)*7>>32` | `strength_reduce/` |
21//!
22//! ## Layer 2 - Backend Lowering Strategies (this module)
23//!
24//! Target-dependent emission decisions. These don't change WHAT the program
25//! computes - they change HOW it's emitted for a specific chip/API.
26//!
27//! | Strategy | Backend | Effect |
28//! |----------|---------|--------|
29//! | primary-binary native multiply-high | backend | `MulHigh` → 1 instruction |
30//! | secondary-text native multiply-high | backend | `MulHigh` → 1 instruction |
31//! | 16-bit half-word decomp | target-text fallback | `MulHigh` → 14 ALU ops |
32//! | Dual-issue FP32/INT32 | capable device | Division via FP pipeline |
33//! | Matrix-core batching | capable device | Batched int8 multiply |
34//!
35//! # Adding a New Strategy
36//!
37//! 1. Implement [`crate::strategy::LoweringStrategy`] in your backend crate
38//! 2. Register it via `inventory::submit!`
39//! 3. The lowering pipeline auto-selects the highest-priority applicable
40//! strategy based on [`vyre_foundation::validate::BackendCapabilities`]
41//!
42//! # Vyre Law Zero
43//!
44//! > Runtime performance is sacred. No avoidable runtime overhead, ever.
45//!
46//! Layer 1 runs at compile time - zero cost.
47//! Layer 2 runs at kernel compile time (once for the megakernel) - amortized to zero.
48//! At GPU runtime, only the optimal native instructions execute.
49
50use vyre_foundation::ir::{BinOp, Expr};
51use vyre_foundation::optimizer::passes::algebraic::precision_hint::{
52 PrecisionHint, TranscendentalOp,
53};
54use vyre_foundation::validate::BackendCapabilities;
55
56/// A lowered expression ready for backend emission.
57///
58/// This is the output of a [`LoweringStrategy`]. It can be either a
59/// rewritten Vyre `Expr` or a backend-specific opaque instruction
60/// sequence (represented as a tagged enum for extensibility).
61#[derive(Debug, Clone)]
62pub enum LoweredExpr {
63 /// Rewritten as a Vyre IR expression (most strategies do this).
64 Expr(Expr),
65 /// The strategy handled emission directly - the lowering pipeline
66 /// should not process this expression further.
67 Emitted,
68}
69
70/// A backend-specific lowering strategy.
71///
72/// Strategies are the extensibility point for target-dependent
73/// optimizations. Each strategy declares:
74/// - **what** it can optimize (via [`can_apply`](LoweringStrategy::can_apply))
75/// - **how well** (via [`priority`](LoweringStrategy::priority))
76/// - **the transformation** (via [`lower`](LoweringStrategy::lower))
77///
78/// The lowering pipeline selects the highest-priority applicable
79/// strategy for each expression.
80pub trait LoweringStrategy: Send + Sync + std::fmt::Debug {
81 /// Human-readable name for diagnostics and telemetry.
82 fn name(&self) -> &str;
83
84 /// Check whether this strategy applies given the target capabilities
85 /// and the expression being lowered.
86 fn can_apply(&self, caps: &BackendCapabilities, op: &BinOp) -> bool;
87
88 /// Priority for strategy selection. Higher = preferred.
89 ///
90 /// Guidelines:
91 /// - 100: native hardware instruction (OpUMulExtended, mul.hi.u32)
92 /// - 50: multi-instruction but optimal (dual-issue trick)
93 /// - 10: portable decomposition (16-bit arithmetic expansion)
94 fn priority(&self) -> u32;
95
96 /// Lower the given expression using this strategy.
97 ///
98 /// `left` and `right` are the operands of the binary operation.
99 /// The strategy may return a rewritten `Expr` or signal that it
100 /// handled emission directly.
101 fn lower(&self, op: &BinOp, left: &Expr, right: &Expr) -> LoweredExpr;
102}
103
104/// Select the best available strategy for the given operation.
105///
106/// Returns `None` if no registered strategy applies, in which case
107/// the lowering pipeline should use its default emission path.
108pub fn select_strategy<'a>(
109 strategies: &'a [Box<dyn LoweringStrategy>],
110 caps: &BackendCapabilities,
111 op: &BinOp,
112) -> Option<&'a dyn LoweringStrategy> {
113 strategies
114 .iter()
115 .filter(|s| s.can_apply(caps, op))
116 .max_by_key(|s| s.priority())
117 .map(|s| s.as_ref())
118}
119
120/// Concrete lower/emit plan selected from a foundation precision hint.
121#[derive(Debug, Clone, Copy, PartialEq)]
122pub enum PrecisionLoweringPlan {
123 /// Keep the default f32/device-transcendental lowering.
124 DefaultF32,
125 /// Emit this site through native f16 ALU and widen the result to f32.
126 NativeF16 {
127 /// Maximum absolute source operand carried from the foundation hint.
128 max_abs_operand: f32,
129 },
130 /// Emit a bounded polynomial for the transcendental instead of a native
131 /// device call.
132 PolynomialTranscendental {
133 /// Target operation.
134 op: TranscendentalOp,
135 /// Maximum absolute argument bound from the foundation hint.
136 argument_bound: f32,
137 /// Required backend-side polynomial degree.
138 degree: u8,
139 },
140}
141
142/// Select a backend-neutral lower/emit plan for a precision hint.
143///
144/// Foundation owns candidate discovery. This function owns the shared
145/// capability gate every emitter uses before choosing the faster code shape.
146#[must_use]
147pub fn select_precision_lowering(
148 caps: &BackendCapabilities,
149 hint: &PrecisionHint,
150) -> PrecisionLoweringPlan {
151 match hint {
152 PrecisionHint::F16Eligible { max_abs_operand } if caps.has_native_f16 => {
153 PrecisionLoweringPlan::NativeF16 {
154 max_abs_operand: *max_abs_operand,
155 }
156 }
157 PrecisionHint::TranscendentalPolynomial { op, argument_bound }
158 if caps.has_transcendental_polynomial_emit =>
159 {
160 PrecisionLoweringPlan::PolynomialTranscendental {
161 op: *op,
162 argument_bound: *argument_bound,
163 degree: polynomial_degree_for(*op, *argument_bound),
164 }
165 }
166 _ => PrecisionLoweringPlan::DefaultF32,
167 }
168}
169
170fn polynomial_degree_for(op: TranscendentalOp, argument_bound: f32) -> u8 {
171 match op {
172 TranscendentalOp::Sin => {
173 if argument_bound <= 0.25 {
174 3
175 } else {
176 5
177 }
178 }
179 TranscendentalOp::Cos => {
180 if argument_bound <= 0.25 {
181 4
182 } else {
183 6
184 }
185 }
186 TranscendentalOp::Exp | TranscendentalOp::Ln => 5,
187 }
188}
189
190#[cfg(test)]
191mod tests {
192 use super::*;
193
194 #[derive(Debug)]
195 struct MockNativeStrategy;
196
197 impl LoweringStrategy for MockNativeStrategy {
198 fn name(&self) -> &str {
199 "mock-native"
200 }
201 fn can_apply(&self, caps: &BackendCapabilities, op: &BinOp) -> bool {
202 caps.has_mul_high && matches!(op, BinOp::MulHigh)
203 }
204 fn priority(&self) -> u32 {
205 100
206 }
207 fn lower(&self, _op: &BinOp, left: &Expr, right: &Expr) -> LoweredExpr {
208 // In real impl: emit OpUMulExtended
209 LoweredExpr::Expr(Expr::mulhi(left.clone(), right.clone()))
210 }
211 }
212
213 #[derive(Debug)]
214 struct MockFallbackStrategy;
215
216 impl LoweringStrategy for MockFallbackStrategy {
217 fn name(&self) -> &str {
218 "mock-fallback"
219 }
220 fn can_apply(&self, _caps: &BackendCapabilities, op: &BinOp) -> bool {
221 matches!(op, BinOp::MulHigh)
222 }
223 fn priority(&self) -> u32 {
224 10
225 }
226 fn lower(&self, _op: &BinOp, left: &Expr, right: &Expr) -> LoweredExpr {
227 // In real impl: 16-bit decomposition
228 LoweredExpr::Expr(Expr::mul(left.clone(), right.clone()))
229 }
230 }
231
232 #[test]
233 fn selects_highest_priority() {
234 let strategies: Vec<Box<dyn LoweringStrategy>> =
235 vec![Box::new(MockFallbackStrategy), Box::new(MockNativeStrategy)];
236 let caps = BackendCapabilities {
237 has_mul_high: true,
238 ..Default::default()
239 };
240 let selected = select_strategy(&strategies, &caps, &BinOp::MulHigh);
241 assert_eq!(selected.unwrap().name(), "mock-native");
242 }
243
244 #[test]
245 fn falls_back_when_native_unavailable() {
246 let strategies: Vec<Box<dyn LoweringStrategy>> =
247 vec![Box::new(MockFallbackStrategy), Box::new(MockNativeStrategy)];
248 let caps = BackendCapabilities {
249 has_mul_high: false,
250 ..Default::default()
251 };
252 let selected = select_strategy(&strategies, &caps, &BinOp::MulHigh);
253 assert_eq!(selected.unwrap().name(), "mock-fallback");
254 }
255
256 #[test]
257 fn returns_none_for_unsupported_op() {
258 let strategies: Vec<Box<dyn LoweringStrategy>> = vec![Box::new(MockNativeStrategy)];
259 let caps = BackendCapabilities {
260 has_mul_high: true,
261 ..Default::default()
262 };
263 let selected = select_strategy(&strategies, &caps, &BinOp::Add);
264 assert!(selected.is_none());
265 }
266
267 #[test]
268 fn precision_hint_selects_native_f16_when_supported() {
269 let caps = BackendCapabilities {
270 has_native_f16: true,
271 ..Default::default()
272 };
273 let plan = select_precision_lowering(
274 &caps,
275 &PrecisionHint::F16Eligible {
276 max_abs_operand: 4.0,
277 },
278 );
279 assert_eq!(
280 plan,
281 PrecisionLoweringPlan::NativeF16 {
282 max_abs_operand: 4.0
283 }
284 );
285 }
286
287 #[test]
288 fn precision_hint_keeps_f32_without_native_f16() {
289 let plan = select_precision_lowering(
290 &BackendCapabilities::default(),
291 &PrecisionHint::F16Eligible {
292 max_abs_operand: 4.0,
293 },
294 );
295 assert_eq!(plan, PrecisionLoweringPlan::DefaultF32);
296 }
297
298 #[test]
299 fn transcendental_hint_selects_polynomial_when_supported() {
300 let caps = BackendCapabilities {
301 has_transcendental_polynomial_emit: true,
302 ..Default::default()
303 };
304 let plan = select_precision_lowering(
305 &caps,
306 &PrecisionHint::TranscendentalPolynomial {
307 op: TranscendentalOp::Sin,
308 argument_bound: 0.2,
309 },
310 );
311 assert_eq!(
312 plan,
313 PrecisionLoweringPlan::PolynomialTranscendental {
314 op: TranscendentalOp::Sin,
315 argument_bound: 0.2,
316 degree: 3,
317 }
318 );
319 }
320
321 #[test]
322 fn transcendental_hint_uses_higher_degree_for_wider_sin_range() {
323 let caps = BackendCapabilities {
324 has_transcendental_polynomial_emit: true,
325 ..Default::default()
326 };
327 let plan = select_precision_lowering(
328 &caps,
329 &PrecisionHint::TranscendentalPolynomial {
330 op: TranscendentalOp::Sin,
331 argument_bound: 0.75,
332 },
333 );
334 assert_eq!(
335 plan,
336 PrecisionLoweringPlan::PolynomialTranscendental {
337 op: TranscendentalOp::Sin,
338 argument_bound: 0.75,
339 degree: 5,
340 }
341 );
342 }
343}