Skip to main content

runmat_accelerate/
native_auto.rs

1use runmat_time::{system_time_now, Instant};
2use std::collections::HashMap;
3use std::env;
4use std::fs;
5use std::path::{Path, PathBuf};
6use std::sync::Mutex;
7use std::time::{Duration, UNIX_EPOCH};
8
9use crate::{
10    auto_offload_options,
11    fusion::{active_fusion, FusionKind},
12    fusion_residency,
13    precision::ensure_provider_supports_dtype,
14    AutoOffloadLogLevel,
15};
16use anyhow::{anyhow, Result};
17use futures::lock::Mutex as AsyncMutex;
18use log::{debug, info, trace, warn};
19use once_cell::sync::{Lazy, OnceCell};
20use runmat_accelerate_api::{AccelProvider, ApiDeviceInfo, HostTensorView, ProviderPrecision};
21use runmat_builtins::{builtin_functions, AccelTag, Tensor, Value};
22use runmat_runtime::builtins::common::spec::{builtin_residency_policy, ResidencyPolicy};
23use runmat_runtime::gather_if_needed_async;
24use serde::{Deserialize, Serialize};
25
26const DEFAULT_CPU_ELEM_PER_ELEM: f64 = 1.0e-7;
27const DEFAULT_CPU_REDUCTION_PER_ELEM: f64 = 1.2e-7;
28const DEFAULT_CPU_MATMUL_PER_FLOP: f64 = 2.5e-11;
29const SMALL_BATCH_DEFAULT_MAX_DIM: usize = 8;
30const SMALL_BATCH_DEFAULT_MIN_ELEMS: usize = 1_048_576;
31const DECISION_LOG_CAPACITY: usize = 128;
32const CALIBRATION_VERSION: u32 = 1;
33
34#[derive(Clone, Copy, Debug)]
35pub enum BinaryOp {
36    Elementwise,
37    MatMul,
38}
39
40#[derive(Clone, Copy, Debug)]
41pub enum UnaryOp {
42    Generic,
43    Transpose,
44}
45
46#[derive(Clone, Copy, Debug)]
47pub enum ReductionOp {
48    Sum,
49    Mean,
50    Min,
51    Max,
52}
53
54#[derive(Debug, Clone, Serialize, Deserialize)]
55struct ThresholdConfig {
56    unary_min_elems: usize,
57    binary_min_elems: usize,
58    reduction_min_elems: usize,
59    matmul_min_flops: usize,
60    cpu_elem_per_elem: f64,
61    cpu_reduction_per_elem: f64,
62    cpu_matmul_per_flop: f64,
63    small_batch_max_dim: usize,
64    small_batch_min_elems: usize,
65}
66
67impl Default for ThresholdConfig {
68    fn default() -> Self {
69        Self {
70            unary_min_elems: 4_096,
71            binary_min_elems: 4_096,
72            reduction_min_elems: 256,
73            matmul_min_flops: 1_000_000, // roughly 100x100x100
74            cpu_elem_per_elem: DEFAULT_CPU_ELEM_PER_ELEM,
75            cpu_reduction_per_elem: DEFAULT_CPU_REDUCTION_PER_ELEM,
76            cpu_matmul_per_flop: DEFAULT_CPU_MATMUL_PER_FLOP,
77            small_batch_max_dim: SMALL_BATCH_DEFAULT_MAX_DIM,
78            small_batch_min_elems: SMALL_BATCH_DEFAULT_MIN_ELEMS,
79        }
80    }
81}
82
83#[derive(Debug, Clone, Serialize)]
84pub struct AutoOffloadDecisionEntry {
85    pub timestamp_ms: u128,
86    pub operation: String,
87    pub elements: Option<usize>,
88    pub flops: Option<usize>,
89    pub batch: Option<usize>,
90    pub decision: AutoOffloadDisposition,
91    pub reason: DecisionReason,
92    pub cpu_estimate_ms: Option<f64>,
93    pub gpu_estimate_ms: Option<f64>,
94    pub threshold: Option<usize>,
95    pub fusion_kind: Option<FusionKind>,
96}
97
98#[derive(Debug, Clone, Copy, Serialize, PartialEq, Eq)]
99#[serde(rename_all = "kebab-case")]
100pub enum AutoOffloadDisposition {
101    Gpu,
102    Cpu,
103}
104
105#[derive(Debug, Clone, Copy, Serialize, PartialEq, Eq)]
106#[serde(rename_all = "kebab-case")]
107pub enum DecisionReason {
108    FusionOverride,
109    Residency,
110    SmallBatchGuard,
111    ProfileModel,
112    Threshold,
113    Disabled,
114}
115
116#[derive(Debug, Clone, Serialize)]
117pub struct ThresholdSnapshot {
118    pub unary_min_elems: usize,
119    pub binary_min_elems: usize,
120    pub reduction_min_elems: usize,
121    pub matmul_min_flops: usize,
122    pub cpu_elem_per_elem: f64,
123    pub cpu_reduction_per_elem: f64,
124    pub cpu_matmul_per_flop: f64,
125    pub small_batch_max_dim: usize,
126    pub small_batch_min_elems: usize,
127}
128
129#[derive(Debug, Clone, Serialize)]
130pub struct AutoOffloadCalibrationSummary {
131    pub previous: ThresholdSnapshot,
132    pub delta: ThresholdDelta,
133}
134
135#[derive(Debug, Clone, Serialize, Default)]
136pub struct ThresholdDelta {
137    #[serde(skip_serializing_if = "Option::is_none")]
138    pub cpu_elem_per_elem: Option<ThresholdDeltaEntry>,
139    #[serde(skip_serializing_if = "Option::is_none")]
140    pub cpu_reduction_per_elem: Option<ThresholdDeltaEntry>,
141    #[serde(skip_serializing_if = "Option::is_none")]
142    pub cpu_matmul_per_flop: Option<ThresholdDeltaEntry>,
143}
144
145#[derive(Debug, Clone, Serialize)]
146pub struct ThresholdDeltaEntry {
147    pub before: f64,
148    pub after: f64,
149    pub absolute: f64,
150    #[serde(skip_serializing_if = "Option::is_none")]
151    pub ratio: Option<f64>,
152}
153
154impl ThresholdDeltaEntry {
155    fn new(before: f64, after: f64) -> Self {
156        let absolute = after - before;
157        let ratio = if before.abs() > f64::EPSILON {
158            Some(after / before)
159        } else {
160            None
161        };
162        Self {
163            before,
164            after,
165            absolute,
166            ratio,
167        }
168    }
169}
170
171#[derive(Debug, Clone, Serialize)]
172pub struct AutoOffloadReport {
173    pub provider: Option<CachedProviderInfo>,
174    pub thresholds: ThresholdSnapshot,
175    pub base_source: ThresholdBase,
176    pub env_overrides_applied: bool,
177    pub cache_path: Option<String>,
178    pub calibrate_duration_ms: Option<u128>,
179    #[serde(skip_serializing_if = "Option::is_none")]
180    pub calibration: Option<AutoOffloadCalibrationSummary>,
181    pub decisions: Vec<AutoOffloadDecisionEntry>,
182}
183
184#[derive(Debug, Clone, Copy, Serialize, PartialEq, Eq)]
185#[serde(rename_all = "kebab-case")]
186pub enum ThresholdBase {
187    BuiltInDefault,
188    LoadedFromCache,
189    Calibrated,
190}
191
192impl ThresholdBase {
193    pub fn as_str(&self) -> &'static str {
194        match self {
195            ThresholdBase::BuiltInDefault => "built-in-default",
196            ThresholdBase::LoadedFromCache => "loaded-from-cache",
197            ThresholdBase::Calibrated => "calibrated",
198        }
199    }
200}
201
202#[derive(Debug, Clone, Serialize)]
203pub struct CachedProviderInfo {
204    pub name: String,
205    pub vendor: String,
206    pub backend: Option<String>,
207    pub device_id: u32,
208}
209
210#[derive(Debug, Clone)]
211struct AutoOffloadState {
212    provider: Option<CachedProviderInfo>,
213    thresholds: ThresholdConfig,
214    base_source: ThresholdBase,
215    env_overrides_applied: bool,
216    cache_path: Option<String>,
217    calibrate_duration_ms: Option<u128>,
218    previous_thresholds: Option<ThresholdConfig>,
219    calibration_delta: Option<ThresholdDelta>,
220}
221
222#[derive(Clone)]
223struct DecisionEvaluation {
224    recommend_gpu: bool,
225    reason: DecisionReason,
226    cpu_secs: Option<f64>,
227    gpu_secs: Option<f64>,
228    threshold: Option<usize>,
229    fusion_kind: Option<FusionKind>,
230    batch: Option<usize>,
231}
232
233struct DecisionLog {
234    entries: Vec<AutoOffloadDecisionEntry>,
235}
236
237impl DecisionLog {
238    fn new() -> Self {
239        Self {
240            entries: Vec::new(),
241        }
242    }
243
244    fn push(&mut self, entry: AutoOffloadDecisionEntry) {
245        self.entries.push(entry);
246        if self.entries.len() > DECISION_LOG_CAPACITY {
247            let overflow = self.entries.len() - DECISION_LOG_CAPACITY;
248            self.entries.drain(0..overflow);
249        }
250    }
251
252    fn snapshot(&self) -> Vec<AutoOffloadDecisionEntry> {
253        self.entries.clone()
254    }
255
256    fn clear(&mut self) {
257        self.entries.clear();
258    }
259}
260
261static DECISION_LOG: Lazy<Mutex<DecisionLog>> = Lazy::new(|| Mutex::new(DecisionLog::new()));
262static AUTO_STATE: OnceCell<Mutex<AutoOffloadState>> = OnceCell::new();
263
264fn record_decision(entry: AutoOffloadDecisionEntry) {
265    if let Ok(mut log) = DECISION_LOG.lock() {
266        log.push(entry);
267    }
268}
269
270fn snapshot_decisions() -> Vec<AutoOffloadDecisionEntry> {
271    DECISION_LOG
272        .lock()
273        .map(|log| log.snapshot())
274        .unwrap_or_default()
275}
276
277fn clear_decisions() {
278    if let Ok(mut log) = DECISION_LOG.lock() {
279        log.clear();
280    }
281}
282
283fn now_millis() -> u128 {
284    system_time_now()
285        .duration_since(UNIX_EPOCH)
286        .unwrap_or_else(|_| Duration::from_secs(0))
287        .as_millis()
288}
289
290fn threshold_snapshot(cfg: &ThresholdConfig) -> ThresholdSnapshot {
291    ThresholdSnapshot {
292        unary_min_elems: cfg.unary_min_elems,
293        binary_min_elems: cfg.binary_min_elems,
294        reduction_min_elems: cfg.reduction_min_elems,
295        matmul_min_flops: cfg.matmul_min_flops,
296        cpu_elem_per_elem: cfg.cpu_elem_per_elem,
297        cpu_reduction_per_elem: cfg.cpu_reduction_per_elem,
298        cpu_matmul_per_flop: cfg.cpu_matmul_per_flop,
299        small_batch_max_dim: cfg.small_batch_max_dim,
300        small_batch_min_elems: cfg.small_batch_min_elems,
301    }
302}
303
304fn compute_delta(before: &ThresholdConfig, after: &ThresholdConfig) -> ThresholdDelta {
305    let mut delta = ThresholdDelta::default();
306
307    if (before.cpu_elem_per_elem - after.cpu_elem_per_elem).abs() > f64::EPSILON {
308        delta.cpu_elem_per_elem = Some(ThresholdDeltaEntry::new(
309            before.cpu_elem_per_elem,
310            after.cpu_elem_per_elem,
311        ));
312    }
313
314    if (before.cpu_reduction_per_elem - after.cpu_reduction_per_elem).abs() > f64::EPSILON {
315        delta.cpu_reduction_per_elem = Some(ThresholdDeltaEntry::new(
316            before.cpu_reduction_per_elem,
317            after.cpu_reduction_per_elem,
318        ));
319    }
320
321    if (before.cpu_matmul_per_flop - after.cpu_matmul_per_flop).abs() > f64::EPSILON {
322        delta.cpu_matmul_per_flop = Some(ThresholdDeltaEntry::new(
323            before.cpu_matmul_per_flop,
324            after.cpu_matmul_per_flop,
325        ));
326    }
327
328    delta
329}
330
331#[derive(Debug, Deserialize)]
332struct CalibrationFile {
333    #[serde(default)]
334    suite: Option<CalibrationSuiteSection>,
335    #[serde(default)]
336    auto_offload_calibration: Option<CalibrationSample>,
337}
338
339#[derive(Debug, Deserialize)]
340struct CalibrationSuiteSection {
341    #[serde(default)]
342    auto_offload_calibration: Option<CalibrationSample>,
343}
344
345#[derive(Debug, Clone, Deserialize)]
346struct CalibrationSample {
347    #[serde(default)]
348    runs: usize,
349    #[serde(default, rename = "cpu_time_ms")]
350    cpu_time: CalibrationTimes,
351    #[serde(default)]
352    units: CalibrationUnits,
353    #[serde(default)]
354    provider: Option<CalibrationProviderInfo>,
355    #[serde(default)]
356    provider_conflict: bool,
357}
358
359#[derive(Debug, Clone, Deserialize, Default)]
360struct CalibrationTimes {
361    #[serde(default)]
362    elementwise: f64,
363    #[serde(default)]
364    reduction: f64,
365    #[serde(default)]
366    matmul: f64,
367}
368
369#[derive(Debug, Clone, Deserialize, Default)]
370struct CalibrationUnits {
371    #[serde(default)]
372    elementwise: f64,
373    #[serde(default)]
374    reduction: f64,
375    #[serde(default, rename = "matmul_flops")]
376    matmul_flops: f64,
377}
378
379#[derive(Debug, Clone, Deserialize)]
380struct CalibrationProviderInfo {
381    name: String,
382    vendor: String,
383    #[serde(default)]
384    backend: Option<String>,
385    device_id: u32,
386}
387
388#[derive(Debug, Serialize)]
389pub struct AutoOffloadCalibrationOutcome {
390    pub runs: usize,
391    pub before: ThresholdSnapshot,
392    pub after: ThresholdSnapshot,
393    #[serde(skip_serializing_if = "Option::is_none")]
394    pub delta: Option<ThresholdDelta>,
395    #[serde(skip_serializing_if = "Option::is_none")]
396    pub persisted_to: Option<String>,
397    #[serde(skip_serializing_if = "Option::is_none")]
398    pub provider: Option<CachedProviderInfo>,
399    pub commit: bool,
400}
401
402fn load_calibration_sample(path: &Path) -> Result<CalibrationSample> {
403    let payload = fs::read_to_string(path).map_err(|e| anyhow!(e.to_string()))?;
404    let file: CalibrationFile = serde_json::from_str(&payload)
405        .map_err(|e| anyhow!(format!("failed to parse calibration file: {e}")))?;
406    if let Some(suite) = file.suite {
407        if let Some(sample) = suite.auto_offload_calibration {
408            return Ok(sample);
409        }
410    }
411    if let Some(sample) = file.auto_offload_calibration {
412        return Ok(sample);
413    }
414    Err(anyhow!(
415        "calibration file does not contain an auto_offload_calibration section"
416    ))
417}
418
419fn apply_calibration_sample(
420    cfg: &mut ThresholdConfig,
421    sample: &CalibrationSample,
422) -> Option<ThresholdDelta> {
423    let mut delta = ThresholdDelta::default();
424    let mut changed = false;
425
426    if sample.units.elementwise > 0.0 && sample.cpu_time.elementwise > 0.0 {
427        let secs_per_elem = (sample.cpu_time.elementwise / 1_000.0) / sample.units.elementwise;
428        if secs_per_elem.is_finite()
429            && secs_per_elem > 0.0
430            && (cfg.cpu_elem_per_elem - secs_per_elem).abs() > f64::EPSILON
431        {
432            delta.cpu_elem_per_elem = Some(ThresholdDeltaEntry::new(
433                cfg.cpu_elem_per_elem,
434                secs_per_elem,
435            ));
436            cfg.cpu_elem_per_elem = secs_per_elem;
437            changed = true;
438        }
439    }
440
441    if sample.units.reduction > 0.0 && sample.cpu_time.reduction > 0.0 {
442        let secs_per_elem = (sample.cpu_time.reduction / 1_000.0) / sample.units.reduction;
443        if secs_per_elem.is_finite()
444            && secs_per_elem > 0.0
445            && (cfg.cpu_reduction_per_elem - secs_per_elem).abs() > f64::EPSILON
446        {
447            delta.cpu_reduction_per_elem = Some(ThresholdDeltaEntry::new(
448                cfg.cpu_reduction_per_elem,
449                secs_per_elem,
450            ));
451            cfg.cpu_reduction_per_elem = secs_per_elem;
452            changed = true;
453        }
454    }
455
456    if sample.units.matmul_flops > 0.0 && sample.cpu_time.matmul > 0.0 {
457        let secs_per_flop = (sample.cpu_time.matmul / 1_000.0) / sample.units.matmul_flops;
458        if secs_per_flop.is_finite()
459            && secs_per_flop > 0.0
460            && (cfg.cpu_matmul_per_flop - secs_per_flop).abs() > f64::EPSILON
461        {
462            delta.cpu_matmul_per_flop = Some(ThresholdDeltaEntry::new(
463                cfg.cpu_matmul_per_flop,
464                secs_per_flop,
465            ));
466            cfg.cpu_matmul_per_flop = secs_per_flop;
467            changed = true;
468        }
469    }
470
471    if changed {
472        Some(delta)
473    } else {
474        None
475    }
476}
477
478pub fn apply_auto_offload_calibration_from_file(
479    path: &Path,
480    commit: bool,
481) -> Result<AutoOffloadCalibrationOutcome> {
482    let sample = load_calibration_sample(path)?;
483    if sample.runs == 0 {
484        return Err(anyhow!("calibration sample contains zero runs"));
485    }
486
487    let provider = runmat_accelerate_api::provider()
488        .ok_or_else(|| anyhow!("no acceleration provider registered"))?;
489    let device_info = provider.device_info_struct();
490
491    if let Some(ref prov) = sample.provider {
492        if prov.name != device_info.name
493            || prov.vendor != device_info.vendor
494            || prov.backend.as_deref() != device_info.backend.as_deref()
495            || prov.device_id != device_info.device_id
496        {
497            warn!(
498                "Calibration provider mismatch: sample='{} ({})' device='{} ({})'",
499                prov.name, prov.vendor, device_info.name, device_info.vendor
500            );
501        }
502        if sample.provider_conflict {
503            warn!("Calibration sample reported provider conflict across cases");
504        }
505    }
506
507    let (mut cfg, _) = load_cached_thresholds(&device_info)
508        .unwrap_or_else(|| (ThresholdConfig::default(), PathBuf::new()));
509    let before_cfg = cfg.clone();
510
511    let delta = apply_calibration_sample(&mut cfg, &sample)
512        .ok_or_else(|| anyhow!("calibration sample did not produce coefficient updates"))?;
513
514    let mut persisted_to: Option<PathBuf> = None;
515    if commit {
516        persisted_to = Some(persist_thresholds(&device_info, &cfg)?);
517    }
518
519    if let Some(state_mutex) = AUTO_STATE.get() {
520        if let Ok(mut state) = state_mutex.lock() {
521            state.previous_thresholds = Some(before_cfg.clone());
522            state.calibration_delta = Some(delta.clone());
523            if commit {
524                state.thresholds = cfg.clone();
525                state.base_source = ThresholdBase::Calibrated;
526                if let Some(ref path_buf) = persisted_to {
527                    state.cache_path = Some(path_buf.to_string_lossy().into_owned());
528                }
529                state.calibrate_duration_ms = None;
530            }
531        }
532    }
533
534    Ok(AutoOffloadCalibrationOutcome {
535        runs: sample.runs,
536        before: threshold_snapshot(&before_cfg),
537        after: threshold_snapshot(&cfg),
538        delta: Some(delta),
539        persisted_to: persisted_to.map(|p| p.to_string_lossy().into_owned()),
540        provider: Some(cached_provider_info(&device_info)),
541        commit,
542    })
543}
544
545fn cached_provider_info(info: &ApiDeviceInfo) -> CachedProviderInfo {
546    CachedProviderInfo {
547        name: info.name.clone(),
548        vendor: info.vendor.clone(),
549        backend: info.backend.clone(),
550        device_id: info.device_id,
551    }
552}
553
554fn cpu_estimate(per_unit: f64, units: usize) -> Option<f64> {
555    if per_unit.is_finite() && per_unit > 0.0 {
556        Some(per_unit * units as f64)
557    } else {
558        None
559    }
560}
561
562fn value_shape(value: &Value) -> Option<&[usize]> {
563    match value {
564        Value::Tensor(t) => Some(&t.shape),
565        Value::GpuTensor(handle) => Some(&handle.shape),
566        _ => None,
567    }
568}
569
570fn batch_dimension_from_value(value: &Value) -> Option<usize> {
571    let shape = value_shape(value)?;
572    if shape.len() < 3 {
573        return None;
574    }
575    shape.last().copied()
576}
577
578fn batch_dimension_from_values(values: &[&Value]) -> Option<usize> {
579    values
580        .iter()
581        .filter_map(|value| batch_dimension_from_value(value))
582        .min()
583}
584
585fn decision_entry(
586    operation: &str,
587    elements: Option<usize>,
588    flops: Option<usize>,
589    eval: &DecisionEvaluation,
590) -> AutoOffloadDecisionEntry {
591    AutoOffloadDecisionEntry {
592        timestamp_ms: now_millis(),
593        operation: operation.to_string(),
594        elements,
595        flops,
596        batch: eval.batch,
597        decision: if eval.recommend_gpu {
598            AutoOffloadDisposition::Gpu
599        } else {
600            AutoOffloadDisposition::Cpu
601        },
602        reason: eval.reason,
603        cpu_estimate_ms: eval.cpu_secs.map(|secs| secs * 1_000.0),
604        gpu_estimate_ms: eval.gpu_secs.map(|secs| secs * 1_000.0),
605        threshold: eval.threshold,
606        fusion_kind: eval.fusion_kind.clone(),
607    }
608}
609
610pub struct NativeAutoOffload {
611    provider: &'static dyn AccelProvider,
612    thresholds: ThresholdConfig,
613    enabled: bool,
614}
615
616static GLOBAL: OnceCell<Option<NativeAutoOffload>> = OnceCell::new();
617static GLOBAL_INIT_LOCK: Lazy<AsyncMutex<()>> = Lazy::new(|| AsyncMutex::new(()));
618static PROFILE_MODEL: OnceCell<Option<ProfileCostModel>> = OnceCell::new();
619
620fn env_bool(key: &str) -> Option<bool> {
621    env::var(key).ok().and_then(|v| parse_bool(&v))
622}
623
624fn parse_bool(s: &str) -> Option<bool> {
625    match s.trim().to_ascii_lowercase().as_str() {
626        "1" | "true" | "yes" | "on" => Some(true),
627        "0" | "false" | "no" | "off" => Some(false),
628        _ => None,
629    }
630}
631
632fn log_promotion<F>(builder: F)
633where
634    F: FnOnce() -> String,
635{
636    match auto_offload_options().log_level {
637        AutoOffloadLogLevel::Off => {}
638        AutoOffloadLogLevel::Info => info!("{}", builder()),
639        AutoOffloadLogLevel::Trace => trace!("{}", builder()),
640    }
641}
642
643fn update_cpu_cost(slot: &mut f64, candidate: f64) {
644    if candidate.is_finite() && candidate > 0.0 && candidate < *slot {
645        *slot = candidate;
646    }
647}
648
649fn value_len(value: &Value) -> Option<usize> {
650    match value {
651        Value::Tensor(t) => Some(t.data.len()),
652        Value::GpuTensor(handle) => Some(handle.shape.iter().product()),
653        Value::Num(_) | Value::Bool(_) | Value::Int(_) => Some(1),
654        Value::Complex(_, _) => Some(1),
655        _ => None,
656    }
657}
658
659fn element_count_pair(a: &Value, b: &Value) -> Option<usize> {
660    let la = value_len(a)?;
661    let lb = value_len(b)?;
662    Some(la.max(lb))
663}
664
665pub async fn global() -> Option<&'static NativeAutoOffload> {
666    if let Some(existing) = GLOBAL.get() {
667        return existing.as_ref();
668    }
669    // If auto-offload is disabled or there is no GPU provider registered,
670    // initialize_async() would return None immediately (no I/O, no blocking).
671    // Return None directly without acquiring the async lock so single-poll
672    // callers (e.g. the turbine JIT interpreter fallback) never observe a
673    // spurious Pending.  We intentionally do NOT write to GLOBAL here: doing
674    // so without holding GLOBAL_INIT_LOCK would race with a concurrent thread
675    // that is partway through initialize_async() and has found a valid
676    // provider.  That thread's subsequent GLOBAL.set(Some(offload)) would
677    // silently fail (OnceCell is set-once), permanently disabling the
678    // accelerator for the lifetime of the process.  These two checks are
679    // cheap (no I/O), so re-evaluating them on each call is acceptable.
680    if !auto_enabled() || runmat_accelerate_api::provider().is_none() {
681        return None;
682    }
683    let _guard = GLOBAL_INIT_LOCK.lock().await;
684    if let Some(existing) = GLOBAL.get() {
685        return existing.as_ref();
686    }
687    let initialized = initialize_async().await;
688    let _ = GLOBAL.set(initialized);
689    GLOBAL.get().and_then(|value| value.as_ref())
690}
691
692async fn initialize_async() -> Option<NativeAutoOffload> {
693    if !auto_enabled() {
694        clear_decisions();
695        return None;
696    }
697    let provider = runmat_accelerate_api::provider()?;
698    let device_info = provider.device_info_struct();
699    let mut config = ThresholdConfig::default();
700    let mut base_source = ThresholdBase::BuiltInDefault;
701    let mut cache_path: Option<String> = None;
702    let mut calibrate_duration_ms: Option<u128> = None;
703    let refresh_calibration = calibrate_refresh_enabled();
704
705    if !refresh_calibration {
706        if let Some((cached, path)) = load_cached_thresholds_async(&device_info).await {
707            info!(
708                "Native auto-offload: loaded cached calibration for '{}' from {}",
709                device_info.name, path
710            );
711            config = cached;
712            cache_path = Some(path);
713            base_source = ThresholdBase::LoadedFromCache;
714        }
715    }
716
717    let needs_calibration = calibrate_enabled() && (refresh_calibration || cache_path.is_none());
718    if needs_calibration {
719        let start = Instant::now();
720        match auto_calibrate(provider, &mut config) {
721            Ok(()) => {
722                calibrate_duration_ms = Some(start.elapsed().as_millis());
723                base_source = ThresholdBase::Calibrated;
724                match persist_thresholds_async(&device_info, &config).await {
725                    Ok(path) => {
726                        cache_path = Some(path.clone());
727                        info!(
728                            "Native auto-offload: persisted calibration for '{}' to {}",
729                            device_info.name, path
730                        );
731                    }
732                    Err(err) => {
733                        debug!("Native auto-offload: failed to persist calibration: {err}");
734                    }
735                }
736            }
737            Err(err) => {
738                debug!("Native auto-offload calibration failed: {err}");
739            }
740        }
741    }
742
743    let env_overrides_applied = apply_env_overrides(&mut config);
744    let model_status = if profile_cost_model().is_some() {
745        "profile"
746    } else {
747        "fallback"
748    };
749    info!(
750        "Native auto-offload thresholds: unary={} binary={} reduction={} matmul_flops={} small_batch_dim={} small_batch_min_elems={} (model: {}, source: {}, env_overrides={})",
751        config.unary_min_elems,
752        config.binary_min_elems,
753        config.reduction_min_elems,
754        config.matmul_min_flops,
755        config.small_batch_max_dim,
756        config.small_batch_min_elems,
757        model_status,
758        base_source.as_str(),
759        env_overrides_applied
760    );
761
762    let cache_path_str = cache_path.clone();
763    let state = AutoOffloadState {
764        provider: Some(cached_provider_info(&device_info)),
765        thresholds: config.clone(),
766        base_source,
767        env_overrides_applied,
768        cache_path: cache_path_str,
769        calibrate_duration_ms,
770        previous_thresholds: None,
771        calibration_delta: None,
772    };
773    let _ = AUTO_STATE.set(Mutex::new(state));
774
775    Some(NativeAutoOffload::new(provider, config))
776}
777
778impl NativeAutoOffload {
779    fn new(provider: &'static dyn AccelProvider, thresholds: ThresholdConfig) -> Self {
780        let enabled = true;
781        Self {
782            provider,
783            thresholds,
784            enabled,
785        }
786    }
787
788    fn promote_tensor_if_large(&self, value: &Value, threshold: usize) -> Result<Value> {
789        match value {
790            Value::GpuTensor(_) => Ok(value.clone()),
791            Value::Tensor(t) => {
792                if ensure_provider_supports_dtype(self.provider, t.dtype).is_err() {
793                    return Ok(value.clone());
794                }
795                if t.data.len() >= threshold && threshold > 0 {
796                    log_promotion(|| {
797                        format!(
798                            "Promoting tensor to GPU (len={}, threshold={})",
799                            t.data.len(),
800                            threshold
801                        )
802                    });
803                    self.tensor_to_gpu(t)
804                } else {
805                    Ok(value.clone())
806                }
807            }
808            _ => Ok(value.clone()),
809        }
810    }
811
812    fn tensor_to_gpu(&self, tensor: &Tensor) -> Result<Value> {
813        let view = HostTensorView {
814            data: &tensor.data,
815            shape: &tensor.shape,
816        };
817        let handle = self
818            .provider
819            .upload(&view)
820            .map_err(|e| anyhow!(e.to_string()))?;
821        Ok(Value::GpuTensor(handle))
822    }
823
824    fn small_batch_guard(&self, elements: usize, batch: Option<usize>) -> bool {
825        if !self.enabled {
826            return false;
827        }
828        let Some(batch) = batch else {
829            return false;
830        };
831        if batch == 0 {
832            return false;
833        }
834        let thresholds = &self.thresholds;
835        thresholds.small_batch_max_dim > 0
836            && thresholds.small_batch_min_elems > 0
837            && batch <= thresholds.small_batch_max_dim
838            && elements >= thresholds.small_batch_min_elems
839    }
840
841    fn promote_binary(&self, op: BinaryOp, a: &Value, b: &Value) -> Result<(Value, Value)> {
842        if !self.enabled {
843            return Ok((a.clone(), b.clone()));
844        }
845        match op {
846            BinaryOp::Elementwise => {
847                let elems = element_count_pair(a, b).unwrap_or(0);
848                let eval = self.evaluate_elementwise(elems, &[a, b]);
849                record_decision(decision_entry("elementwise", Some(elems), None, &eval));
850                if eval.recommend_gpu {
851                    log_promotion(|| format!("Elementwise offload accepted ({} elems)", elems));
852                    let a_p = self.promote_tensor_if_large(a, 1)?;
853                    let b_p = self.promote_tensor_if_large(b, 1)?;
854                    Ok((a_p, b_p))
855                } else {
856                    Ok((a.clone(), b.clone()))
857                }
858            }
859            BinaryOp::MatMul => {
860                if let (Some((ra, ca)), Some((rb, cb))) = (tensor_rows_cols(a), tensor_rows_cols(b))
861                {
862                    if ca != rb {
863                        return Ok((a.clone(), b.clone()));
864                    }
865                    let flops = ra.saturating_mul(ca).saturating_mul(cb);
866                    let eval = self.evaluate_matmul(flops);
867                    record_decision(decision_entry("matmul", None, Some(flops), &eval));
868                    if eval.recommend_gpu {
869                        log_promotion(|| {
870                            format!(
871                                "Promoting matmul operands (flops={}, threshold={})",
872                                flops, self.thresholds.matmul_min_flops
873                            )
874                        });
875                        let a_p = self.promote_tensor_if_large(a, 1)?;
876                        let b_p = self.promote_tensor_if_large(b, 1)?;
877                        return Ok((a_p, b_p));
878                    }
879                }
880                Ok((a.clone(), b.clone()))
881            }
882        }
883    }
884
885    fn promote_unary(&self, op: UnaryOp, v: &Value) -> Result<Value> {
886        if !self.enabled {
887            return Ok(v.clone());
888        }
889        let elems = value_len(v).unwrap_or(0);
890        let eval = self.evaluate_unary(elems, op, v);
891        let op_label = match op {
892            UnaryOp::Transpose => "transpose",
893            UnaryOp::Generic => "unary",
894        };
895        record_decision(decision_entry(op_label, Some(elems), None, &eval));
896        if eval.recommend_gpu {
897            log_promotion(|| format!("Unary offload accepted ({:?}, {} elems)", op, elems));
898            self.promote_tensor_if_large(v, 1)
899        } else {
900            Ok(v.clone())
901        }
902    }
903
904    fn promote_reduction(&self, _op: ReductionOp, args: &[Value]) -> Result<Vec<Value>> {
905        if !self.enabled || args.is_empty() {
906            return Ok(args.to_vec());
907        }
908        let elems = value_len(&args[0]).unwrap_or(0);
909        let eval = self.evaluate_reduction(elems);
910        record_decision(decision_entry("reduction", Some(elems), None, &eval));
911        if !eval.recommend_gpu {
912            return Ok(args.to_vec());
913        }
914        log_promotion(|| format!("Reduction offload accepted ({} elems)", elems));
915        let mut out = Vec::with_capacity(args.len());
916        if let Some(first) = args.first() {
917            out.push(self.promote_tensor_if_large(first, 1)?);
918            out.extend(args.iter().skip(1).cloned());
919        }
920        Ok(out)
921    }
922
923    fn evaluate_elementwise(&self, elements: usize, values: &[&Value]) -> DecisionEvaluation {
924        let fusion = active_fusion();
925        let fusion_kind = fusion.as_ref().map(|f| f.kind.clone());
926        let batch = batch_dimension_from_values(values);
927        let cpu_secs = cpu_estimate(self.thresholds.cpu_elem_per_elem, elements);
928
929        // Chain-aware residency: if any input is already on GPU, keep the op on GPU
930        if values.iter().any(|v| matches!(v, Value::GpuTensor(_))) {
931            return DecisionEvaluation {
932                recommend_gpu: true,
933                reason: DecisionReason::Residency,
934                cpu_secs,
935                gpu_secs: None,
936                threshold: Some(self.thresholds.binary_min_elems),
937                fusion_kind,
938                batch,
939            };
940        }
941
942        if let Some(active) = fusion.as_ref() {
943            // If an elementwise chain is actively fused OR this elementwise op
944            // participates in a fused reduction group, force GPU to keep the
945            // whole chain resident and avoid host round-trips.
946            if (active.kind.is_elementwise() || active.kind.is_reduction()) && active.supported {
947                return DecisionEvaluation {
948                    recommend_gpu: true,
949                    reason: DecisionReason::FusionOverride,
950                    cpu_secs,
951                    gpu_secs: None,
952                    threshold: Some(self.thresholds.binary_min_elems),
953                    fusion_kind,
954                    batch,
955                };
956            }
957        }
958
959        if self.small_batch_guard(elements, batch) {
960            return DecisionEvaluation {
961                recommend_gpu: false,
962                reason: DecisionReason::SmallBatchGuard,
963                cpu_secs,
964                gpu_secs: None,
965                threshold: Some(self.thresholds.binary_min_elems),
966                fusion_kind,
967                batch,
968            };
969        }
970
971        if let Some(model) = profile_cost_model() {
972            if let Some(gpu_duration) = model.estimate_elemwise(elements) {
973                let gpu_secs = Some(gpu_duration.as_secs_f64());
974                let cpu = cpu_secs.unwrap_or(f64::INFINITY);
975                let recommend = gpu_duration.as_secs_f64() * 0.95 < cpu;
976                return DecisionEvaluation {
977                    recommend_gpu: recommend,
978                    reason: DecisionReason::ProfileModel,
979                    cpu_secs,
980                    gpu_secs,
981                    threshold: Some(self.thresholds.binary_min_elems),
982                    fusion_kind,
983                    batch,
984                };
985            }
986        }
987
988        DecisionEvaluation {
989            recommend_gpu: elements >= self.thresholds.binary_min_elems,
990            reason: DecisionReason::Threshold,
991            cpu_secs,
992            gpu_secs: None,
993            threshold: Some(self.thresholds.binary_min_elems),
994            fusion_kind,
995            batch,
996        }
997    }
998
999    fn evaluate_matmul(&self, flops: usize) -> DecisionEvaluation {
1000        let cpu_secs = cpu_estimate(self.thresholds.cpu_matmul_per_flop, flops);
1001        if let Some(model) = profile_cost_model() {
1002            if let Some(gpu_duration) = model.estimate_matmul_flops(flops) {
1003                let gpu_secs = Some(gpu_duration.as_secs_f64());
1004                let cpu = cpu_secs.unwrap_or(f64::INFINITY);
1005                let recommend = gpu_duration.as_secs_f64() * 0.95 < cpu;
1006                return DecisionEvaluation {
1007                    recommend_gpu: recommend,
1008                    reason: DecisionReason::ProfileModel,
1009                    cpu_secs,
1010                    gpu_secs,
1011                    threshold: Some(self.thresholds.matmul_min_flops),
1012                    fusion_kind: None,
1013                    batch: None,
1014                };
1015            }
1016        }
1017
1018        DecisionEvaluation {
1019            recommend_gpu: flops >= self.thresholds.matmul_min_flops,
1020            reason: DecisionReason::Threshold,
1021            cpu_secs,
1022            gpu_secs: None,
1023            threshold: Some(self.thresholds.matmul_min_flops),
1024            fusion_kind: None,
1025            batch: None,
1026        }
1027    }
1028
1029    fn evaluate_reduction(&self, elements: usize) -> DecisionEvaluation {
1030        let fusion_kind = active_fusion().map(|f| f.kind.clone());
1031        let cpu_secs = cpu_estimate(self.thresholds.cpu_reduction_per_elem, elements);
1032        if let Some(model) = profile_cost_model() {
1033            if let Some(gpu_duration) = model.estimate_reduction(elements) {
1034                let gpu_secs = Some(gpu_duration.as_secs_f64());
1035                let cpu = cpu_secs.unwrap_or(f64::INFINITY);
1036                let recommend = gpu_duration.as_secs_f64() * 0.95 < cpu;
1037                return DecisionEvaluation {
1038                    recommend_gpu: recommend,
1039                    reason: DecisionReason::ProfileModel,
1040                    cpu_secs,
1041                    gpu_secs,
1042                    threshold: Some(self.thresholds.reduction_min_elems),
1043                    fusion_kind,
1044                    batch: None,
1045                };
1046            }
1047        }
1048
1049        DecisionEvaluation {
1050            recommend_gpu: elements >= self.thresholds.reduction_min_elems,
1051            reason: DecisionReason::Threshold,
1052            cpu_secs,
1053            gpu_secs: None,
1054            threshold: Some(self.thresholds.reduction_min_elems),
1055            fusion_kind,
1056            batch: None,
1057        }
1058    }
1059
1060    fn evaluate_unary(&self, elements: usize, op: UnaryOp, value: &Value) -> DecisionEvaluation {
1061        let fusion_kind = active_fusion().map(|f| f.kind.clone());
1062        let batch = batch_dimension_from_values(&[value]);
1063        // Chain-aware residency for unary ops: if operand is already on GPU, keep it on GPU
1064        if matches!(value, Value::GpuTensor(_)) {
1065            return DecisionEvaluation {
1066                recommend_gpu: true,
1067                reason: DecisionReason::Residency,
1068                cpu_secs: cpu_estimate(self.thresholds.cpu_elem_per_elem, elements),
1069                gpu_secs: None,
1070                threshold: Some(self.thresholds.unary_min_elems),
1071                fusion_kind,
1072                batch,
1073            };
1074        }
1075        if matches!(op, UnaryOp::Generic) && self.small_batch_guard(elements, batch) {
1076            return DecisionEvaluation {
1077                recommend_gpu: false,
1078                reason: DecisionReason::SmallBatchGuard,
1079                cpu_secs: cpu_estimate(self.thresholds.cpu_elem_per_elem, elements),
1080                gpu_secs: None,
1081                threshold: Some(self.thresholds.unary_min_elems),
1082                fusion_kind,
1083                batch,
1084            };
1085        }
1086
1087        let cpu_secs = cpu_estimate(self.thresholds.cpu_elem_per_elem, elements);
1088        if let Some(model) = profile_cost_model() {
1089            let gpu_duration = match op {
1090                UnaryOp::Transpose => model.estimate_transpose(elements),
1091                UnaryOp::Generic => model.estimate_elemwise(elements),
1092            };
1093            if let Some(gpu_duration) = gpu_duration {
1094                let gpu_secs = Some(gpu_duration.as_secs_f64());
1095                let cpu = cpu_secs.unwrap_or(f64::INFINITY);
1096                let recommend = gpu_duration.as_secs_f64() * 0.95 < cpu;
1097                return DecisionEvaluation {
1098                    recommend_gpu: recommend,
1099                    reason: DecisionReason::ProfileModel,
1100                    cpu_secs,
1101                    gpu_secs,
1102                    threshold: Some(self.thresholds.unary_min_elems),
1103                    fusion_kind,
1104                    batch,
1105                };
1106            }
1107        }
1108
1109        DecisionEvaluation {
1110            recommend_gpu: elements >= self.thresholds.unary_min_elems,
1111            reason: DecisionReason::Threshold,
1112            cpu_secs,
1113            gpu_secs: None,
1114            threshold: Some(self.thresholds.unary_min_elems),
1115            fusion_kind,
1116            batch,
1117        }
1118    }
1119
1120    async fn prepare_builtin(&self, name: &str, args: &[Value]) -> Result<Vec<Value>> {
1121        if !self.enabled {
1122            return Ok(args.to_vec());
1123        }
1124        // Do not attempt to promote 'double' on providers that cannot store f64.
1125        // Offloading a cast to double requires device-side f64; otherwise keep host.
1126        if name.eq_ignore_ascii_case("double")
1127            && self.provider.precision() != runmat_accelerate_api::ProviderPrecision::F64
1128        {
1129            return Ok(args.to_vec());
1130        }
1131        if let Some(policy) = builtin_policy(name) {
1132            if policy.is_sink {
1133                clear_sink_inputs(args);
1134                if should_gather_sink_args(name) {
1135                    trace!(
1136                        "auto-offload: prepare_builtin(name={:?}) is_sink=true residency=GatherImmediately -> gathering {} arg(s)",
1137                        name,
1138                        args.len()
1139                    );
1140                    return gather_args(args).await;
1141                }
1142                trace!(
1143                    "auto-offload: prepare_builtin(name={:?}) is_sink=true residency!=GatherImmediately -> no gather (fusion barrier only)",
1144                    name
1145                );
1146                return Ok(args.to_vec());
1147            }
1148
1149            let mut processed = args.to_vec();
1150
1151            if policy
1152                .accel_tags
1153                .iter()
1154                .any(|tag| matches!(tag, AccelTag::Reduction))
1155            {
1156                if (name.eq_ignore_ascii_case("max") || name.eq_ignore_ascii_case("min"))
1157                    && !max_or_min_reduction_call(args)
1158                {
1159                    trace!(
1160                        "Skipping reduction promotion for builtin '{}' (detected elementwise form)",
1161                        name
1162                    );
1163                } else {
1164                    log_promotion(|| format!("Promoting builtin '{}' as reduction", name));
1165                    return self.promote_reduction(reduction_op_hint(name), args);
1166                }
1167            }
1168
1169            if policy
1170                .accel_tags
1171                .iter()
1172                .any(|tag| matches!(tag, AccelTag::MatMul))
1173                && processed.len() >= 2
1174            {
1175                log_promotion(|| format!("Promoting builtin '{}' as matmul", name));
1176                let (a_p, b_p) =
1177                    self.promote_binary(BinaryOp::MatMul, &processed[0], &processed[1])?;
1178                processed[0] = a_p;
1179                processed[1] = b_p;
1180                return Ok(processed);
1181            }
1182
1183            if policy
1184                .accel_tags
1185                .iter()
1186                .any(|tag| matches!(tag, AccelTag::Elementwise))
1187                && processed.len() >= 2
1188            {
1189                log_promotion(|| format!("Promoting builtin '{}' as elementwise", name));
1190                let (a_p, b_p) =
1191                    self.promote_binary(BinaryOp::Elementwise, &processed[0], &processed[1])?;
1192                processed[0] = a_p;
1193                processed[1] = b_p;
1194                return Ok(processed);
1195            }
1196
1197            if let Some(first) = processed.first_mut() {
1198                if policy
1199                    .accel_tags
1200                    .iter()
1201                    .any(|tag| matches!(tag, AccelTag::Transpose))
1202                {
1203                    log_promotion(|| format!("Promoting builtin '{}' as transpose", name));
1204                    *first = self.promote_unary(UnaryOp::Transpose, first)?;
1205                    return Ok(processed);
1206                }
1207
1208                if policy
1209                    .accel_tags
1210                    .iter()
1211                    .any(|tag| matches!(tag, AccelTag::Unary))
1212                {
1213                    log_promotion(|| format!("Promoting builtin '{}' as unary", name));
1214                    *first = self.promote_unary(UnaryOp::Generic, first)?;
1215                    return Ok(processed);
1216                }
1217            }
1218        }
1219        Ok(args.to_vec())
1220    }
1221}
1222
1223fn tensor_rows_cols(value: &Value) -> Option<(usize, usize)> {
1224    match value {
1225        Value::Tensor(t) => Some((t.rows(), t.cols())),
1226        Value::GpuTensor(handle) => {
1227            if handle.shape.len() == 2 {
1228                Some((handle.shape[0], handle.shape[1]))
1229            } else {
1230                None
1231            }
1232        }
1233        _ => None,
1234    }
1235}
1236
1237#[allow(dead_code)]
1238fn should_skip_reduction_promotion(name: &str, args: &[Value]) -> bool {
1239    (name.eq_ignore_ascii_case("max") || name.eq_ignore_ascii_case("min"))
1240        && !max_or_min_reduction_call(args)
1241}
1242
1243fn reduction_op_hint(name: &str) -> ReductionOp {
1244    if name.eq_ignore_ascii_case("max") {
1245        ReductionOp::Max
1246    } else if name.eq_ignore_ascii_case("min") {
1247        ReductionOp::Min
1248    } else {
1249        ReductionOp::Sum
1250    }
1251}
1252
1253fn max_or_min_reduction_call(args: &[Value]) -> bool {
1254    if args.len() <= 1 {
1255        return true;
1256    }
1257    args.get(1).map(is_empty_placeholder_value).unwrap_or(false)
1258}
1259
1260fn is_empty_placeholder_value(value: &Value) -> bool {
1261    match value {
1262        Value::Tensor(t) => t.data.is_empty(),
1263        Value::LogicalArray(l) => l.data.is_empty(),
1264        Value::StringArray(sa) => sa.data.is_empty(),
1265        Value::CharArray(ca) => ca.data.is_empty(),
1266        Value::Cell(cell) => cell.data.is_empty(),
1267        Value::String(s) => s.is_empty(),
1268        _ => false,
1269    }
1270}
1271
1272async fn gather_args(args: &[Value]) -> Result<Vec<Value>> {
1273    let mut out = Vec::with_capacity(args.len());
1274    for (idx, value) in args.iter().enumerate() {
1275        if let Value::GpuTensor(handle) = value {
1276            trace!(
1277                "auto-offload: gather_args arg[{}]=GpuTensor device_id={} buffer_id={} shape={:?}",
1278                idx,
1279                handle.device_id,
1280                handle.buffer_id,
1281                handle.shape
1282            );
1283        } else {
1284            trace!(
1285                "auto-offload: gather_args arg[{}]={:?}",
1286                idx,
1287                value_kind(value)
1288            );
1289        }
1290        let gathered = gather_if_needed_async(value)
1291            .await
1292            .map_err(|e| anyhow!(e))?;
1293        trace!(
1294            "auto-offload: gather_args arg[{}] -> {:?}",
1295            idx,
1296            value_kind(&gathered)
1297        );
1298        out.push(gathered);
1299    }
1300    Ok(out)
1301}
1302
1303fn clear_sink_inputs(args: &[Value]) {
1304    for value in args {
1305        if let Value::GpuTensor(handle) = value {
1306            fusion_residency::clear(handle);
1307        }
1308    }
1309}
1310
1311fn should_gather_sink_args(name: &str) -> bool {
1312    matches!(
1313        builtin_residency_policy(name),
1314        Some(ResidencyPolicy::GatherImmediately) | None
1315    )
1316}
1317
1318fn value_kind(value: &Value) -> &'static str {
1319    match value {
1320        Value::GpuTensor(_) => "GpuTensor",
1321        Value::Tensor(_) => "Tensor",
1322        Value::SparseTensor(_) => "SparseTensor",
1323        Value::Num(_) => "Num",
1324        Value::Int(_) => "Int",
1325        Value::Bool(_) => "Bool",
1326        Value::LogicalArray(_) => "LogicalArray",
1327        Value::CharArray(_) => "CharArray",
1328        Value::String(_) => "String",
1329        Value::StringArray(_) => "StringArray",
1330        Value::Cell(_) => "Cell",
1331        Value::Struct(_) => "Struct",
1332        Value::Object(_) => "Object",
1333        Value::HandleObject(_) => "HandleObject",
1334        Value::FunctionHandle(_)
1335        | Value::ExternalFunctionHandle(_)
1336        | Value::MethodFunctionHandle(_) => "FunctionHandle",
1337        Value::BoundFunctionHandle { .. } => "FunctionHandle",
1338        Value::Closure(_) => "Closure",
1339        Value::ClassRef(_) => "ClassRef",
1340        Value::Complex(_, _) => "Complex",
1341        Value::ComplexTensor(_) => "ComplexTensor",
1342        Value::Listener(_) => "Listener",
1343        Value::MException(_) => "MException",
1344        Value::OutputList(_) => "OutputList",
1345    }
1346}
1347
1348#[cfg(test)]
1349mod tests {
1350    use super::*;
1351
1352    #[test]
1353    fn max_detection_handles_placeholders() {
1354        let tensor = Tensor::new(vec![1.0], vec![1, 1]).unwrap();
1355        let placeholder = Tensor::new(Vec::<f64>::new(), vec![0, 0]).unwrap();
1356        let data = Value::Tensor(tensor);
1357        let empty = Value::Tensor(placeholder);
1358
1359        assert!(max_or_min_reduction_call(std::slice::from_ref(&data)));
1360        assert!(max_or_min_reduction_call(&[
1361            data.clone(),
1362            empty.clone(),
1363            Value::Num(1.0)
1364        ]));
1365        assert!(!max_or_min_reduction_call(&[data.clone(), Value::Num(0.0)]));
1366    }
1367}
1368
1369#[derive(Clone, Copy)]
1370struct BuiltinPolicy {
1371    accel_tags: &'static [AccelTag],
1372    is_sink: bool,
1373}
1374
1375static BUILTIN_POLICIES: OnceCell<HashMap<String, BuiltinPolicy>> = OnceCell::new();
1376
1377fn build_builtin_policy_map() -> HashMap<String, BuiltinPolicy> {
1378    let mut map = HashMap::new();
1379    for func in builtin_functions() {
1380        map.insert(
1381            func.name.to_ascii_lowercase(),
1382            BuiltinPolicy {
1383                accel_tags: func.accel_tags,
1384                is_sink: func.is_sink,
1385            },
1386        );
1387    }
1388    map
1389}
1390
1391fn builtin_policy(name: &str) -> Option<BuiltinPolicy> {
1392    let map = BUILTIN_POLICIES.get_or_init(build_builtin_policy_map);
1393    map.get(&name.to_ascii_lowercase()).copied()
1394}
1395
1396fn auto_enabled() -> bool {
1397    if let Some(flag) = env_bool("RUNMAT_ACCEL_AUTO_OFFLOAD") {
1398        return flag;
1399    }
1400    auto_offload_options().enabled
1401}
1402
1403fn calibrate_enabled() -> bool {
1404    if let Some(flag) = env_bool("RUNMAT_ACCEL_CALIBRATE") {
1405        return flag;
1406    }
1407    auto_offload_options().calibrate
1408}
1409
1410fn calibrate_refresh_enabled() -> bool {
1411    env_bool("RUNMAT_ACCEL_CALIBRATE_REFRESH").unwrap_or(false)
1412}
1413
1414fn apply_env_overrides(cfg: &mut ThresholdConfig) -> bool {
1415    let mut applied = false;
1416    if let Some(val) = env_usize("RUNMAT_ACCEL_THRESHOLD_UNARY") {
1417        cfg.unary_min_elems = val;
1418        applied = true;
1419    }
1420    if let Some(val) = env_usize("RUNMAT_ACCEL_THRESHOLD_ELEMWISE") {
1421        cfg.binary_min_elems = val;
1422        applied = true;
1423    }
1424    if let Some(val) = env_usize("RUNMAT_ACCEL_THRESHOLD_REDUCTION") {
1425        cfg.reduction_min_elems = val;
1426        applied = true;
1427    }
1428    if let Some(val) = env_usize("RUNMAT_ACCEL_THRESHOLD_MATMUL") {
1429        cfg.matmul_min_flops = val;
1430        applied = true;
1431    }
1432    if let Some(val) = env_usize("RUNMAT_ACCEL_THRESHOLD_ALL") {
1433        cfg.unary_min_elems = val;
1434        cfg.binary_min_elems = val;
1435        cfg.reduction_min_elems = val;
1436        applied = true;
1437    }
1438    if let Some(val) = env_usize("RUNMAT_ACCEL_SMALL_BATCH_MAX_DIM") {
1439        cfg.small_batch_max_dim = val;
1440        applied = true;
1441    }
1442    if let Some(val) = env_usize("RUNMAT_ACCEL_SMALL_BATCH_MIN_ELEMS") {
1443        cfg.small_batch_min_elems = val;
1444        applied = true;
1445    }
1446    applied
1447}
1448
1449fn env_usize(key: &str) -> Option<usize> {
1450    env::var(key).ok().and_then(|v| v.parse::<usize>().ok())
1451}
1452
1453#[derive(Debug, Clone, Serialize, Deserialize)]
1454struct CalibrationRecord {
1455    version: u32,
1456    recorded_at: u64,
1457    provider: CalibrationProviderDetails,
1458    thresholds: ThresholdConfig,
1459}
1460
1461#[derive(Debug, Clone, Serialize, Deserialize)]
1462struct CalibrationProviderDetails {
1463    name: String,
1464    vendor: String,
1465    backend: Option<String>,
1466    device_id: u32,
1467}
1468
1469#[cfg(target_arch = "wasm32")]
1470fn calibration_cache_key(info: &ApiDeviceInfo) -> String {
1471    let vendor = slugify(&info.vendor);
1472    let name = slugify(&info.name);
1473    let backend = slugify(info.backend.as_deref().unwrap_or("unknown"));
1474    format!("{}-{}-{}-{}.json", vendor, name, backend, info.device_id)
1475}
1476
1477async fn load_cached_thresholds_async(info: &ApiDeviceInfo) -> Option<(ThresholdConfig, String)> {
1478    #[cfg(target_arch = "wasm32")]
1479    {
1480        let key = calibration_cache_key(info);
1481        let contents = crate::web_auto_offload_store::load(&key).await?;
1482        match serde_json::from_str::<CalibrationRecord>(&contents) {
1483            Ok(record) => {
1484                if record.version != CALIBRATION_VERSION {
1485                    debug!(
1486                        "Native auto-offload calibration cache version mismatch (found {}, expected {})",
1487                        record.version,
1488                        CALIBRATION_VERSION
1489                    );
1490                    None
1491                } else {
1492                    Some((record.thresholds, key))
1493                }
1494            }
1495            Err(err) => {
1496                debug!(
1497                    "Native auto-offload failed to parse cached calibration for '{}': {err}",
1498                    info.name
1499                );
1500                None
1501            }
1502        }
1503    }
1504    #[cfg(not(target_arch = "wasm32"))]
1505    {
1506        load_cached_thresholds(info).map(|(cfg, path)| (cfg, path.display().to_string()))
1507    }
1508}
1509
1510async fn persist_thresholds_async(info: &ApiDeviceInfo, cfg: &ThresholdConfig) -> Result<String> {
1511    #[cfg(target_arch = "wasm32")]
1512    {
1513        let key = calibration_cache_key(info);
1514        let record = CalibrationRecord {
1515            version: CALIBRATION_VERSION,
1516            recorded_at: system_time_now()
1517                .duration_since(UNIX_EPOCH)
1518                .unwrap_or_else(|_| Duration::from_secs(0))
1519                .as_secs(),
1520            provider: CalibrationProviderDetails {
1521                name: info.name.clone(),
1522                vendor: info.vendor.clone(),
1523                backend: info.backend.clone(),
1524                device_id: info.device_id,
1525            },
1526            thresholds: cfg.clone(),
1527        };
1528        let payload = serde_json::to_string_pretty(&record).map_err(|e| anyhow!(e.to_string()))?;
1529        crate::web_auto_offload_store::save(&key, &payload)
1530            .await
1531            .map_err(|e| anyhow!(format!("indexeddb persist failed: {e:?}")))?;
1532        Ok(key)
1533    }
1534    #[cfg(not(target_arch = "wasm32"))]
1535    {
1536        persist_thresholds(info, cfg).map(|path| path.display().to_string())
1537    }
1538}
1539
1540fn load_cached_thresholds(info: &ApiDeviceInfo) -> Option<(ThresholdConfig, PathBuf)> {
1541    let path = calibration_cache_file(info)?;
1542    let contents = fs::read_to_string(&path).ok()?;
1543    match serde_json::from_str::<CalibrationRecord>(&contents) {
1544        Ok(record) => {
1545            if record.version != CALIBRATION_VERSION {
1546                debug!(
1547                    "Native auto-offload calibration cache version mismatch (found {}, expected {})",
1548                    record.version,
1549                    CALIBRATION_VERSION
1550                );
1551                None
1552            } else {
1553                Some((record.thresholds, path))
1554            }
1555        }
1556        Err(err) => {
1557            debug!(
1558                "Native auto-offload failed to parse cached calibration for '{}': {err}",
1559                info.name
1560            );
1561            None
1562        }
1563    }
1564}
1565
1566fn persist_thresholds(info: &ApiDeviceInfo, cfg: &ThresholdConfig) -> Result<PathBuf> {
1567    let path = calibration_cache_file(info)
1568        .ok_or_else(|| anyhow!("unable to determine calibration cache directory"))?;
1569    if let Some(parent) = path.parent() {
1570        fs::create_dir_all(parent).map_err(|e| anyhow!(e.to_string()))?;
1571    }
1572    let record = CalibrationRecord {
1573        version: CALIBRATION_VERSION,
1574        recorded_at: system_time_now()
1575            .duration_since(UNIX_EPOCH)
1576            .unwrap_or_else(|_| Duration::from_secs(0))
1577            .as_secs(),
1578        provider: CalibrationProviderDetails {
1579            name: info.name.clone(),
1580            vendor: info.vendor.clone(),
1581            backend: info.backend.clone(),
1582            device_id: info.device_id,
1583        },
1584        thresholds: cfg.clone(),
1585    };
1586    let payload = serde_json::to_string_pretty(&record).map_err(|e| anyhow!(e.to_string()))?;
1587    fs::write(&path, payload).map_err(|e| anyhow!(e.to_string()))?;
1588    Ok(path)
1589}
1590
1591fn calibration_cache_file(info: &ApiDeviceInfo) -> Option<PathBuf> {
1592    let mut dir = calibration_cache_dir()?;
1593    let vendor = slugify(&info.vendor);
1594    let name = slugify(&info.name);
1595    let backend = slugify(info.backend.as_deref().unwrap_or("unknown"));
1596    let file = format!("{}-{}-{}-{}.json", vendor, name, backend, info.device_id);
1597    dir.push(file);
1598    Some(dir)
1599}
1600
1601fn calibration_cache_dir() -> Option<PathBuf> {
1602    dirs::cache_dir().map(|base| base.join("runmat").join("auto_offload"))
1603}
1604
1605fn slugify(input: &str) -> String {
1606    let mut out = String::with_capacity(input.len());
1607    let mut last_underscore = false;
1608    for ch in input.chars() {
1609        if ch.is_ascii_alphanumeric() {
1610            out.push(ch.to_ascii_lowercase());
1611            last_underscore = false;
1612        } else if !last_underscore {
1613            out.push('_');
1614            last_underscore = true;
1615        }
1616    }
1617    let trimmed = out.trim_matches('_');
1618    if trimmed.is_empty() {
1619        "device".to_string()
1620    } else {
1621        trimmed.to_string()
1622    }
1623}
1624
1625fn auto_calibrate(provider: &'static dyn AccelProvider, cfg: &mut ThresholdConfig) -> Result<()> {
1626    if let Some(elem_threshold) = calibrate_elemwise(provider, cfg).transpose()? {
1627        if elem_threshold != usize::MAX {
1628            cfg.binary_min_elems = elem_threshold;
1629            cfg.unary_min_elems = cfg.unary_min_elems.min(elem_threshold);
1630        }
1631    }
1632    if let Some(red_threshold) = calibrate_reduction(provider, cfg).transpose()? {
1633        if red_threshold != usize::MAX {
1634            cfg.reduction_min_elems = red_threshold;
1635        }
1636    }
1637    if let Some(matmul_threshold) = calibrate_matmul(provider, cfg).transpose()? {
1638        if matmul_threshold != usize::MAX {
1639            cfg.matmul_min_flops = matmul_threshold;
1640        }
1641    }
1642    Ok(())
1643}
1644
1645fn calibrate_elemwise(
1646    provider: &'static dyn AccelProvider,
1647    cfg: &mut ThresholdConfig,
1648) -> Option<Result<usize>> {
1649    let sizes = [256usize, 1_024, 4_096, 16_384, 65_536];
1650    for size in sizes {
1651        match compare_elemwise(provider, size, &mut cfg.cpu_elem_per_elem) {
1652            Ok(Some(true)) => return Some(Ok(size)),
1653            Ok(Some(false)) => continue,
1654            Ok(None) => return None,
1655            Err(e) => return Some(Err(e)),
1656        }
1657    }
1658    Some(Ok(usize::MAX))
1659}
1660
1661fn compare_elemwise(
1662    provider: &'static dyn AccelProvider,
1663    elements: usize,
1664    cpu_cost_slot: &mut f64,
1665) -> Result<Option<bool>> {
1666    if elements == 0 {
1667        return Ok(Some(false));
1668    }
1669    let shape = vec![elements, 1];
1670    let template = match provider.precision() {
1671        ProviderPrecision::F64 => {
1672            Tensor::new((0..elements).map(|i| i as f64).collect(), shape.clone())
1673                .map_err(|e| anyhow!(e))?
1674        }
1675        ProviderPrecision::F32 => {
1676            Tensor::from_f32((0..elements).map(|i| i as f32).collect(), shape.clone())
1677                .map_err(|e| anyhow!(e))?
1678        }
1679    };
1680    let a = Value::Tensor(template.clone());
1681    let b = Value::Tensor(template.clone());
1682    let cpu_time = time(|| runmat_runtime::call_builtin("plus", &[a.clone(), b.clone()]))?;
1683    let cpu_per_elem = cpu_time.as_secs_f64() / elements as f64;
1684    update_cpu_cost(cpu_cost_slot, cpu_per_elem);
1685    if let Some(model) = profile_cost_model() {
1686        if let Some(gpu_time) = model.estimate_elemwise(elements) {
1687            trace!(
1688                "Elemwise calibration ({} elems): cpu={:?}, gpu_est={:?}",
1689                elements,
1690                cpu_time,
1691                gpu_time
1692            );
1693            return Ok(Some(gpu_time < cpu_time));
1694        }
1695    }
1696    let view = HostTensorView {
1697        data: template.data.as_slice(),
1698        shape: template.shape.as_slice(),
1699    };
1700    let ha = provider.upload(&view).map_err(|e| anyhow!(e.to_string()))?;
1701    let hb = provider.upload(&view).map_err(|e| anyhow!(e.to_string()))?;
1702    let start = Instant::now();
1703    let hc = match futures::executor::block_on(provider.elem_add(&ha, &hb)) {
1704        Ok(h) => h,
1705        Err(_) => {
1706            let _ = provider.free(&ha);
1707            let _ = provider.free(&hb);
1708            return Ok(None);
1709        }
1710    };
1711    let gpu_time = start.elapsed();
1712    let _ = provider.free(&ha);
1713    let _ = provider.free(&hb);
1714    let _ = provider.free(&hc);
1715    Ok(Some(gpu_time < cpu_time))
1716}
1717
1718fn calibrate_reduction(
1719    provider: &'static dyn AccelProvider,
1720    cfg: &mut ThresholdConfig,
1721) -> Option<Result<usize>> {
1722    let sizes = [256usize, 1_024, 4_096, 16_384, 65_536];
1723    for size in sizes {
1724        match compare_reduction(provider, size, &mut cfg.cpu_reduction_per_elem) {
1725            Ok(Some(true)) => return Some(Ok(size)),
1726            Ok(Some(false)) => continue,
1727            Ok(None) => return None,
1728            Err(e) => return Some(Err(e)),
1729        }
1730    }
1731    Some(Ok(usize::MAX))
1732}
1733
1734fn compare_reduction(
1735    provider: &'static dyn AccelProvider,
1736    elements: usize,
1737    cpu_cost_slot: &mut f64,
1738) -> Result<Option<bool>> {
1739    let shape = vec![elements, 1];
1740    let template = match provider.precision() {
1741        ProviderPrecision::F64 => {
1742            Tensor::new((0..elements).map(|i| i as f64).collect(), shape.clone())
1743                .map_err(|e| anyhow!(e))?
1744        }
1745        ProviderPrecision::F32 => {
1746            Tensor::from_f32((0..elements).map(|i| i as f32).collect(), shape.clone())
1747                .map_err(|e| anyhow!(e))?
1748        }
1749    };
1750    let value = Value::Tensor(template.clone());
1751    let cpu_time = time(|| runmat_runtime::call_builtin("sum", std::slice::from_ref(&value)))?;
1752    let cpu_per_elem = cpu_time.as_secs_f64() / elements as f64;
1753    update_cpu_cost(cpu_cost_slot, cpu_per_elem);
1754    if let Some(model) = profile_cost_model() {
1755        if let Some(gpu_time) = model.estimate_reduction(elements) {
1756            trace!(
1757                "Reduction calibration ({} elems): cpu={:?}, gpu_est={:?}",
1758                elements,
1759                cpu_time,
1760                gpu_time
1761            );
1762            return Ok(Some(gpu_time < cpu_time));
1763        }
1764    }
1765    let view = HostTensorView {
1766        data: template.data.as_slice(),
1767        shape: template.shape.as_slice(),
1768    };
1769    let h = provider.upload(&view).map_err(|e| anyhow!(e.to_string()))?;
1770    let start = Instant::now();
1771    let out = match futures::executor::block_on(provider.reduce_sum(&h)) {
1772        Ok(hc) => hc,
1773        Err(_) => {
1774            provider.free(&h).ok();
1775            return Ok(None);
1776        }
1777    };
1778    let gpu_time = start.elapsed();
1779    let _ = provider.free(&h);
1780    let _ = provider.free(&out);
1781    Ok(Some(gpu_time < cpu_time))
1782}
1783
1784fn calibrate_matmul(
1785    provider: &'static dyn AccelProvider,
1786    cfg: &mut ThresholdConfig,
1787) -> Option<Result<usize>> {
1788    let dims = [32usize, 64, 96, 128, 192];
1789    for n in dims {
1790        match compare_matmul(provider, n, &mut cfg.cpu_matmul_per_flop) {
1791            Ok(Some(true)) => {
1792                let flops = n * n * n;
1793                return Some(Ok(flops));
1794            }
1795            Ok(Some(false)) => continue,
1796            Ok(None) => return None,
1797            Err(e) => return Some(Err(e)),
1798        }
1799    }
1800    Some(Ok(usize::MAX))
1801}
1802
1803fn compare_matmul(
1804    provider: &'static dyn AccelProvider,
1805    n: usize,
1806    cpu_cost_slot: &mut f64,
1807) -> Result<Option<bool>> {
1808    if n == 0 {
1809        return Ok(Some(false));
1810    }
1811    let total = n * n;
1812    let shape = vec![n, n];
1813    let (ta, tb) = match provider.precision() {
1814        ProviderPrecision::F64 => {
1815            let data_a: Vec<f64> = (0..total).map(|i| (i % 13) as f64).collect();
1816            let data_b: Vec<f64> = (0..total).map(|i| (i % 7) as f64).collect();
1817            let ta = Tensor::new(data_a, shape.clone()).map_err(|e| anyhow!(e))?;
1818            let tb = Tensor::new(data_b, shape.clone()).map_err(|e| anyhow!(e))?;
1819            (ta, tb)
1820        }
1821        ProviderPrecision::F32 => {
1822            let data_a: Vec<f32> = (0..total).map(|i| (i % 13) as f32).collect();
1823            let data_b: Vec<f32> = (0..total).map(|i| (i % 7) as f32).collect();
1824            let ta = Tensor::from_f32(data_a, shape.clone()).map_err(|e| anyhow!(e))?;
1825            let tb = Tensor::from_f32(data_b, shape.clone()).map_err(|e| anyhow!(e))?;
1826            (ta, tb)
1827        }
1828    };
1829    let a = Value::Tensor(ta.clone());
1830    let b = Value::Tensor(tb.clone());
1831    let cpu_time = time(|| futures::executor::block_on(runmat_runtime::value_matmul(&a, &b)))?;
1832    let flops = (n * n * n) as f64;
1833    update_cpu_cost(cpu_cost_slot, cpu_time.as_secs_f64() / flops);
1834    if let Some(model) = profile_cost_model() {
1835        if let Some(gpu_time) = model.estimate_matmul(n, n, n) {
1836            trace!(
1837                "Matmul calibration ({}^3 flops): cpu={:?}, gpu_est={:?}",
1838                n,
1839                cpu_time,
1840                gpu_time
1841            );
1842            return Ok(Some(gpu_time < cpu_time));
1843        }
1844    }
1845    let view_a = HostTensorView {
1846        data: ta.data.as_slice(),
1847        shape: ta.shape.as_slice(),
1848    };
1849    let view_b = HostTensorView {
1850        data: tb.data.as_slice(),
1851        shape: tb.shape.as_slice(),
1852    };
1853    let ha = provider
1854        .upload(&view_a)
1855        .map_err(|e| anyhow!(e.to_string()))?;
1856    let hb = provider
1857        .upload(&view_b)
1858        .map_err(|e| anyhow!(e.to_string()))?;
1859    let start = Instant::now();
1860    let hc = match futures::executor::block_on(provider.matmul(&ha, &hb)) {
1861        Ok(h) => h,
1862        Err(_) => {
1863            let _ = provider.free(&ha);
1864            let _ = provider.free(&hb);
1865            return Ok(None);
1866        }
1867    };
1868    let gpu_time = start.elapsed();
1869    let _ = provider.free(&ha);
1870    let _ = provider.free(&hb);
1871    let _ = provider.free(&hc);
1872    Ok(Some(gpu_time < cpu_time))
1873}
1874
1875fn time<F, T>(mut f: F) -> Result<Duration>
1876where
1877    F: FnMut() -> runmat_runtime::BuiltinResult<T>,
1878{
1879    let start = Instant::now();
1880    let _ = f().map_err(|err| anyhow!(err))?;
1881    Ok(start.elapsed())
1882}
1883
1884pub fn auto_offload_report() -> Option<AutoOffloadReport> {
1885    let state_guard = AUTO_STATE.get()?;
1886    let state = state_guard.lock().ok()?;
1887    let calibration = state.previous_thresholds.as_ref().map(|prev| {
1888        let delta = state
1889            .calibration_delta
1890            .clone()
1891            .unwrap_or_else(|| compute_delta(prev, &state.thresholds));
1892        AutoOffloadCalibrationSummary {
1893            previous: threshold_snapshot(prev),
1894            delta,
1895        }
1896    });
1897    Some(AutoOffloadReport {
1898        provider: state.provider.clone(),
1899        thresholds: threshold_snapshot(&state.thresholds),
1900        base_source: state.base_source,
1901        env_overrides_applied: state.env_overrides_applied,
1902        cache_path: state.cache_path.clone(),
1903        calibrate_duration_ms: state.calibrate_duration_ms,
1904        calibration,
1905        decisions: snapshot_decisions(),
1906    })
1907}
1908
1909pub fn sequence_threshold_hint() -> Option<usize> {
1910    AUTO_STATE
1911        .get()
1912        .and_then(|state| state.lock().ok())
1913        .map(|state| state.thresholds.unary_min_elems)
1914}
1915
1916pub fn reset_auto_offload_log() {
1917    clear_decisions();
1918}
1919
1920#[derive(Clone, Deserialize, Debug)]
1921struct ProfileDurationSummary {
1922    #[serde(default)]
1923    avg_ms: f64,
1924}
1925
1926#[derive(Clone, Deserialize, Debug)]
1927struct ProfileReport {
1928    category: String,
1929    #[serde(default)]
1930    input_shapes: Vec<Vec<usize>>,
1931    total_ms: ProfileDurationSummary,
1932}
1933
1934#[derive(Clone, Copy, Default, Debug)]
1935struct LinearModel {
1936    slope: f64,
1937    intercept: f64,
1938}
1939
1940impl LinearModel {
1941    fn estimate(&self, x: f64) -> Option<Duration> {
1942        if !self.slope.is_finite() || self.slope <= 0.0 {
1943            return None;
1944        }
1945        let total = self.intercept + self.slope * x;
1946        if total.is_finite() && total > 0.0 {
1947            Some(Duration::from_secs_f64(total))
1948        } else {
1949            None
1950        }
1951    }
1952}
1953
1954#[derive(Default)]
1955struct ProfileCostModel {
1956    elem: Option<LinearModel>,
1957    reduction: Option<LinearModel>,
1958    transpose: Option<LinearModel>,
1959    matmul: Option<LinearModel>,
1960}
1961
1962impl ProfileCostModel {
1963    fn from_reports(reports: &[ProfileReport]) -> Self {
1964        let mut elem_samples = Vec::<(f64, f64)>::new();
1965        let mut reduction_samples = Vec::<(f64, f64)>::new();
1966        let mut transpose_samples = Vec::<(f64, f64)>::new();
1967        let mut matmul_samples = Vec::<(f64, f64)>::new();
1968
1969        for report in reports {
1970            let total_secs = report.total_ms.avg_ms / 1_000.0;
1971            match report.category.as_str() {
1972                "elementwise" | "reduction" | "transpose" => {
1973                    if let Some(shape) = report.input_shapes.first() {
1974                        let elems: usize = shape.iter().copied().product();
1975                        if elems == 0 {
1976                            continue;
1977                        }
1978                        let sample = (elems as f64, total_secs);
1979                        match report.category.as_str() {
1980                            "elementwise" => elem_samples.push(sample),
1981                            "reduction" => reduction_samples.push(sample),
1982                            "transpose" => transpose_samples.push(sample),
1983                            _ => {}
1984                        }
1985                    }
1986                }
1987                "matmul" => {
1988                    if report.input_shapes.len() >= 2 {
1989                        let a = &report.input_shapes[0];
1990                        let b = &report.input_shapes[1];
1991                        if a.len() == 2 && b.len() == 2 {
1992                            let m = a[0];
1993                            let k = a[1];
1994                            let n = b[1];
1995                            let flops = m.checked_mul(k).and_then(|val| val.checked_mul(n));
1996                            if let Some(flops) = flops {
1997                                matmul_samples.push((flops as f64, total_secs));
1998                            }
1999                        }
2000                    }
2001                }
2002                _ => {}
2003            }
2004        }
2005
2006        ProfileCostModel {
2007            elem: fit_linear_model(&elem_samples),
2008            reduction: fit_linear_model(&reduction_samples),
2009            transpose: fit_linear_model(&transpose_samples),
2010            matmul: fit_linear_model(&matmul_samples),
2011        }
2012    }
2013
2014    fn estimate_elemwise(&self, elements: usize) -> Option<Duration> {
2015        self.elem.and_then(|model| model.estimate(elements as f64))
2016    }
2017
2018    fn estimate_reduction(&self, elements: usize) -> Option<Duration> {
2019        self.reduction
2020            .and_then(|model| model.estimate(elements as f64))
2021    }
2022
2023    fn estimate_matmul(&self, m: usize, k: usize, n: usize) -> Option<Duration> {
2024        let flops = m.checked_mul(k)?.checked_mul(n)?;
2025        self.matmul.and_then(|model| model.estimate(flops as f64))
2026    }
2027
2028    fn estimate_matmul_flops(&self, flops: usize) -> Option<Duration> {
2029        self.matmul.and_then(|model| model.estimate(flops as f64))
2030    }
2031
2032    fn estimate_transpose(&self, elements: usize) -> Option<Duration> {
2033        self.transpose
2034            .and_then(|model| model.estimate(elements as f64))
2035    }
2036}
2037
2038fn fit_linear_model(samples: &[(f64, f64)]) -> Option<LinearModel> {
2039    if samples.is_empty() {
2040        return None;
2041    }
2042    if samples.len() == 1 {
2043        let (x, y) = samples[0];
2044        if x > 0.0 {
2045            return Some(LinearModel {
2046                slope: (y / x).max(0.0),
2047                intercept: 0.0,
2048            });
2049        }
2050        return None;
2051    }
2052
2053    let sum_x: f64 = samples.iter().map(|(x, _)| *x).sum();
2054    let sum_y: f64 = samples.iter().map(|(_, y)| *y).sum();
2055    let sum_xx: f64 = samples.iter().map(|(x, _)| x * x).sum();
2056    let sum_xy: f64 = samples.iter().map(|(x, y)| x * y).sum();
2057    let n = samples.len() as f64;
2058    let denom = (n * sum_xx) - (sum_x * sum_x);
2059    if denom.abs() < f64::EPSILON {
2060        return None;
2061    }
2062    let slope = ((n * sum_xy) - (sum_x * sum_y)) / denom;
2063    let mean_x = sum_x / n;
2064    let mean_y = sum_y / n;
2065    let mut intercept = mean_y - slope * mean_x;
2066    if intercept < 0.0 {
2067        intercept = 0.0;
2068    }
2069    if !slope.is_finite() || slope <= 0.0 {
2070        return None;
2071    }
2072    Some(LinearModel { slope, intercept })
2073}
2074
2075fn profile_cost_model() -> Option<&'static ProfileCostModel> {
2076    PROFILE_MODEL.get_or_init(load_profile_cost_model).as_ref()
2077}
2078
2079fn load_profile_cost_model() -> Option<ProfileCostModel> {
2080    let mut candidates = Vec::new();
2081    if let Ok(path) = env::var("RUNMAT_ACCEL_PROFILE") {
2082        candidates.push(PathBuf::from(path));
2083    }
2084    if let Some(path) = auto_offload_options().profile_path.clone() {
2085        candidates.push(path);
2086    }
2087    candidates.push(PathBuf::from("benchmarks/wgpu_profile/mac_m2.json"));
2088    candidates.push(PathBuf::from("wgpu_profile.json"));
2089
2090    for path in candidates {
2091        if !path.exists() {
2092            continue;
2093        }
2094        match fs::read_to_string(&path) {
2095            Ok(contents) => match serde_json::from_str::<Vec<ProfileReport>>(&contents) {
2096                Ok(reports) => {
2097                    debug!(
2098                        "Loaded {} GPU profile reports from {}",
2099                        reports.len(),
2100                        path.display()
2101                    );
2102                    return Some(ProfileCostModel::from_reports(&reports));
2103                }
2104                Err(err) => {
2105                    debug!("Failed to parse GPU profile {}: {err}", path.display());
2106                }
2107            },
2108            Err(err) => {
2109                debug!("Failed to read GPU profile {}: {err}", path.display());
2110            }
2111        }
2112    }
2113    None
2114}
2115
2116pub async fn promote_binary(op: BinaryOp, a: &Value, b: &Value) -> Result<(Value, Value)> {
2117    if !auto_enabled() {
2118        return Ok((a.clone(), b.clone()));
2119    }
2120    if let Some(auto) = global().await {
2121        auto.promote_binary(op, a, b)
2122    } else {
2123        Ok((a.clone(), b.clone()))
2124    }
2125}
2126
2127pub async fn promote_unary(op: UnaryOp, value: &Value) -> Result<Value> {
2128    if !auto_enabled() {
2129        return Ok(value.clone());
2130    }
2131    if let Some(auto) = global().await {
2132        auto.promote_unary(op, value)
2133    } else {
2134        Ok(value.clone())
2135    }
2136}
2137
2138pub async fn prepare_builtin_args(name: &str, args: &[Value]) -> Result<Vec<Value>> {
2139    if let Some(policy) = builtin_policy(name) {
2140        if policy.is_sink {
2141            clear_sink_inputs(args);
2142            if should_gather_sink_args(name) {
2143                trace!(
2144                    "auto-offload: prepare_builtin_args(name={:?}) is_sink=true residency=GatherImmediately -> gathering {} arg(s)",
2145                    name,
2146                    args.len()
2147                );
2148                return gather_args(args).await;
2149            }
2150            trace!(
2151                "auto-offload: prepare_builtin_args(name={:?}) is_sink=true residency!=GatherImmediately -> no gather (fusion barrier only)",
2152                name
2153            );
2154            return Ok(args.to_vec());
2155        }
2156    }
2157    if !auto_enabled() {
2158        return Ok(args.to_vec());
2159    }
2160    if let Some(auto) = global().await {
2161        auto.prepare_builtin(name, args).await
2162    } else {
2163        Ok(args.to_vec())
2164    }
2165}
2166
2167pub fn is_sink(name: &str) -> bool {
2168    builtin_policy(name).map(|p| p.is_sink).unwrap_or(false)
2169}
2170
2171pub async fn promote_reduction_args(op: ReductionOp, args: &[Value]) -> Result<Vec<Value>> {
2172    if !auto_enabled() {
2173        return Ok(args.to_vec());
2174    }
2175    if let Some(auto) = global().await {
2176        auto.promote_reduction(op, args)
2177    } else {
2178        Ok(args.to_vec())
2179    }
2180}