#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct TraceJitInputs {
pub shader_hit_count: u32,
pub prediction_confidence_bps: u32,
pub speculative_spec_cost_ns: u64,
pub miss_cost_ns: u64,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TraceJitDecision {
HoldSteady,
Speculate {
expected_savings_ns: u64,
},
}
impl std::fmt::Display for TraceJitDecision {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::HoldSteady => f.write_str("hold-steady"),
Self::Speculate {
expected_savings_ns,
} => write!(f, "speculate:{expected_savings_ns}"),
}
}
}
pub const TRACE_JIT_HOT_SHAPE_THRESHOLD: u32 = 8;
pub const TRACE_JIT_MIN_CONFIDENCE_BPS: u32 = 6_000;
#[must_use]
pub fn decide_trace_jit_speculation(inputs: TraceJitInputs) -> TraceJitDecision {
if inputs.shader_hit_count < TRACE_JIT_HOT_SHAPE_THRESHOLD {
return TraceJitDecision::HoldSteady;
}
if inputs.prediction_confidence_bps < TRACE_JIT_MIN_CONFIDENCE_BPS {
return TraceJitDecision::HoldSteady;
}
if inputs.miss_cost_ns == 0 {
return TraceJitDecision::HoldSteady;
}
let weighted = (inputs.miss_cost_ns as u128)
.saturating_mul(inputs.prediction_confidence_bps as u128)
/ 10_000u128;
let weighted_u64 = u64::try_from(weighted).unwrap_or(u64::MAX);
if weighted_u64 <= inputs.speculative_spec_cost_ns {
return TraceJitDecision::HoldSteady;
}
let expected_savings_ns = weighted_u64.saturating_sub(inputs.speculative_spec_cost_ns);
TraceJitDecision::Speculate {
expected_savings_ns,
}
}
#[cfg(test)]
mod tests {
use super::*;
fn inp(hit: u32, conf: u32, spec_cost: u64, miss_cost: u64) -> TraceJitInputs {
TraceJitInputs {
shader_hit_count: hit,
prediction_confidence_bps: conf,
speculative_spec_cost_ns: spec_cost,
miss_cost_ns: miss_cost,
}
}
#[test]
fn cold_shape_holds_steady() {
assert_eq!(
decide_trace_jit_speculation(inp(7, 9_000, 1_000, 100_000)),
TraceJitDecision::HoldSteady
);
}
#[test]
fn low_confidence_holds_steady() {
assert_eq!(
decide_trace_jit_speculation(inp(100, 5_999, 1_000, 1_000_000)),
TraceJitDecision::HoldSteady
);
}
#[test]
fn zero_miss_cost_holds_steady() {
assert_eq!(
decide_trace_jit_speculation(inp(100, 9_000, 1_000, 0)),
TraceJitDecision::HoldSteady
);
}
#[test]
fn positive_savings_speculates() {
let dec = decide_trace_jit_speculation(inp(100, 10_000, 10_000, 100_000));
assert_eq!(
dec,
TraceJitDecision::Speculate {
expected_savings_ns: 90_000
}
);
}
#[test]
fn confidence_weights_predicted_savings() {
let dec = decide_trace_jit_speculation(inp(100, 6_000, 50_000, 100_000));
assert_eq!(
dec,
TraceJitDecision::Speculate {
expected_savings_ns: 10_000
}
);
}
#[test]
fn spec_cost_above_weighted_savings_holds_steady() {
assert_eq!(
decide_trace_jit_speculation(inp(100, 6_000, 60_000, 100_000)),
TraceJitDecision::HoldSteady
);
}
#[test]
fn at_threshold_speculates_when_other_inputs_pass() {
let dec = decide_trace_jit_speculation(inp(8, 10_000, 1_000, 100_000));
match dec {
TraceJitDecision::Speculate { .. } => {}
other => panic!("expected Speculate; got {:?}", other),
}
}
#[test]
fn confidence_at_threshold_speculates() {
let dec = decide_trace_jit_speculation(inp(100, 6_000, 1_000, 100_000));
match dec {
TraceJitDecision::Speculate { .. } => {}
other => panic!("expected Speculate; got {:?}", other),
}
}
#[test]
fn extreme_inputs_use_saturating_arithmetic() {
let dec = decide_trace_jit_speculation(inp(100, 10_000, 1_000, u64::MAX));
match dec {
TraceJitDecision::Speculate { .. } => {}
other => panic!("expected Speculate; got {:?}", other),
}
}
#[test]
fn calibration_constants_pinned() {
assert_eq!(TRACE_JIT_HOT_SHAPE_THRESHOLD, 8);
assert_eq!(TRACE_JIT_MIN_CONFIDENCE_BPS, 6_000);
}
#[test]
fn trace_jit_decision_displays_human_string() {
assert_eq!(
format!("{}", TraceJitDecision::HoldSteady),
"hold-steady"
);
assert_eq!(
format!("{}", TraceJitDecision::Speculate { expected_savings_ns: 77 }),
"speculate:77"
);
}
}