Skip to main content

runmat_runtime/builtins/common/
residency.rs

1//! Shared heuristics for deciding when newly constructed arrays should prefer
2//! GPU residency, even when none of their inputs already live on the device.
3
4use runmat_accelerate_api::{provider, sequence_threshold_hint};
5
6const DEFAULT_SEQUENCE_MIN_LEN: usize = 4_096;
7const MIN_THRESHOLD: usize = 1_024;
8
9/// Kinds of sequence-producing builtins that can consult the residency helper.
10#[derive(Clone, Copy, Debug, PartialEq, Eq)]
11pub enum SequenceIntent {
12    Linspace,
13    Logspace,
14    Colon,
15    MeshAxis,
16    Generic,
17}
18
19impl SequenceIntent {
20    fn scale(self) -> f64 {
21        match self {
22            SequenceIntent::MeshAxis => 0.5,
23            SequenceIntent::Colon => 1.0,
24            SequenceIntent::Linspace => 1.0,
25            SequenceIntent::Logspace => 1.0,
26            SequenceIntent::Generic => 1.0,
27        }
28    }
29}
30
31/// Describes why we recommended GPU or CPU residency.
32#[derive(Clone, Copy, Debug, PartialEq, Eq)]
33pub enum ResidencyReason {
34    ExplicitGpuInput,
35    DisabledByEnv,
36    ProviderUnavailable,
37    ZeroLength,
38    BelowThreshold,
39    ThresholdHint,
40}
41
42/// Final decision returned by the residency helper.
43#[derive(Clone, Copy, Debug, PartialEq, Eq)]
44pub struct ResidencyDecision {
45    pub prefer_gpu: bool,
46    pub reason: ResidencyReason,
47}
48
49impl ResidencyDecision {
50    fn gpu(reason: ResidencyReason) -> Self {
51        Self {
52            prefer_gpu: true,
53            reason,
54        }
55    }
56
57    fn cpu(reason: ResidencyReason) -> Self {
58        Self {
59            prefer_gpu: false,
60            reason,
61        }
62    }
63}
64
65/// Decide whether a sequence of `len` elements for the supplied intent should
66/// prefer GPU residency.
67///
68/// `explicit_gpu_input` should be `true` when any of the arguments already
69/// reside on the GPU (for example, `gpuArray.linspace(...)`).
70pub fn sequence_gpu_preference(
71    len: usize,
72    intent: SequenceIntent,
73    explicit_gpu_input: bool,
74) -> ResidencyDecision {
75    if explicit_gpu_input {
76        return ResidencyDecision::gpu(ResidencyReason::ExplicitGpuInput);
77    }
78
79    if len == 0 {
80        return ResidencyDecision::cpu(ResidencyReason::ZeroLength);
81    }
82
83    if sequence_heuristics_disabled() {
84        return ResidencyDecision::cpu(ResidencyReason::DisabledByEnv);
85    }
86
87    if provider().is_none() {
88        return ResidencyDecision::cpu(ResidencyReason::ProviderUnavailable);
89    }
90
91    let threshold = threshold_for_intent(intent);
92    if len >= threshold {
93        return ResidencyDecision::gpu(ResidencyReason::ThresholdHint);
94    }
95
96    ResidencyDecision::cpu(ResidencyReason::BelowThreshold)
97}
98
99fn threshold_for_intent(intent: SequenceIntent) -> usize {
100    let env_override = std::env::var("RUNMAT_SEQUENCE_GPU_MIN")
101        .ok()
102        .and_then(|raw| raw.trim().parse::<usize>().ok());
103
104    let base = env_override
105        .or_else(sequence_threshold_hint)
106        .unwrap_or(DEFAULT_SEQUENCE_MIN_LEN);
107
108    let scaled = (base as f64 * intent.scale()).round() as isize;
109    scaled.max(MIN_THRESHOLD as isize) as usize
110}
111
112fn sequence_heuristics_disabled() -> bool {
113    matches!(
114        std::env::var("RUNMAT_SEQUENCE_GPU_DISABLE"),
115        Ok(flag) if flag.trim().eq_ignore_ascii_case("1")
116            || flag.trim().eq_ignore_ascii_case("true")
117            || flag.trim().eq_ignore_ascii_case("yes")
118    )
119}
120
121#[cfg(test)]
122pub(crate) mod tests {
123    use super::*;
124    use runmat_accelerate::simple_provider;
125    use std::sync::Mutex;
126
127    static ENV_LOCK: Mutex<()> = Mutex::new(());
128
129    fn reset_env() {
130        std::env::remove_var("RUNMAT_SEQUENCE_GPU_DISABLE");
131        std::env::remove_var("RUNMAT_SEQUENCE_GPU_MIN");
132    }
133
134    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
135    #[test]
136    fn explicit_gpu_short_circuits() {
137        let _guard = ENV_LOCK.lock().unwrap();
138        reset_env();
139        let decision = sequence_gpu_preference(4, SequenceIntent::Linspace, true);
140        assert!(decision.prefer_gpu);
141        assert_eq!(decision.reason, ResidencyReason::ExplicitGpuInput);
142    }
143
144    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
145    #[test]
146    fn env_disable_blocks_gpu() {
147        let _guard = ENV_LOCK.lock().unwrap();
148        std::env::set_var("RUNMAT_SEQUENCE_GPU_DISABLE", "1");
149        let decision = sequence_gpu_preference(10_000, SequenceIntent::Linspace, false);
150        assert!(!decision.prefer_gpu);
151        assert_eq!(decision.reason, ResidencyReason::DisabledByEnv);
152        reset_env();
153    }
154
155    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
156    #[test]
157    fn env_min_len_controls_threshold() {
158        let _guard = ENV_LOCK.lock().unwrap();
159        reset_env();
160        simple_provider::register_inprocess_provider();
161        std::env::set_var("RUNMAT_SEQUENCE_GPU_MIN", "8192");
162        let decision = sequence_gpu_preference(10_000, SequenceIntent::Linspace, false);
163        assert!(decision.prefer_gpu);
164        assert_eq!(decision.reason, ResidencyReason::ThresholdHint);
165        reset_env();
166    }
167}