Skip to main content

runmat_accelerate/
native_auto.rs

1use runmat_time::{system_time_now, Instant};
2use std::collections::HashMap;
3use std::env;
4use std::fs;
5use std::path::{Path, PathBuf};
6use std::sync::Mutex;
7use std::time::{Duration, UNIX_EPOCH};
8
9use crate::{
10    auto_offload_options,
11    fusion::{active_fusion, FusionKind},
12    fusion_residency,
13    precision::ensure_provider_supports_dtype,
14    AutoOffloadLogLevel,
15};
16use anyhow::{anyhow, Result};
17use futures::lock::Mutex as AsyncMutex;
18use log::{debug, info, trace, warn};
19use once_cell::sync::{Lazy, OnceCell};
20use runmat_accelerate_api::{AccelProvider, ApiDeviceInfo, HostTensorView, ProviderPrecision};
21use runmat_builtins::{builtin_functions, AccelTag, Tensor, Value};
22use runmat_runtime::builtins::common::spec::{builtin_residency_policy, ResidencyPolicy};
23use runmat_runtime::gather_if_needed_async;
24use serde::{Deserialize, Serialize};
25
26const DEFAULT_CPU_ELEM_PER_ELEM: f64 = 1.0e-7;
27const DEFAULT_CPU_REDUCTION_PER_ELEM: f64 = 1.2e-7;
28const DEFAULT_CPU_MATMUL_PER_FLOP: f64 = 2.5e-11;
29const SMALL_BATCH_DEFAULT_MAX_DIM: usize = 8;
30const SMALL_BATCH_DEFAULT_MIN_ELEMS: usize = 1_048_576;
31const DECISION_LOG_CAPACITY: usize = 128;
32const CALIBRATION_VERSION: u32 = 1;
33
34#[derive(Clone, Copy, Debug)]
35pub enum BinaryOp {
36    Elementwise,
37    MatMul,
38}
39
40#[derive(Clone, Copy, Debug)]
41pub enum UnaryOp {
42    Generic,
43    Transpose,
44}
45
46#[derive(Clone, Copy, Debug)]
47pub enum ReductionOp {
48    Sum,
49    Mean,
50    Min,
51    Max,
52}
53
54#[derive(Debug, Clone, Serialize, Deserialize)]
55struct ThresholdConfig {
56    unary_min_elems: usize,
57    binary_min_elems: usize,
58    reduction_min_elems: usize,
59    matmul_min_flops: usize,
60    cpu_elem_per_elem: f64,
61    cpu_reduction_per_elem: f64,
62    cpu_matmul_per_flop: f64,
63    small_batch_max_dim: usize,
64    small_batch_min_elems: usize,
65}
66
67impl Default for ThresholdConfig {
68    fn default() -> Self {
69        Self {
70            unary_min_elems: 4_096,
71            binary_min_elems: 4_096,
72            reduction_min_elems: 256,
73            matmul_min_flops: 1_000_000, // roughly 100x100x100
74            cpu_elem_per_elem: DEFAULT_CPU_ELEM_PER_ELEM,
75            cpu_reduction_per_elem: DEFAULT_CPU_REDUCTION_PER_ELEM,
76            cpu_matmul_per_flop: DEFAULT_CPU_MATMUL_PER_FLOP,
77            small_batch_max_dim: SMALL_BATCH_DEFAULT_MAX_DIM,
78            small_batch_min_elems: SMALL_BATCH_DEFAULT_MIN_ELEMS,
79        }
80    }
81}
82
83#[derive(Debug, Clone, Serialize)]
84pub struct AutoOffloadDecisionEntry {
85    pub timestamp_ms: u128,
86    pub operation: String,
87    pub elements: Option<usize>,
88    pub flops: Option<usize>,
89    pub batch: Option<usize>,
90    pub decision: AutoOffloadDisposition,
91    pub reason: DecisionReason,
92    pub cpu_estimate_ms: Option<f64>,
93    pub gpu_estimate_ms: Option<f64>,
94    pub threshold: Option<usize>,
95    pub fusion_kind: Option<FusionKind>,
96}
97
98#[derive(Debug, Clone, Copy, Serialize, PartialEq, Eq)]
99#[serde(rename_all = "kebab-case")]
100pub enum AutoOffloadDisposition {
101    Gpu,
102    Cpu,
103}
104
105#[derive(Debug, Clone, Copy, Serialize, PartialEq, Eq)]
106#[serde(rename_all = "kebab-case")]
107pub enum DecisionReason {
108    FusionOverride,
109    Residency,
110    SmallBatchGuard,
111    ProfileModel,
112    Threshold,
113    Disabled,
114}
115
116#[derive(Debug, Clone, Serialize)]
117pub struct ThresholdSnapshot {
118    pub unary_min_elems: usize,
119    pub binary_min_elems: usize,
120    pub reduction_min_elems: usize,
121    pub matmul_min_flops: usize,
122    pub cpu_elem_per_elem: f64,
123    pub cpu_reduction_per_elem: f64,
124    pub cpu_matmul_per_flop: f64,
125    pub small_batch_max_dim: usize,
126    pub small_batch_min_elems: usize,
127}
128
129#[derive(Debug, Clone, Serialize)]
130pub struct AutoOffloadCalibrationSummary {
131    pub previous: ThresholdSnapshot,
132    pub delta: ThresholdDelta,
133}
134
135#[derive(Debug, Clone, Serialize, Default)]
136pub struct ThresholdDelta {
137    #[serde(skip_serializing_if = "Option::is_none")]
138    pub cpu_elem_per_elem: Option<ThresholdDeltaEntry>,
139    #[serde(skip_serializing_if = "Option::is_none")]
140    pub cpu_reduction_per_elem: Option<ThresholdDeltaEntry>,
141    #[serde(skip_serializing_if = "Option::is_none")]
142    pub cpu_matmul_per_flop: Option<ThresholdDeltaEntry>,
143}
144
145#[derive(Debug, Clone, Serialize)]
146pub struct ThresholdDeltaEntry {
147    pub before: f64,
148    pub after: f64,
149    pub absolute: f64,
150    #[serde(skip_serializing_if = "Option::is_none")]
151    pub ratio: Option<f64>,
152}
153
154impl ThresholdDeltaEntry {
155    fn new(before: f64, after: f64) -> Self {
156        let absolute = after - before;
157        let ratio = if before.abs() > f64::EPSILON {
158            Some(after / before)
159        } else {
160            None
161        };
162        Self {
163            before,
164            after,
165            absolute,
166            ratio,
167        }
168    }
169}
170
171#[derive(Debug, Clone, Serialize)]
172pub struct AutoOffloadReport {
173    pub provider: Option<CachedProviderInfo>,
174    pub thresholds: ThresholdSnapshot,
175    pub base_source: ThresholdBase,
176    pub env_overrides_applied: bool,
177    pub cache_path: Option<String>,
178    pub calibrate_duration_ms: Option<u128>,
179    #[serde(skip_serializing_if = "Option::is_none")]
180    pub calibration: Option<AutoOffloadCalibrationSummary>,
181    pub decisions: Vec<AutoOffloadDecisionEntry>,
182}
183
184#[derive(Debug, Clone, Copy, Serialize, PartialEq, Eq)]
185#[serde(rename_all = "kebab-case")]
186pub enum ThresholdBase {
187    BuiltInDefault,
188    LoadedFromCache,
189    Calibrated,
190}
191
192impl ThresholdBase {
193    pub fn as_str(&self) -> &'static str {
194        match self {
195            ThresholdBase::BuiltInDefault => "built-in-default",
196            ThresholdBase::LoadedFromCache => "loaded-from-cache",
197            ThresholdBase::Calibrated => "calibrated",
198        }
199    }
200}
201
202#[derive(Debug, Clone, Serialize)]
203pub struct CachedProviderInfo {
204    pub name: String,
205    pub vendor: String,
206    pub backend: Option<String>,
207    pub device_id: u32,
208}
209
210#[derive(Debug, Clone)]
211struct AutoOffloadState {
212    provider: Option<CachedProviderInfo>,
213    thresholds: ThresholdConfig,
214    base_source: ThresholdBase,
215    env_overrides_applied: bool,
216    cache_path: Option<String>,
217    calibrate_duration_ms: Option<u128>,
218    previous_thresholds: Option<ThresholdConfig>,
219    calibration_delta: Option<ThresholdDelta>,
220}
221
222#[derive(Clone)]
223struct DecisionEvaluation {
224    recommend_gpu: bool,
225    reason: DecisionReason,
226    cpu_secs: Option<f64>,
227    gpu_secs: Option<f64>,
228    threshold: Option<usize>,
229    fusion_kind: Option<FusionKind>,
230    batch: Option<usize>,
231}
232
233struct DecisionLog {
234    entries: Vec<AutoOffloadDecisionEntry>,
235}
236
237impl DecisionLog {
238    fn new() -> Self {
239        Self {
240            entries: Vec::new(),
241        }
242    }
243
244    fn push(&mut self, entry: AutoOffloadDecisionEntry) {
245        self.entries.push(entry);
246        if self.entries.len() > DECISION_LOG_CAPACITY {
247            let overflow = self.entries.len() - DECISION_LOG_CAPACITY;
248            self.entries.drain(0..overflow);
249        }
250    }
251
252    fn snapshot(&self) -> Vec<AutoOffloadDecisionEntry> {
253        self.entries.clone()
254    }
255
256    fn clear(&mut self) {
257        self.entries.clear();
258    }
259}
260
261static DECISION_LOG: Lazy<Mutex<DecisionLog>> = Lazy::new(|| Mutex::new(DecisionLog::new()));
262static AUTO_STATE: OnceCell<Mutex<AutoOffloadState>> = OnceCell::new();
263
264fn record_decision(entry: AutoOffloadDecisionEntry) {
265    if let Ok(mut log) = DECISION_LOG.lock() {
266        log.push(entry);
267    }
268}
269
270fn snapshot_decisions() -> Vec<AutoOffloadDecisionEntry> {
271    DECISION_LOG
272        .lock()
273        .map(|log| log.snapshot())
274        .unwrap_or_default()
275}
276
277fn clear_decisions() {
278    if let Ok(mut log) = DECISION_LOG.lock() {
279        log.clear();
280    }
281}
282
283fn now_millis() -> u128 {
284    system_time_now()
285        .duration_since(UNIX_EPOCH)
286        .unwrap_or_else(|_| Duration::from_secs(0))
287        .as_millis()
288}
289
290fn threshold_snapshot(cfg: &ThresholdConfig) -> ThresholdSnapshot {
291    ThresholdSnapshot {
292        unary_min_elems: cfg.unary_min_elems,
293        binary_min_elems: cfg.binary_min_elems,
294        reduction_min_elems: cfg.reduction_min_elems,
295        matmul_min_flops: cfg.matmul_min_flops,
296        cpu_elem_per_elem: cfg.cpu_elem_per_elem,
297        cpu_reduction_per_elem: cfg.cpu_reduction_per_elem,
298        cpu_matmul_per_flop: cfg.cpu_matmul_per_flop,
299        small_batch_max_dim: cfg.small_batch_max_dim,
300        small_batch_min_elems: cfg.small_batch_min_elems,
301    }
302}
303
304fn compute_delta(before: &ThresholdConfig, after: &ThresholdConfig) -> ThresholdDelta {
305    let mut delta = ThresholdDelta::default();
306
307    if (before.cpu_elem_per_elem - after.cpu_elem_per_elem).abs() > f64::EPSILON {
308        delta.cpu_elem_per_elem = Some(ThresholdDeltaEntry::new(
309            before.cpu_elem_per_elem,
310            after.cpu_elem_per_elem,
311        ));
312    }
313
314    if (before.cpu_reduction_per_elem - after.cpu_reduction_per_elem).abs() > f64::EPSILON {
315        delta.cpu_reduction_per_elem = Some(ThresholdDeltaEntry::new(
316            before.cpu_reduction_per_elem,
317            after.cpu_reduction_per_elem,
318        ));
319    }
320
321    if (before.cpu_matmul_per_flop - after.cpu_matmul_per_flop).abs() > f64::EPSILON {
322        delta.cpu_matmul_per_flop = Some(ThresholdDeltaEntry::new(
323            before.cpu_matmul_per_flop,
324            after.cpu_matmul_per_flop,
325        ));
326    }
327
328    delta
329}
330
331#[derive(Debug, Deserialize)]
332struct CalibrationFile {
333    #[serde(default)]
334    suite: Option<CalibrationSuiteSection>,
335    #[serde(default)]
336    auto_offload_calibration: Option<CalibrationSample>,
337}
338
339#[derive(Debug, Deserialize)]
340struct CalibrationSuiteSection {
341    #[serde(default)]
342    auto_offload_calibration: Option<CalibrationSample>,
343}
344
345#[derive(Debug, Clone, Deserialize)]
346struct CalibrationSample {
347    #[serde(default)]
348    runs: usize,
349    #[serde(default, rename = "cpu_time_ms")]
350    cpu_time: CalibrationTimes,
351    #[serde(default)]
352    units: CalibrationUnits,
353    #[serde(default)]
354    provider: Option<CalibrationProviderInfo>,
355    #[serde(default)]
356    provider_conflict: bool,
357}
358
359#[derive(Debug, Clone, Deserialize, Default)]
360struct CalibrationTimes {
361    #[serde(default)]
362    elementwise: f64,
363    #[serde(default)]
364    reduction: f64,
365    #[serde(default)]
366    matmul: f64,
367}
368
369#[derive(Debug, Clone, Deserialize, Default)]
370struct CalibrationUnits {
371    #[serde(default)]
372    elementwise: f64,
373    #[serde(default)]
374    reduction: f64,
375    #[serde(default, rename = "matmul_flops")]
376    matmul_flops: f64,
377}
378
379#[derive(Debug, Clone, Deserialize)]
380struct CalibrationProviderInfo {
381    name: String,
382    vendor: String,
383    #[serde(default)]
384    backend: Option<String>,
385    device_id: u32,
386}
387
388#[derive(Debug, Serialize)]
389pub struct AutoOffloadCalibrationOutcome {
390    pub runs: usize,
391    pub before: ThresholdSnapshot,
392    pub after: ThresholdSnapshot,
393    #[serde(skip_serializing_if = "Option::is_none")]
394    pub delta: Option<ThresholdDelta>,
395    #[serde(skip_serializing_if = "Option::is_none")]
396    pub persisted_to: Option<String>,
397    #[serde(skip_serializing_if = "Option::is_none")]
398    pub provider: Option<CachedProviderInfo>,
399    pub commit: bool,
400}
401
402fn load_calibration_sample(path: &Path) -> Result<CalibrationSample> {
403    let payload = fs::read_to_string(path).map_err(|e| anyhow!(e.to_string()))?;
404    let file: CalibrationFile = serde_json::from_str(&payload)
405        .map_err(|e| anyhow!(format!("failed to parse calibration file: {e}")))?;
406    if let Some(suite) = file.suite {
407        if let Some(sample) = suite.auto_offload_calibration {
408            return Ok(sample);
409        }
410    }
411    if let Some(sample) = file.auto_offload_calibration {
412        return Ok(sample);
413    }
414    Err(anyhow!(
415        "calibration file does not contain an auto_offload_calibration section"
416    ))
417}
418
419fn apply_calibration_sample(
420    cfg: &mut ThresholdConfig,
421    sample: &CalibrationSample,
422) -> Option<ThresholdDelta> {
423    let mut delta = ThresholdDelta::default();
424    let mut changed = false;
425
426    if sample.units.elementwise > 0.0 && sample.cpu_time.elementwise > 0.0 {
427        let secs_per_elem = (sample.cpu_time.elementwise / 1_000.0) / sample.units.elementwise;
428        if secs_per_elem.is_finite()
429            && secs_per_elem > 0.0
430            && (cfg.cpu_elem_per_elem - secs_per_elem).abs() > f64::EPSILON
431        {
432            delta.cpu_elem_per_elem = Some(ThresholdDeltaEntry::new(
433                cfg.cpu_elem_per_elem,
434                secs_per_elem,
435            ));
436            cfg.cpu_elem_per_elem = secs_per_elem;
437            changed = true;
438        }
439    }
440
441    if sample.units.reduction > 0.0 && sample.cpu_time.reduction > 0.0 {
442        let secs_per_elem = (sample.cpu_time.reduction / 1_000.0) / sample.units.reduction;
443        if secs_per_elem.is_finite()
444            && secs_per_elem > 0.0
445            && (cfg.cpu_reduction_per_elem - secs_per_elem).abs() > f64::EPSILON
446        {
447            delta.cpu_reduction_per_elem = Some(ThresholdDeltaEntry::new(
448                cfg.cpu_reduction_per_elem,
449                secs_per_elem,
450            ));
451            cfg.cpu_reduction_per_elem = secs_per_elem;
452            changed = true;
453        }
454    }
455
456    if sample.units.matmul_flops > 0.0 && sample.cpu_time.matmul > 0.0 {
457        let secs_per_flop = (sample.cpu_time.matmul / 1_000.0) / sample.units.matmul_flops;
458        if secs_per_flop.is_finite()
459            && secs_per_flop > 0.0
460            && (cfg.cpu_matmul_per_flop - secs_per_flop).abs() > f64::EPSILON
461        {
462            delta.cpu_matmul_per_flop = Some(ThresholdDeltaEntry::new(
463                cfg.cpu_matmul_per_flop,
464                secs_per_flop,
465            ));
466            cfg.cpu_matmul_per_flop = secs_per_flop;
467            changed = true;
468        }
469    }
470
471    if changed {
472        Some(delta)
473    } else {
474        None
475    }
476}
477
478pub fn apply_auto_offload_calibration_from_file(
479    path: &Path,
480    commit: bool,
481) -> Result<AutoOffloadCalibrationOutcome> {
482    let sample = load_calibration_sample(path)?;
483    if sample.runs == 0 {
484        return Err(anyhow!("calibration sample contains zero runs"));
485    }
486
487    let provider = runmat_accelerate_api::provider()
488        .ok_or_else(|| anyhow!("no acceleration provider registered"))?;
489    let device_info = provider.device_info_struct();
490
491    if let Some(ref prov) = sample.provider {
492        if prov.name != device_info.name
493            || prov.vendor != device_info.vendor
494            || prov.backend.as_deref() != device_info.backend.as_deref()
495            || prov.device_id != device_info.device_id
496        {
497            warn!(
498                "Calibration provider mismatch: sample='{} ({})' device='{} ({})'",
499                prov.name, prov.vendor, device_info.name, device_info.vendor
500            );
501        }
502        if sample.provider_conflict {
503            warn!("Calibration sample reported provider conflict across cases");
504        }
505    }
506
507    let (mut cfg, _) = load_cached_thresholds(&device_info)
508        .unwrap_or_else(|| (ThresholdConfig::default(), PathBuf::new()));
509    let before_cfg = cfg.clone();
510
511    let delta = apply_calibration_sample(&mut cfg, &sample)
512        .ok_or_else(|| anyhow!("calibration sample did not produce coefficient updates"))?;
513
514    let mut persisted_to: Option<PathBuf> = None;
515    if commit {
516        persisted_to = Some(persist_thresholds(&device_info, &cfg)?);
517    }
518
519    if let Some(state_mutex) = AUTO_STATE.get() {
520        if let Ok(mut state) = state_mutex.lock() {
521            state.previous_thresholds = Some(before_cfg.clone());
522            state.calibration_delta = Some(delta.clone());
523            if commit {
524                state.thresholds = cfg.clone();
525                state.base_source = ThresholdBase::Calibrated;
526                if let Some(ref path_buf) = persisted_to {
527                    state.cache_path = Some(path_buf.to_string_lossy().into_owned());
528                }
529                state.calibrate_duration_ms = None;
530            }
531        }
532    }
533
534    Ok(AutoOffloadCalibrationOutcome {
535        runs: sample.runs,
536        before: threshold_snapshot(&before_cfg),
537        after: threshold_snapshot(&cfg),
538        delta: Some(delta),
539        persisted_to: persisted_to.map(|p| p.to_string_lossy().into_owned()),
540        provider: Some(cached_provider_info(&device_info)),
541        commit,
542    })
543}
544
545fn cached_provider_info(info: &ApiDeviceInfo) -> CachedProviderInfo {
546    CachedProviderInfo {
547        name: info.name.clone(),
548        vendor: info.vendor.clone(),
549        backend: info.backend.clone(),
550        device_id: info.device_id,
551    }
552}
553
554fn cpu_estimate(per_unit: f64, units: usize) -> Option<f64> {
555    if per_unit.is_finite() && per_unit > 0.0 {
556        Some(per_unit * units as f64)
557    } else {
558        None
559    }
560}
561
562fn value_shape(value: &Value) -> Option<&[usize]> {
563    match value {
564        Value::Tensor(t) => Some(&t.shape),
565        Value::GpuTensor(handle) => Some(&handle.shape),
566        _ => None,
567    }
568}
569
570fn batch_dimension_from_value(value: &Value) -> Option<usize> {
571    let shape = value_shape(value)?;
572    if shape.len() < 3 {
573        return None;
574    }
575    shape.last().copied()
576}
577
578fn batch_dimension_from_values(values: &[&Value]) -> Option<usize> {
579    values
580        .iter()
581        .filter_map(|value| batch_dimension_from_value(value))
582        .min()
583}
584
585fn decision_entry(
586    operation: &str,
587    elements: Option<usize>,
588    flops: Option<usize>,
589    eval: &DecisionEvaluation,
590) -> AutoOffloadDecisionEntry {
591    AutoOffloadDecisionEntry {
592        timestamp_ms: now_millis(),
593        operation: operation.to_string(),
594        elements,
595        flops,
596        batch: eval.batch,
597        decision: if eval.recommend_gpu {
598            AutoOffloadDisposition::Gpu
599        } else {
600            AutoOffloadDisposition::Cpu
601        },
602        reason: eval.reason,
603        cpu_estimate_ms: eval.cpu_secs.map(|secs| secs * 1_000.0),
604        gpu_estimate_ms: eval.gpu_secs.map(|secs| secs * 1_000.0),
605        threshold: eval.threshold,
606        fusion_kind: eval.fusion_kind.clone(),
607    }
608}
609
610pub struct NativeAutoOffload {
611    provider: &'static dyn AccelProvider,
612    thresholds: ThresholdConfig,
613    enabled: bool,
614}
615
616static GLOBAL: OnceCell<Option<NativeAutoOffload>> = OnceCell::new();
617static GLOBAL_INIT_LOCK: Lazy<AsyncMutex<()>> = Lazy::new(|| AsyncMutex::new(()));
618static PROFILE_MODEL: OnceCell<Option<ProfileCostModel>> = OnceCell::new();
619
620fn env_bool(key: &str) -> Option<bool> {
621    env::var(key).ok().and_then(|v| parse_bool(&v))
622}
623
624fn parse_bool(s: &str) -> Option<bool> {
625    match s.trim().to_ascii_lowercase().as_str() {
626        "1" | "true" | "yes" | "on" => Some(true),
627        "0" | "false" | "no" | "off" => Some(false),
628        _ => None,
629    }
630}
631
632fn log_promotion<F>(builder: F)
633where
634    F: FnOnce() -> String,
635{
636    match auto_offload_options().log_level {
637        AutoOffloadLogLevel::Off => {}
638        AutoOffloadLogLevel::Info => info!("{}", builder()),
639        AutoOffloadLogLevel::Trace => trace!("{}", builder()),
640    }
641}
642
643fn update_cpu_cost(slot: &mut f64, candidate: f64) {
644    if candidate.is_finite() && candidate > 0.0 && candidate < *slot {
645        *slot = candidate;
646    }
647}
648
649fn value_len(value: &Value) -> Option<usize> {
650    match value {
651        Value::Tensor(t) => Some(t.data.len()),
652        Value::GpuTensor(handle) => Some(handle.shape.iter().product()),
653        Value::Num(_) | Value::Bool(_) | Value::Int(_) => Some(1),
654        Value::Complex(_, _) => Some(1),
655        _ => None,
656    }
657}
658
659fn element_count_pair(a: &Value, b: &Value) -> Option<usize> {
660    let la = value_len(a)?;
661    let lb = value_len(b)?;
662    Some(la.max(lb))
663}
664
665pub async fn global() -> Option<&'static NativeAutoOffload> {
666    if let Some(existing) = GLOBAL.get() {
667        return existing.as_ref();
668    }
669    // If auto-offload is disabled or there is no GPU provider registered,
670    // initialize_async() would return None immediately (no I/O, no blocking).
671    // Return None directly without acquiring the async lock so single-poll
672    // callers (e.g. the turbine JIT interpreter fallback) never observe a
673    // spurious Pending.  We intentionally do NOT write to GLOBAL here: doing
674    // so without holding GLOBAL_INIT_LOCK would race with a concurrent thread
675    // that is partway through initialize_async() and has found a valid
676    // provider.  That thread's subsequent GLOBAL.set(Some(offload)) would
677    // silently fail (OnceCell is set-once), permanently disabling the
678    // accelerator for the lifetime of the process.  These two checks are
679    // cheap (no I/O), so re-evaluating them on each call is acceptable.
680    if !auto_enabled() || runmat_accelerate_api::provider().is_none() {
681        return None;
682    }
683    let _guard = GLOBAL_INIT_LOCK.lock().await;
684    if let Some(existing) = GLOBAL.get() {
685        return existing.as_ref();
686    }
687    let initialized = initialize_async().await;
688    let _ = GLOBAL.set(initialized);
689    GLOBAL.get().and_then(|value| value.as_ref())
690}
691
692async fn initialize_async() -> Option<NativeAutoOffload> {
693    if !auto_enabled() {
694        clear_decisions();
695        return None;
696    }
697    let provider = runmat_accelerate_api::provider()?;
698    let device_info = provider.device_info_struct();
699    let mut config = ThresholdConfig::default();
700    let mut base_source = ThresholdBase::BuiltInDefault;
701    let mut cache_path: Option<String> = None;
702    let mut calibrate_duration_ms: Option<u128> = None;
703    let refresh_calibration = calibrate_refresh_enabled();
704
705    if !refresh_calibration {
706        if let Some((cached, path)) = load_cached_thresholds_async(&device_info).await {
707            info!(
708                "Native auto-offload: loaded cached calibration for '{}' from {}",
709                device_info.name, path
710            );
711            config = cached;
712            cache_path = Some(path);
713            base_source = ThresholdBase::LoadedFromCache;
714        }
715    }
716
717    let needs_calibration = calibrate_enabled() && (refresh_calibration || cache_path.is_none());
718    if needs_calibration {
719        let start = Instant::now();
720        match auto_calibrate(provider, &mut config) {
721            Ok(()) => {
722                calibrate_duration_ms = Some(start.elapsed().as_millis());
723                base_source = ThresholdBase::Calibrated;
724                match persist_thresholds_async(&device_info, &config).await {
725                    Ok(path) => {
726                        cache_path = Some(path.clone());
727                        info!(
728                            "Native auto-offload: persisted calibration for '{}' to {}",
729                            device_info.name, path
730                        );
731                    }
732                    Err(err) => {
733                        debug!("Native auto-offload: failed to persist calibration: {err}");
734                    }
735                }
736            }
737            Err(err) => {
738                debug!("Native auto-offload calibration failed: {err}");
739            }
740        }
741    }
742
743    let env_overrides_applied = apply_env_overrides(&mut config);
744    let model_status = if profile_cost_model().is_some() {
745        "profile"
746    } else {
747        "fallback"
748    };
749    info!(
750        "Native auto-offload thresholds: unary={} binary={} reduction={} matmul_flops={} small_batch_dim={} small_batch_min_elems={} (model: {}, source: {}, env_overrides={})",
751        config.unary_min_elems,
752        config.binary_min_elems,
753        config.reduction_min_elems,
754        config.matmul_min_flops,
755        config.small_batch_max_dim,
756        config.small_batch_min_elems,
757        model_status,
758        base_source.as_str(),
759        env_overrides_applied
760    );
761
762    let cache_path_str = cache_path.clone();
763    let state = AutoOffloadState {
764        provider: Some(cached_provider_info(&device_info)),
765        thresholds: config.clone(),
766        base_source,
767        env_overrides_applied,
768        cache_path: cache_path_str,
769        calibrate_duration_ms,
770        previous_thresholds: None,
771        calibration_delta: None,
772    };
773    let _ = AUTO_STATE.set(Mutex::new(state));
774
775    Some(NativeAutoOffload::new(provider, config))
776}
777
778impl NativeAutoOffload {
779    fn new(provider: &'static dyn AccelProvider, thresholds: ThresholdConfig) -> Self {
780        let enabled = true;
781        Self {
782            provider,
783            thresholds,
784            enabled,
785        }
786    }
787
788    fn promote_tensor_if_large(&self, value: &Value, threshold: usize) -> Result<Value> {
789        match value {
790            Value::GpuTensor(_) => Ok(value.clone()),
791            Value::Tensor(t) => {
792                if ensure_provider_supports_dtype(self.provider, t.dtype).is_err() {
793                    return Ok(value.clone());
794                }
795                if t.data.len() >= threshold && threshold > 0 {
796                    log_promotion(|| {
797                        format!(
798                            "Promoting tensor to GPU (len={}, threshold={})",
799                            t.data.len(),
800                            threshold
801                        )
802                    });
803                    self.tensor_to_gpu(t)
804                } else {
805                    Ok(value.clone())
806                }
807            }
808            _ => Ok(value.clone()),
809        }
810    }
811
812    fn tensor_to_gpu(&self, tensor: &Tensor) -> Result<Value> {
813        let view = HostTensorView {
814            data: &tensor.data,
815            shape: &tensor.shape,
816        };
817        let handle = self
818            .provider
819            .upload(&view)
820            .map_err(|e| anyhow!(e.to_string()))?;
821        Ok(Value::GpuTensor(handle))
822    }
823
824    fn small_batch_guard(&self, elements: usize, batch: Option<usize>) -> bool {
825        if !self.enabled {
826            return false;
827        }
828        let Some(batch) = batch else {
829            return false;
830        };
831        if batch == 0 {
832            return false;
833        }
834        let thresholds = &self.thresholds;
835        thresholds.small_batch_max_dim > 0
836            && thresholds.small_batch_min_elems > 0
837            && batch <= thresholds.small_batch_max_dim
838            && elements >= thresholds.small_batch_min_elems
839    }
840
841    fn promote_binary(&self, op: BinaryOp, a: &Value, b: &Value) -> Result<(Value, Value)> {
842        if !self.enabled {
843            return Ok((a.clone(), b.clone()));
844        }
845        match op {
846            BinaryOp::Elementwise => {
847                let elems = element_count_pair(a, b).unwrap_or(0);
848                let eval = self.evaluate_elementwise(elems, &[a, b]);
849                record_decision(decision_entry("elementwise", Some(elems), None, &eval));
850                if eval.recommend_gpu {
851                    log_promotion(|| format!("Elementwise offload accepted ({} elems)", elems));
852                    let a_p = self.promote_tensor_if_large(a, 1)?;
853                    let b_p = self.promote_tensor_if_large(b, 1)?;
854                    Ok((a_p, b_p))
855                } else {
856                    Ok((a.clone(), b.clone()))
857                }
858            }
859            BinaryOp::MatMul => {
860                if let (Some((ra, ca)), Some((rb, cb))) = (tensor_rows_cols(a), tensor_rows_cols(b))
861                {
862                    if ca != rb {
863                        return Ok((a.clone(), b.clone()));
864                    }
865                    let flops = ra.saturating_mul(ca).saturating_mul(cb);
866                    let eval = self.evaluate_matmul(flops);
867                    record_decision(decision_entry("matmul", None, Some(flops), &eval));
868                    if eval.recommend_gpu {
869                        log_promotion(|| {
870                            format!(
871                                "Promoting matmul operands (flops={}, threshold={})",
872                                flops, self.thresholds.matmul_min_flops
873                            )
874                        });
875                        let a_p = self.promote_tensor_if_large(a, 1)?;
876                        let b_p = self.promote_tensor_if_large(b, 1)?;
877                        return Ok((a_p, b_p));
878                    }
879                }
880                Ok((a.clone(), b.clone()))
881            }
882        }
883    }
884
885    fn promote_unary(&self, op: UnaryOp, v: &Value) -> Result<Value> {
886        if !self.enabled {
887            return Ok(v.clone());
888        }
889        let elems = value_len(v).unwrap_or(0);
890        let eval = self.evaluate_unary(elems, op, v);
891        let op_label = match op {
892            UnaryOp::Transpose => "transpose",
893            UnaryOp::Generic => "unary",
894        };
895        record_decision(decision_entry(op_label, Some(elems), None, &eval));
896        if eval.recommend_gpu {
897            log_promotion(|| format!("Unary offload accepted ({:?}, {} elems)", op, elems));
898            self.promote_tensor_if_large(v, 1)
899        } else {
900            Ok(v.clone())
901        }
902    }
903
904    fn promote_reduction(&self, _op: ReductionOp, args: &[Value]) -> Result<Vec<Value>> {
905        if !self.enabled || args.is_empty() {
906            return Ok(args.to_vec());
907        }
908        let elems = value_len(&args[0]).unwrap_or(0);
909        let eval = self.evaluate_reduction(elems);
910        record_decision(decision_entry("reduction", Some(elems), None, &eval));
911        if !eval.recommend_gpu {
912            return Ok(args.to_vec());
913        }
914        log_promotion(|| format!("Reduction offload accepted ({} elems)", elems));
915        let mut out = Vec::with_capacity(args.len());
916        if let Some(first) = args.first() {
917            out.push(self.promote_tensor_if_large(first, 1)?);
918            out.extend(args.iter().skip(1).cloned());
919        }
920        Ok(out)
921    }
922
923    fn evaluate_elementwise(&self, elements: usize, values: &[&Value]) -> DecisionEvaluation {
924        let fusion = active_fusion();
925        let fusion_kind = fusion.as_ref().map(|f| f.kind.clone());
926        let batch = batch_dimension_from_values(values);
927        let cpu_secs = cpu_estimate(self.thresholds.cpu_elem_per_elem, elements);
928
929        // Chain-aware residency: if any input is already on GPU, keep the op on GPU
930        if values.iter().any(|v| matches!(v, Value::GpuTensor(_))) {
931            return DecisionEvaluation {
932                recommend_gpu: true,
933                reason: DecisionReason::Residency,
934                cpu_secs,
935                gpu_secs: None,
936                threshold: Some(self.thresholds.binary_min_elems),
937                fusion_kind,
938                batch,
939            };
940        }
941
942        if let Some(active) = fusion.as_ref() {
943            // If an elementwise chain is actively fused OR this elementwise op
944            // participates in a fused reduction group, force GPU to keep the
945            // whole chain resident and avoid host round-trips.
946            if (active.kind.is_elementwise() || active.kind.is_reduction()) && active.supported {
947                return DecisionEvaluation {
948                    recommend_gpu: true,
949                    reason: DecisionReason::FusionOverride,
950                    cpu_secs,
951                    gpu_secs: None,
952                    threshold: Some(self.thresholds.binary_min_elems),
953                    fusion_kind,
954                    batch,
955                };
956            }
957        }
958
959        if self.small_batch_guard(elements, batch) {
960            return DecisionEvaluation {
961                recommend_gpu: false,
962                reason: DecisionReason::SmallBatchGuard,
963                cpu_secs,
964                gpu_secs: None,
965                threshold: Some(self.thresholds.binary_min_elems),
966                fusion_kind,
967                batch,
968            };
969        }
970
971        if let Some(model) = profile_cost_model() {
972            if let Some(gpu_duration) = model.estimate_elemwise(elements) {
973                let gpu_secs = Some(gpu_duration.as_secs_f64());
974                let cpu = cpu_secs.unwrap_or(f64::INFINITY);
975                let recommend = gpu_duration.as_secs_f64() * 0.95 < cpu;
976                return DecisionEvaluation {
977                    recommend_gpu: recommend,
978                    reason: DecisionReason::ProfileModel,
979                    cpu_secs,
980                    gpu_secs,
981                    threshold: Some(self.thresholds.binary_min_elems),
982                    fusion_kind,
983                    batch,
984                };
985            }
986        }
987
988        DecisionEvaluation {
989            recommend_gpu: elements >= self.thresholds.binary_min_elems,
990            reason: DecisionReason::Threshold,
991            cpu_secs,
992            gpu_secs: None,
993            threshold: Some(self.thresholds.binary_min_elems),
994            fusion_kind,
995            batch,
996        }
997    }
998
999    fn evaluate_matmul(&self, flops: usize) -> DecisionEvaluation {
1000        let cpu_secs = cpu_estimate(self.thresholds.cpu_matmul_per_flop, flops);
1001        if let Some(model) = profile_cost_model() {
1002            if let Some(gpu_duration) = model.estimate_matmul_flops(flops) {
1003                let gpu_secs = Some(gpu_duration.as_secs_f64());
1004                let cpu = cpu_secs.unwrap_or(f64::INFINITY);
1005                let recommend = gpu_duration.as_secs_f64() * 0.95 < cpu;
1006                return DecisionEvaluation {
1007                    recommend_gpu: recommend,
1008                    reason: DecisionReason::ProfileModel,
1009                    cpu_secs,
1010                    gpu_secs,
1011                    threshold: Some(self.thresholds.matmul_min_flops),
1012                    fusion_kind: None,
1013                    batch: None,
1014                };
1015            }
1016        }
1017
1018        DecisionEvaluation {
1019            recommend_gpu: flops >= self.thresholds.matmul_min_flops,
1020            reason: DecisionReason::Threshold,
1021            cpu_secs,
1022            gpu_secs: None,
1023            threshold: Some(self.thresholds.matmul_min_flops),
1024            fusion_kind: None,
1025            batch: None,
1026        }
1027    }
1028
1029    fn evaluate_reduction(&self, elements: usize) -> DecisionEvaluation {
1030        let fusion_kind = active_fusion().map(|f| f.kind.clone());
1031        let cpu_secs = cpu_estimate(self.thresholds.cpu_reduction_per_elem, elements);
1032        if let Some(model) = profile_cost_model() {
1033            if let Some(gpu_duration) = model.estimate_reduction(elements) {
1034                let gpu_secs = Some(gpu_duration.as_secs_f64());
1035                let cpu = cpu_secs.unwrap_or(f64::INFINITY);
1036                let recommend = gpu_duration.as_secs_f64() * 0.95 < cpu;
1037                return DecisionEvaluation {
1038                    recommend_gpu: recommend,
1039                    reason: DecisionReason::ProfileModel,
1040                    cpu_secs,
1041                    gpu_secs,
1042                    threshold: Some(self.thresholds.reduction_min_elems),
1043                    fusion_kind,
1044                    batch: None,
1045                };
1046            }
1047        }
1048
1049        DecisionEvaluation {
1050            recommend_gpu: elements >= self.thresholds.reduction_min_elems,
1051            reason: DecisionReason::Threshold,
1052            cpu_secs,
1053            gpu_secs: None,
1054            threshold: Some(self.thresholds.reduction_min_elems),
1055            fusion_kind,
1056            batch: None,
1057        }
1058    }
1059
1060    fn evaluate_unary(&self, elements: usize, op: UnaryOp, value: &Value) -> DecisionEvaluation {
1061        let fusion_kind = active_fusion().map(|f| f.kind.clone());
1062        let batch = batch_dimension_from_values(&[value]);
1063        // Chain-aware residency for unary ops: if operand is already on GPU, keep it on GPU
1064        if matches!(value, Value::GpuTensor(_)) {
1065            return DecisionEvaluation {
1066                recommend_gpu: true,
1067                reason: DecisionReason::Residency,
1068                cpu_secs: cpu_estimate(self.thresholds.cpu_elem_per_elem, elements),
1069                gpu_secs: None,
1070                threshold: Some(self.thresholds.unary_min_elems),
1071                fusion_kind,
1072                batch,
1073            };
1074        }
1075        if matches!(op, UnaryOp::Generic) && self.small_batch_guard(elements, batch) {
1076            return DecisionEvaluation {
1077                recommend_gpu: false,
1078                reason: DecisionReason::SmallBatchGuard,
1079                cpu_secs: cpu_estimate(self.thresholds.cpu_elem_per_elem, elements),
1080                gpu_secs: None,
1081                threshold: Some(self.thresholds.unary_min_elems),
1082                fusion_kind,
1083                batch,
1084            };
1085        }
1086
1087        let cpu_secs = cpu_estimate(self.thresholds.cpu_elem_per_elem, elements);
1088        if let Some(model) = profile_cost_model() {
1089            let gpu_duration = match op {
1090                UnaryOp::Transpose => model.estimate_transpose(elements),
1091                UnaryOp::Generic => model.estimate_elemwise(elements),
1092            };
1093            if let Some(gpu_duration) = gpu_duration {
1094                let gpu_secs = Some(gpu_duration.as_secs_f64());
1095                let cpu = cpu_secs.unwrap_or(f64::INFINITY);
1096                let recommend = gpu_duration.as_secs_f64() * 0.95 < cpu;
1097                return DecisionEvaluation {
1098                    recommend_gpu: recommend,
1099                    reason: DecisionReason::ProfileModel,
1100                    cpu_secs,
1101                    gpu_secs,
1102                    threshold: Some(self.thresholds.unary_min_elems),
1103                    fusion_kind,
1104                    batch,
1105                };
1106            }
1107        }
1108
1109        DecisionEvaluation {
1110            recommend_gpu: elements >= self.thresholds.unary_min_elems,
1111            reason: DecisionReason::Threshold,
1112            cpu_secs,
1113            gpu_secs: None,
1114            threshold: Some(self.thresholds.unary_min_elems),
1115            fusion_kind,
1116            batch,
1117        }
1118    }
1119
1120    async fn prepare_builtin(&self, name: &str, args: &[Value]) -> Result<Vec<Value>> {
1121        if !self.enabled {
1122            return Ok(args.to_vec());
1123        }
1124        // Do not attempt to promote 'double' on providers that cannot store f64.
1125        // Offloading a cast to double requires device-side f64; otherwise keep host.
1126        if name.eq_ignore_ascii_case("double")
1127            && self.provider.precision() != runmat_accelerate_api::ProviderPrecision::F64
1128        {
1129            return Ok(args.to_vec());
1130        }
1131        if let Some(policy) = builtin_policy(name) {
1132            if policy.is_sink {
1133                clear_sink_inputs(args);
1134                if should_gather_sink_args(name) {
1135                    trace!(
1136                        "auto-offload: prepare_builtin(name={:?}) is_sink=true residency=GatherImmediately -> gathering {} arg(s)",
1137                        name,
1138                        args.len()
1139                    );
1140                    return gather_args(args).await;
1141                }
1142                trace!(
1143                    "auto-offload: prepare_builtin(name={:?}) is_sink=true residency!=GatherImmediately -> no gather (fusion barrier only)",
1144                    name
1145                );
1146                return Ok(args.to_vec());
1147            }
1148
1149            let mut processed = args.to_vec();
1150
1151            if policy
1152                .accel_tags
1153                .iter()
1154                .any(|tag| matches!(tag, AccelTag::Reduction))
1155            {
1156                if (name.eq_ignore_ascii_case("max") || name.eq_ignore_ascii_case("min"))
1157                    && !max_or_min_reduction_call(args)
1158                {
1159                    trace!(
1160                        "Skipping reduction promotion for builtin '{}' (detected elementwise form)",
1161                        name
1162                    );
1163                } else {
1164                    log_promotion(|| format!("Promoting builtin '{}' as reduction", name));
1165                    return self.promote_reduction(reduction_op_hint(name), args);
1166                }
1167            }
1168
1169            if policy
1170                .accel_tags
1171                .iter()
1172                .any(|tag| matches!(tag, AccelTag::MatMul))
1173                && processed.len() >= 2
1174            {
1175                log_promotion(|| format!("Promoting builtin '{}' as matmul", name));
1176                let (a_p, b_p) =
1177                    self.promote_binary(BinaryOp::MatMul, &processed[0], &processed[1])?;
1178                processed[0] = a_p;
1179                processed[1] = b_p;
1180                return Ok(processed);
1181            }
1182
1183            if policy
1184                .accel_tags
1185                .iter()
1186                .any(|tag| matches!(tag, AccelTag::Elementwise))
1187                && processed.len() >= 2
1188            {
1189                log_promotion(|| format!("Promoting builtin '{}' as elementwise", name));
1190                let (a_p, b_p) =
1191                    self.promote_binary(BinaryOp::Elementwise, &processed[0], &processed[1])?;
1192                processed[0] = a_p;
1193                processed[1] = b_p;
1194                return Ok(processed);
1195            }
1196
1197            if let Some(first) = processed.first_mut() {
1198                if policy
1199                    .accel_tags
1200                    .iter()
1201                    .any(|tag| matches!(tag, AccelTag::Transpose))
1202                {
1203                    log_promotion(|| format!("Promoting builtin '{}' as transpose", name));
1204                    *first = self.promote_unary(UnaryOp::Transpose, first)?;
1205                    return Ok(processed);
1206                }
1207
1208                if policy
1209                    .accel_tags
1210                    .iter()
1211                    .any(|tag| matches!(tag, AccelTag::Unary))
1212                {
1213                    log_promotion(|| format!("Promoting builtin '{}' as unary", name));
1214                    *first = self.promote_unary(UnaryOp::Generic, first)?;
1215                    return Ok(processed);
1216                }
1217            }
1218        }
1219        Ok(args.to_vec())
1220    }
1221}
1222
1223fn tensor_rows_cols(value: &Value) -> Option<(usize, usize)> {
1224    match value {
1225        Value::Tensor(t) => Some((t.rows(), t.cols())),
1226        Value::GpuTensor(handle) => {
1227            if handle.shape.len() == 2 {
1228                Some((handle.shape[0], handle.shape[1]))
1229            } else {
1230                None
1231            }
1232        }
1233        _ => None,
1234    }
1235}
1236
1237#[allow(dead_code)]
1238fn should_skip_reduction_promotion(name: &str, args: &[Value]) -> bool {
1239    (name.eq_ignore_ascii_case("max") || name.eq_ignore_ascii_case("min"))
1240        && !max_or_min_reduction_call(args)
1241}
1242
1243fn reduction_op_hint(name: &str) -> ReductionOp {
1244    if name.eq_ignore_ascii_case("max") {
1245        ReductionOp::Max
1246    } else if name.eq_ignore_ascii_case("min") {
1247        ReductionOp::Min
1248    } else {
1249        ReductionOp::Sum
1250    }
1251}
1252
1253fn max_or_min_reduction_call(args: &[Value]) -> bool {
1254    if args.len() <= 1 {
1255        return true;
1256    }
1257    args.get(1).map(is_empty_placeholder_value).unwrap_or(false)
1258}
1259
1260fn is_empty_placeholder_value(value: &Value) -> bool {
1261    match value {
1262        Value::Tensor(t) => t.data.is_empty(),
1263        Value::LogicalArray(l) => l.data.is_empty(),
1264        Value::StringArray(sa) => sa.data.is_empty(),
1265        Value::CharArray(ca) => ca.data.is_empty(),
1266        Value::Cell(cell) => cell.data.is_empty(),
1267        Value::String(s) => s.is_empty(),
1268        _ => false,
1269    }
1270}
1271
1272async fn gather_args(args: &[Value]) -> Result<Vec<Value>> {
1273    let mut out = Vec::with_capacity(args.len());
1274    for (idx, value) in args.iter().enumerate() {
1275        if let Value::GpuTensor(handle) = value {
1276            trace!(
1277                "auto-offload: gather_args arg[{}]=GpuTensor device_id={} buffer_id={} shape={:?}",
1278                idx,
1279                handle.device_id,
1280                handle.buffer_id,
1281                handle.shape
1282            );
1283        } else {
1284            trace!(
1285                "auto-offload: gather_args arg[{}]={:?}",
1286                idx,
1287                value_kind(value)
1288            );
1289        }
1290        let gathered = gather_if_needed_async(value)
1291            .await
1292            .map_err(|e| anyhow!(e))?;
1293        trace!(
1294            "auto-offload: gather_args arg[{}] -> {:?}",
1295            idx,
1296            value_kind(&gathered)
1297        );
1298        out.push(gathered);
1299    }
1300    Ok(out)
1301}
1302
1303fn clear_sink_inputs(args: &[Value]) {
1304    for value in args {
1305        if let Value::GpuTensor(handle) = value {
1306            fusion_residency::clear(handle);
1307        }
1308    }
1309}
1310
1311fn should_gather_sink_args(name: &str) -> bool {
1312    matches!(
1313        builtin_residency_policy(name),
1314        Some(ResidencyPolicy::GatherImmediately) | None
1315    )
1316}
1317
1318fn value_kind(value: &Value) -> &'static str {
1319    match value {
1320        Value::GpuTensor(_) => "GpuTensor",
1321        Value::Tensor(_) => "Tensor",
1322        Value::Num(_) => "Num",
1323        Value::Int(_) => "Int",
1324        Value::Bool(_) => "Bool",
1325        Value::LogicalArray(_) => "LogicalArray",
1326        Value::CharArray(_) => "CharArray",
1327        Value::String(_) => "String",
1328        Value::StringArray(_) => "StringArray",
1329        Value::Cell(_) => "Cell",
1330        Value::Struct(_) => "Struct",
1331        Value::Object(_) => "Object",
1332        Value::HandleObject(_) => "HandleObject",
1333        Value::FunctionHandle(_) => "FunctionHandle",
1334        Value::Closure(_) => "Closure",
1335        Value::ClassRef(_) => "ClassRef",
1336        Value::Complex(_, _) => "Complex",
1337        Value::ComplexTensor(_) => "ComplexTensor",
1338        Value::Listener(_) => "Listener",
1339        Value::MException(_) => "MException",
1340        Value::OutputList(_) => "OutputList",
1341    }
1342}
1343
1344#[cfg(test)]
1345mod tests {
1346    use super::*;
1347
1348    #[test]
1349    fn max_detection_handles_placeholders() {
1350        let tensor = Tensor::new(vec![1.0], vec![1, 1]).unwrap();
1351        let placeholder = Tensor::new(Vec::<f64>::new(), vec![0, 0]).unwrap();
1352        let data = Value::Tensor(tensor);
1353        let empty = Value::Tensor(placeholder);
1354
1355        assert!(max_or_min_reduction_call(std::slice::from_ref(&data)));
1356        assert!(max_or_min_reduction_call(&[
1357            data.clone(),
1358            empty.clone(),
1359            Value::Num(1.0)
1360        ]));
1361        assert!(!max_or_min_reduction_call(&[data.clone(), Value::Num(0.0)]));
1362    }
1363}
1364
1365#[derive(Clone, Copy)]
1366struct BuiltinPolicy {
1367    accel_tags: &'static [AccelTag],
1368    is_sink: bool,
1369}
1370
1371static BUILTIN_POLICIES: OnceCell<HashMap<String, BuiltinPolicy>> = OnceCell::new();
1372
1373fn build_builtin_policy_map() -> HashMap<String, BuiltinPolicy> {
1374    let mut map = HashMap::new();
1375    for func in builtin_functions() {
1376        map.insert(
1377            func.name.to_ascii_lowercase(),
1378            BuiltinPolicy {
1379                accel_tags: func.accel_tags,
1380                is_sink: func.is_sink,
1381            },
1382        );
1383    }
1384    map
1385}
1386
1387fn builtin_policy(name: &str) -> Option<BuiltinPolicy> {
1388    let map = BUILTIN_POLICIES.get_or_init(build_builtin_policy_map);
1389    map.get(&name.to_ascii_lowercase()).copied()
1390}
1391
1392fn auto_enabled() -> bool {
1393    if let Some(flag) = env_bool("RUNMAT_ACCEL_AUTO_OFFLOAD") {
1394        return flag;
1395    }
1396    auto_offload_options().enabled
1397}
1398
1399fn calibrate_enabled() -> bool {
1400    if let Some(flag) = env_bool("RUNMAT_ACCEL_CALIBRATE") {
1401        return flag;
1402    }
1403    auto_offload_options().calibrate
1404}
1405
1406fn calibrate_refresh_enabled() -> bool {
1407    env_bool("RUNMAT_ACCEL_CALIBRATE_REFRESH").unwrap_or(false)
1408}
1409
1410fn apply_env_overrides(cfg: &mut ThresholdConfig) -> bool {
1411    let mut applied = false;
1412    if let Some(val) = env_usize("RUNMAT_ACCEL_THRESHOLD_UNARY") {
1413        cfg.unary_min_elems = val;
1414        applied = true;
1415    }
1416    if let Some(val) = env_usize("RUNMAT_ACCEL_THRESHOLD_ELEMWISE") {
1417        cfg.binary_min_elems = val;
1418        applied = true;
1419    }
1420    if let Some(val) = env_usize("RUNMAT_ACCEL_THRESHOLD_REDUCTION") {
1421        cfg.reduction_min_elems = val;
1422        applied = true;
1423    }
1424    if let Some(val) = env_usize("RUNMAT_ACCEL_THRESHOLD_MATMUL") {
1425        cfg.matmul_min_flops = val;
1426        applied = true;
1427    }
1428    if let Some(val) = env_usize("RUNMAT_ACCEL_THRESHOLD_ALL") {
1429        cfg.unary_min_elems = val;
1430        cfg.binary_min_elems = val;
1431        cfg.reduction_min_elems = val;
1432        applied = true;
1433    }
1434    if let Some(val) = env_usize("RUNMAT_ACCEL_SMALL_BATCH_MAX_DIM") {
1435        cfg.small_batch_max_dim = val;
1436        applied = true;
1437    }
1438    if let Some(val) = env_usize("RUNMAT_ACCEL_SMALL_BATCH_MIN_ELEMS") {
1439        cfg.small_batch_min_elems = val;
1440        applied = true;
1441    }
1442    applied
1443}
1444
1445fn env_usize(key: &str) -> Option<usize> {
1446    env::var(key).ok().and_then(|v| v.parse::<usize>().ok())
1447}
1448
1449#[derive(Debug, Clone, Serialize, Deserialize)]
1450struct CalibrationRecord {
1451    version: u32,
1452    recorded_at: u64,
1453    provider: CalibrationProviderDetails,
1454    thresholds: ThresholdConfig,
1455}
1456
1457#[derive(Debug, Clone, Serialize, Deserialize)]
1458struct CalibrationProviderDetails {
1459    name: String,
1460    vendor: String,
1461    backend: Option<String>,
1462    device_id: u32,
1463}
1464
1465#[cfg(target_arch = "wasm32")]
1466fn calibration_cache_key(info: &ApiDeviceInfo) -> String {
1467    let vendor = slugify(&info.vendor);
1468    let name = slugify(&info.name);
1469    let backend = slugify(info.backend.as_deref().unwrap_or("unknown"));
1470    format!("{}-{}-{}-{}.json", vendor, name, backend, info.device_id)
1471}
1472
1473async fn load_cached_thresholds_async(info: &ApiDeviceInfo) -> Option<(ThresholdConfig, String)> {
1474    #[cfg(target_arch = "wasm32")]
1475    {
1476        let key = calibration_cache_key(info);
1477        let contents = crate::web_auto_offload_store::load(&key).await?;
1478        match serde_json::from_str::<CalibrationRecord>(&contents) {
1479            Ok(record) => {
1480                if record.version != CALIBRATION_VERSION {
1481                    debug!(
1482                        "Native auto-offload calibration cache version mismatch (found {}, expected {})",
1483                        record.version,
1484                        CALIBRATION_VERSION
1485                    );
1486                    None
1487                } else {
1488                    Some((record.thresholds, key))
1489                }
1490            }
1491            Err(err) => {
1492                debug!(
1493                    "Native auto-offload failed to parse cached calibration for '{}': {err}",
1494                    info.name
1495                );
1496                None
1497            }
1498        }
1499    }
1500    #[cfg(not(target_arch = "wasm32"))]
1501    {
1502        load_cached_thresholds(info).map(|(cfg, path)| (cfg, path.display().to_string()))
1503    }
1504}
1505
1506async fn persist_thresholds_async(info: &ApiDeviceInfo, cfg: &ThresholdConfig) -> Result<String> {
1507    #[cfg(target_arch = "wasm32")]
1508    {
1509        let key = calibration_cache_key(info);
1510        let record = CalibrationRecord {
1511            version: CALIBRATION_VERSION,
1512            recorded_at: system_time_now()
1513                .duration_since(UNIX_EPOCH)
1514                .unwrap_or_else(|_| Duration::from_secs(0))
1515                .as_secs(),
1516            provider: CalibrationProviderDetails {
1517                name: info.name.clone(),
1518                vendor: info.vendor.clone(),
1519                backend: info.backend.clone(),
1520                device_id: info.device_id,
1521            },
1522            thresholds: cfg.clone(),
1523        };
1524        let payload = serde_json::to_string_pretty(&record).map_err(|e| anyhow!(e.to_string()))?;
1525        crate::web_auto_offload_store::save(&key, &payload)
1526            .await
1527            .map_err(|e| anyhow!(format!("indexeddb persist failed: {e:?}")))?;
1528        Ok(key)
1529    }
1530    #[cfg(not(target_arch = "wasm32"))]
1531    {
1532        persist_thresholds(info, cfg).map(|path| path.display().to_string())
1533    }
1534}
1535
1536fn load_cached_thresholds(info: &ApiDeviceInfo) -> Option<(ThresholdConfig, PathBuf)> {
1537    let path = calibration_cache_file(info)?;
1538    let contents = fs::read_to_string(&path).ok()?;
1539    match serde_json::from_str::<CalibrationRecord>(&contents) {
1540        Ok(record) => {
1541            if record.version != CALIBRATION_VERSION {
1542                debug!(
1543                    "Native auto-offload calibration cache version mismatch (found {}, expected {})",
1544                    record.version,
1545                    CALIBRATION_VERSION
1546                );
1547                None
1548            } else {
1549                Some((record.thresholds, path))
1550            }
1551        }
1552        Err(err) => {
1553            debug!(
1554                "Native auto-offload failed to parse cached calibration for '{}': {err}",
1555                info.name
1556            );
1557            None
1558        }
1559    }
1560}
1561
1562fn persist_thresholds(info: &ApiDeviceInfo, cfg: &ThresholdConfig) -> Result<PathBuf> {
1563    let path = calibration_cache_file(info)
1564        .ok_or_else(|| anyhow!("unable to determine calibration cache directory"))?;
1565    if let Some(parent) = path.parent() {
1566        fs::create_dir_all(parent).map_err(|e| anyhow!(e.to_string()))?;
1567    }
1568    let record = CalibrationRecord {
1569        version: CALIBRATION_VERSION,
1570        recorded_at: system_time_now()
1571            .duration_since(UNIX_EPOCH)
1572            .unwrap_or_else(|_| Duration::from_secs(0))
1573            .as_secs(),
1574        provider: CalibrationProviderDetails {
1575            name: info.name.clone(),
1576            vendor: info.vendor.clone(),
1577            backend: info.backend.clone(),
1578            device_id: info.device_id,
1579        },
1580        thresholds: cfg.clone(),
1581    };
1582    let payload = serde_json::to_string_pretty(&record).map_err(|e| anyhow!(e.to_string()))?;
1583    fs::write(&path, payload).map_err(|e| anyhow!(e.to_string()))?;
1584    Ok(path)
1585}
1586
1587fn calibration_cache_file(info: &ApiDeviceInfo) -> Option<PathBuf> {
1588    let mut dir = calibration_cache_dir()?;
1589    let vendor = slugify(&info.vendor);
1590    let name = slugify(&info.name);
1591    let backend = slugify(info.backend.as_deref().unwrap_or("unknown"));
1592    let file = format!("{}-{}-{}-{}.json", vendor, name, backend, info.device_id);
1593    dir.push(file);
1594    Some(dir)
1595}
1596
1597fn calibration_cache_dir() -> Option<PathBuf> {
1598    dirs::cache_dir().map(|base| base.join("runmat").join("auto_offload"))
1599}
1600
1601fn slugify(input: &str) -> String {
1602    let mut out = String::with_capacity(input.len());
1603    let mut last_underscore = false;
1604    for ch in input.chars() {
1605        if ch.is_ascii_alphanumeric() {
1606            out.push(ch.to_ascii_lowercase());
1607            last_underscore = false;
1608        } else if !last_underscore {
1609            out.push('_');
1610            last_underscore = true;
1611        }
1612    }
1613    let trimmed = out.trim_matches('_');
1614    if trimmed.is_empty() {
1615        "device".to_string()
1616    } else {
1617        trimmed.to_string()
1618    }
1619}
1620
1621fn auto_calibrate(provider: &'static dyn AccelProvider, cfg: &mut ThresholdConfig) -> Result<()> {
1622    if let Some(elem_threshold) = calibrate_elemwise(provider, cfg).transpose()? {
1623        if elem_threshold != usize::MAX {
1624            cfg.binary_min_elems = elem_threshold;
1625            cfg.unary_min_elems = cfg.unary_min_elems.min(elem_threshold);
1626        }
1627    }
1628    if let Some(red_threshold) = calibrate_reduction(provider, cfg).transpose()? {
1629        if red_threshold != usize::MAX {
1630            cfg.reduction_min_elems = red_threshold;
1631        }
1632    }
1633    if let Some(matmul_threshold) = calibrate_matmul(provider, cfg).transpose()? {
1634        if matmul_threshold != usize::MAX {
1635            cfg.matmul_min_flops = matmul_threshold;
1636        }
1637    }
1638    Ok(())
1639}
1640
1641fn calibrate_elemwise(
1642    provider: &'static dyn AccelProvider,
1643    cfg: &mut ThresholdConfig,
1644) -> Option<Result<usize>> {
1645    let sizes = [256usize, 1_024, 4_096, 16_384, 65_536];
1646    for size in sizes {
1647        match compare_elemwise(provider, size, &mut cfg.cpu_elem_per_elem) {
1648            Ok(Some(true)) => return Some(Ok(size)),
1649            Ok(Some(false)) => continue,
1650            Ok(None) => return None,
1651            Err(e) => return Some(Err(e)),
1652        }
1653    }
1654    Some(Ok(usize::MAX))
1655}
1656
1657fn compare_elemwise(
1658    provider: &'static dyn AccelProvider,
1659    elements: usize,
1660    cpu_cost_slot: &mut f64,
1661) -> Result<Option<bool>> {
1662    if elements == 0 {
1663        return Ok(Some(false));
1664    }
1665    let shape = vec![elements, 1];
1666    let template = match provider.precision() {
1667        ProviderPrecision::F64 => {
1668            Tensor::new((0..elements).map(|i| i as f64).collect(), shape.clone())
1669                .map_err(|e| anyhow!(e))?
1670        }
1671        ProviderPrecision::F32 => {
1672            Tensor::from_f32((0..elements).map(|i| i as f32).collect(), shape.clone())
1673                .map_err(|e| anyhow!(e))?
1674        }
1675    };
1676    let a = Value::Tensor(template.clone());
1677    let b = Value::Tensor(template.clone());
1678    let cpu_time = time(|| runmat_runtime::call_builtin("plus", &[a.clone(), b.clone()]))?;
1679    let cpu_per_elem = cpu_time.as_secs_f64() / elements as f64;
1680    update_cpu_cost(cpu_cost_slot, cpu_per_elem);
1681    if let Some(model) = profile_cost_model() {
1682        if let Some(gpu_time) = model.estimate_elemwise(elements) {
1683            trace!(
1684                "Elemwise calibration ({} elems): cpu={:?}, gpu_est={:?}",
1685                elements,
1686                cpu_time,
1687                gpu_time
1688            );
1689            return Ok(Some(gpu_time < cpu_time));
1690        }
1691    }
1692    let view = HostTensorView {
1693        data: template.data.as_slice(),
1694        shape: template.shape.as_slice(),
1695    };
1696    let ha = provider.upload(&view).map_err(|e| anyhow!(e.to_string()))?;
1697    let hb = provider.upload(&view).map_err(|e| anyhow!(e.to_string()))?;
1698    let start = Instant::now();
1699    let hc = match futures::executor::block_on(provider.elem_add(&ha, &hb)) {
1700        Ok(h) => h,
1701        Err(_) => {
1702            let _ = provider.free(&ha);
1703            let _ = provider.free(&hb);
1704            return Ok(None);
1705        }
1706    };
1707    let gpu_time = start.elapsed();
1708    let _ = provider.free(&ha);
1709    let _ = provider.free(&hb);
1710    let _ = provider.free(&hc);
1711    Ok(Some(gpu_time < cpu_time))
1712}
1713
1714fn calibrate_reduction(
1715    provider: &'static dyn AccelProvider,
1716    cfg: &mut ThresholdConfig,
1717) -> Option<Result<usize>> {
1718    let sizes = [256usize, 1_024, 4_096, 16_384, 65_536];
1719    for size in sizes {
1720        match compare_reduction(provider, size, &mut cfg.cpu_reduction_per_elem) {
1721            Ok(Some(true)) => return Some(Ok(size)),
1722            Ok(Some(false)) => continue,
1723            Ok(None) => return None,
1724            Err(e) => return Some(Err(e)),
1725        }
1726    }
1727    Some(Ok(usize::MAX))
1728}
1729
1730fn compare_reduction(
1731    provider: &'static dyn AccelProvider,
1732    elements: usize,
1733    cpu_cost_slot: &mut f64,
1734) -> Result<Option<bool>> {
1735    let shape = vec![elements, 1];
1736    let template = match provider.precision() {
1737        ProviderPrecision::F64 => {
1738            Tensor::new((0..elements).map(|i| i as f64).collect(), shape.clone())
1739                .map_err(|e| anyhow!(e))?
1740        }
1741        ProviderPrecision::F32 => {
1742            Tensor::from_f32((0..elements).map(|i| i as f32).collect(), shape.clone())
1743                .map_err(|e| anyhow!(e))?
1744        }
1745    };
1746    let value = Value::Tensor(template.clone());
1747    let cpu_time = time(|| runmat_runtime::call_builtin("sum", std::slice::from_ref(&value)))?;
1748    let cpu_per_elem = cpu_time.as_secs_f64() / elements as f64;
1749    update_cpu_cost(cpu_cost_slot, cpu_per_elem);
1750    if let Some(model) = profile_cost_model() {
1751        if let Some(gpu_time) = model.estimate_reduction(elements) {
1752            trace!(
1753                "Reduction calibration ({} elems): cpu={:?}, gpu_est={:?}",
1754                elements,
1755                cpu_time,
1756                gpu_time
1757            );
1758            return Ok(Some(gpu_time < cpu_time));
1759        }
1760    }
1761    let view = HostTensorView {
1762        data: template.data.as_slice(),
1763        shape: template.shape.as_slice(),
1764    };
1765    let h = provider.upload(&view).map_err(|e| anyhow!(e.to_string()))?;
1766    let start = Instant::now();
1767    let out = match futures::executor::block_on(provider.reduce_sum(&h)) {
1768        Ok(hc) => hc,
1769        Err(_) => {
1770            provider.free(&h).ok();
1771            return Ok(None);
1772        }
1773    };
1774    let gpu_time = start.elapsed();
1775    let _ = provider.free(&h);
1776    let _ = provider.free(&out);
1777    Ok(Some(gpu_time < cpu_time))
1778}
1779
1780fn calibrate_matmul(
1781    provider: &'static dyn AccelProvider,
1782    cfg: &mut ThresholdConfig,
1783) -> Option<Result<usize>> {
1784    let dims = [32usize, 64, 96, 128, 192];
1785    for n in dims {
1786        match compare_matmul(provider, n, &mut cfg.cpu_matmul_per_flop) {
1787            Ok(Some(true)) => {
1788                let flops = n * n * n;
1789                return Some(Ok(flops));
1790            }
1791            Ok(Some(false)) => continue,
1792            Ok(None) => return None,
1793            Err(e) => return Some(Err(e)),
1794        }
1795    }
1796    Some(Ok(usize::MAX))
1797}
1798
1799fn compare_matmul(
1800    provider: &'static dyn AccelProvider,
1801    n: usize,
1802    cpu_cost_slot: &mut f64,
1803) -> Result<Option<bool>> {
1804    if n == 0 {
1805        return Ok(Some(false));
1806    }
1807    let total = n * n;
1808    let shape = vec![n, n];
1809    let (ta, tb) = match provider.precision() {
1810        ProviderPrecision::F64 => {
1811            let data_a: Vec<f64> = (0..total).map(|i| (i % 13) as f64).collect();
1812            let data_b: Vec<f64> = (0..total).map(|i| (i % 7) as f64).collect();
1813            let ta = Tensor::new(data_a, shape.clone()).map_err(|e| anyhow!(e))?;
1814            let tb = Tensor::new(data_b, shape.clone()).map_err(|e| anyhow!(e))?;
1815            (ta, tb)
1816        }
1817        ProviderPrecision::F32 => {
1818            let data_a: Vec<f32> = (0..total).map(|i| (i % 13) as f32).collect();
1819            let data_b: Vec<f32> = (0..total).map(|i| (i % 7) as f32).collect();
1820            let ta = Tensor::from_f32(data_a, shape.clone()).map_err(|e| anyhow!(e))?;
1821            let tb = Tensor::from_f32(data_b, shape.clone()).map_err(|e| anyhow!(e))?;
1822            (ta, tb)
1823        }
1824    };
1825    let a = Value::Tensor(ta.clone());
1826    let b = Value::Tensor(tb.clone());
1827    let cpu_time =
1828        time(|| futures::executor::block_on(runmat_runtime::matrix::value_matmul(&a, &b)))?;
1829    let flops = (n * n * n) as f64;
1830    update_cpu_cost(cpu_cost_slot, cpu_time.as_secs_f64() / flops);
1831    if let Some(model) = profile_cost_model() {
1832        if let Some(gpu_time) = model.estimate_matmul(n, n, n) {
1833            trace!(
1834                "Matmul calibration ({}^3 flops): cpu={:?}, gpu_est={:?}",
1835                n,
1836                cpu_time,
1837                gpu_time
1838            );
1839            return Ok(Some(gpu_time < cpu_time));
1840        }
1841    }
1842    let view_a = HostTensorView {
1843        data: ta.data.as_slice(),
1844        shape: ta.shape.as_slice(),
1845    };
1846    let view_b = HostTensorView {
1847        data: tb.data.as_slice(),
1848        shape: tb.shape.as_slice(),
1849    };
1850    let ha = provider
1851        .upload(&view_a)
1852        .map_err(|e| anyhow!(e.to_string()))?;
1853    let hb = provider
1854        .upload(&view_b)
1855        .map_err(|e| anyhow!(e.to_string()))?;
1856    let start = Instant::now();
1857    let hc = match futures::executor::block_on(provider.matmul(&ha, &hb)) {
1858        Ok(h) => h,
1859        Err(_) => {
1860            let _ = provider.free(&ha);
1861            let _ = provider.free(&hb);
1862            return Ok(None);
1863        }
1864    };
1865    let gpu_time = start.elapsed();
1866    let _ = provider.free(&ha);
1867    let _ = provider.free(&hb);
1868    let _ = provider.free(&hc);
1869    Ok(Some(gpu_time < cpu_time))
1870}
1871
1872fn time<F, T>(mut f: F) -> Result<Duration>
1873where
1874    F: FnMut() -> runmat_runtime::BuiltinResult<T>,
1875{
1876    let start = Instant::now();
1877    let _ = f().map_err(|err| anyhow!(err))?;
1878    Ok(start.elapsed())
1879}
1880
1881pub fn auto_offload_report() -> Option<AutoOffloadReport> {
1882    let state_guard = AUTO_STATE.get()?;
1883    let state = state_guard.lock().ok()?;
1884    let calibration = state.previous_thresholds.as_ref().map(|prev| {
1885        let delta = state
1886            .calibration_delta
1887            .clone()
1888            .unwrap_or_else(|| compute_delta(prev, &state.thresholds));
1889        AutoOffloadCalibrationSummary {
1890            previous: threshold_snapshot(prev),
1891            delta,
1892        }
1893    });
1894    Some(AutoOffloadReport {
1895        provider: state.provider.clone(),
1896        thresholds: threshold_snapshot(&state.thresholds),
1897        base_source: state.base_source,
1898        env_overrides_applied: state.env_overrides_applied,
1899        cache_path: state.cache_path.clone(),
1900        calibrate_duration_ms: state.calibrate_duration_ms,
1901        calibration,
1902        decisions: snapshot_decisions(),
1903    })
1904}
1905
1906pub fn sequence_threshold_hint() -> Option<usize> {
1907    AUTO_STATE
1908        .get()
1909        .and_then(|state| state.lock().ok())
1910        .map(|state| state.thresholds.unary_min_elems)
1911}
1912
1913pub fn reset_auto_offload_log() {
1914    clear_decisions();
1915}
1916
1917#[derive(Clone, Deserialize, Debug)]
1918struct ProfileDurationSummary {
1919    #[serde(default)]
1920    avg_ms: f64,
1921}
1922
1923#[derive(Clone, Deserialize, Debug)]
1924struct ProfileReport {
1925    category: String,
1926    #[serde(default)]
1927    input_shapes: Vec<Vec<usize>>,
1928    total_ms: ProfileDurationSummary,
1929}
1930
1931#[derive(Clone, Copy, Default, Debug)]
1932struct LinearModel {
1933    slope: f64,
1934    intercept: f64,
1935}
1936
1937impl LinearModel {
1938    fn estimate(&self, x: f64) -> Option<Duration> {
1939        if !self.slope.is_finite() || self.slope <= 0.0 {
1940            return None;
1941        }
1942        let total = self.intercept + self.slope * x;
1943        if total.is_finite() && total > 0.0 {
1944            Some(Duration::from_secs_f64(total))
1945        } else {
1946            None
1947        }
1948    }
1949}
1950
1951#[derive(Default)]
1952struct ProfileCostModel {
1953    elem: Option<LinearModel>,
1954    reduction: Option<LinearModel>,
1955    transpose: Option<LinearModel>,
1956    matmul: Option<LinearModel>,
1957}
1958
1959impl ProfileCostModel {
1960    fn from_reports(reports: &[ProfileReport]) -> Self {
1961        let mut elem_samples = Vec::<(f64, f64)>::new();
1962        let mut reduction_samples = Vec::<(f64, f64)>::new();
1963        let mut transpose_samples = Vec::<(f64, f64)>::new();
1964        let mut matmul_samples = Vec::<(f64, f64)>::new();
1965
1966        for report in reports {
1967            let total_secs = report.total_ms.avg_ms / 1_000.0;
1968            match report.category.as_str() {
1969                "elementwise" | "reduction" | "transpose" => {
1970                    if let Some(shape) = report.input_shapes.first() {
1971                        let elems: usize = shape.iter().copied().product();
1972                        if elems == 0 {
1973                            continue;
1974                        }
1975                        let sample = (elems as f64, total_secs);
1976                        match report.category.as_str() {
1977                            "elementwise" => elem_samples.push(sample),
1978                            "reduction" => reduction_samples.push(sample),
1979                            "transpose" => transpose_samples.push(sample),
1980                            _ => {}
1981                        }
1982                    }
1983                }
1984                "matmul" => {
1985                    if report.input_shapes.len() >= 2 {
1986                        let a = &report.input_shapes[0];
1987                        let b = &report.input_shapes[1];
1988                        if a.len() == 2 && b.len() == 2 {
1989                            let m = a[0];
1990                            let k = a[1];
1991                            let n = b[1];
1992                            let flops = m.checked_mul(k).and_then(|val| val.checked_mul(n));
1993                            if let Some(flops) = flops {
1994                                matmul_samples.push((flops as f64, total_secs));
1995                            }
1996                        }
1997                    }
1998                }
1999                _ => {}
2000            }
2001        }
2002
2003        ProfileCostModel {
2004            elem: fit_linear_model(&elem_samples),
2005            reduction: fit_linear_model(&reduction_samples),
2006            transpose: fit_linear_model(&transpose_samples),
2007            matmul: fit_linear_model(&matmul_samples),
2008        }
2009    }
2010
2011    fn estimate_elemwise(&self, elements: usize) -> Option<Duration> {
2012        self.elem.and_then(|model| model.estimate(elements as f64))
2013    }
2014
2015    fn estimate_reduction(&self, elements: usize) -> Option<Duration> {
2016        self.reduction
2017            .and_then(|model| model.estimate(elements as f64))
2018    }
2019
2020    fn estimate_matmul(&self, m: usize, k: usize, n: usize) -> Option<Duration> {
2021        let flops = m.checked_mul(k)?.checked_mul(n)?;
2022        self.matmul.and_then(|model| model.estimate(flops as f64))
2023    }
2024
2025    fn estimate_matmul_flops(&self, flops: usize) -> Option<Duration> {
2026        self.matmul.and_then(|model| model.estimate(flops as f64))
2027    }
2028
2029    fn estimate_transpose(&self, elements: usize) -> Option<Duration> {
2030        self.transpose
2031            .and_then(|model| model.estimate(elements as f64))
2032    }
2033}
2034
2035fn fit_linear_model(samples: &[(f64, f64)]) -> Option<LinearModel> {
2036    if samples.is_empty() {
2037        return None;
2038    }
2039    if samples.len() == 1 {
2040        let (x, y) = samples[0];
2041        if x > 0.0 {
2042            return Some(LinearModel {
2043                slope: (y / x).max(0.0),
2044                intercept: 0.0,
2045            });
2046        }
2047        return None;
2048    }
2049
2050    let sum_x: f64 = samples.iter().map(|(x, _)| *x).sum();
2051    let sum_y: f64 = samples.iter().map(|(_, y)| *y).sum();
2052    let sum_xx: f64 = samples.iter().map(|(x, _)| x * x).sum();
2053    let sum_xy: f64 = samples.iter().map(|(x, y)| x * y).sum();
2054    let n = samples.len() as f64;
2055    let denom = (n * sum_xx) - (sum_x * sum_x);
2056    if denom.abs() < f64::EPSILON {
2057        return None;
2058    }
2059    let slope = ((n * sum_xy) - (sum_x * sum_y)) / denom;
2060    let mean_x = sum_x / n;
2061    let mean_y = sum_y / n;
2062    let mut intercept = mean_y - slope * mean_x;
2063    if intercept < 0.0 {
2064        intercept = 0.0;
2065    }
2066    if !slope.is_finite() || slope <= 0.0 {
2067        return None;
2068    }
2069    Some(LinearModel { slope, intercept })
2070}
2071
2072fn profile_cost_model() -> Option<&'static ProfileCostModel> {
2073    PROFILE_MODEL.get_or_init(load_profile_cost_model).as_ref()
2074}
2075
2076fn load_profile_cost_model() -> Option<ProfileCostModel> {
2077    let mut candidates = Vec::new();
2078    if let Ok(path) = env::var("RUNMAT_ACCEL_PROFILE") {
2079        candidates.push(PathBuf::from(path));
2080    }
2081    if let Some(path) = auto_offload_options().profile_path.clone() {
2082        candidates.push(path);
2083    }
2084    candidates.push(PathBuf::from("benchmarks/wgpu_profile/mac_m2.json"));
2085    candidates.push(PathBuf::from("wgpu_profile.json"));
2086
2087    for path in candidates {
2088        if !path.exists() {
2089            continue;
2090        }
2091        match fs::read_to_string(&path) {
2092            Ok(contents) => match serde_json::from_str::<Vec<ProfileReport>>(&contents) {
2093                Ok(reports) => {
2094                    debug!(
2095                        "Loaded {} GPU profile reports from {}",
2096                        reports.len(),
2097                        path.display()
2098                    );
2099                    return Some(ProfileCostModel::from_reports(&reports));
2100                }
2101                Err(err) => {
2102                    debug!("Failed to parse GPU profile {}: {err}", path.display());
2103                }
2104            },
2105            Err(err) => {
2106                debug!("Failed to read GPU profile {}: {err}", path.display());
2107            }
2108        }
2109    }
2110    None
2111}
2112
2113pub async fn promote_binary(op: BinaryOp, a: &Value, b: &Value) -> Result<(Value, Value)> {
2114    if !auto_enabled() {
2115        return Ok((a.clone(), b.clone()));
2116    }
2117    if let Some(auto) = global().await {
2118        auto.promote_binary(op, a, b)
2119    } else {
2120        Ok((a.clone(), b.clone()))
2121    }
2122}
2123
2124pub async fn promote_unary(op: UnaryOp, value: &Value) -> Result<Value> {
2125    if !auto_enabled() {
2126        return Ok(value.clone());
2127    }
2128    if let Some(auto) = global().await {
2129        auto.promote_unary(op, value)
2130    } else {
2131        Ok(value.clone())
2132    }
2133}
2134
2135pub async fn prepare_builtin_args(name: &str, args: &[Value]) -> Result<Vec<Value>> {
2136    if let Some(policy) = builtin_policy(name) {
2137        if policy.is_sink {
2138            clear_sink_inputs(args);
2139            if should_gather_sink_args(name) {
2140                trace!(
2141                    "auto-offload: prepare_builtin_args(name={:?}) is_sink=true residency=GatherImmediately -> gathering {} arg(s)",
2142                    name,
2143                    args.len()
2144                );
2145                return gather_args(args).await;
2146            }
2147            trace!(
2148                "auto-offload: prepare_builtin_args(name={:?}) is_sink=true residency!=GatherImmediately -> no gather (fusion barrier only)",
2149                name
2150            );
2151            return Ok(args.to_vec());
2152        }
2153    }
2154    if !auto_enabled() {
2155        return Ok(args.to_vec());
2156    }
2157    if let Some(auto) = global().await {
2158        auto.prepare_builtin(name, args).await
2159    } else {
2160        Ok(args.to_vec())
2161    }
2162}
2163
2164pub fn is_sink(name: &str) -> bool {
2165    builtin_policy(name).map(|p| p.is_sink).unwrap_or(false)
2166}
2167
2168pub async fn promote_reduction_args(op: ReductionOp, args: &[Value]) -> Result<Vec<Value>> {
2169    if !auto_enabled() {
2170        return Ok(args.to_vec());
2171    }
2172    if let Some(auto) = global().await {
2173        auto.promote_reduction(op, args)
2174    } else {
2175        Ok(args.to_vec())
2176    }
2177}