Skip to main content

runmat_accelerate/
native_auto.rs

1use runmat_time::{system_time_now, Instant};
2use std::collections::HashMap;
3use std::env;
4use std::fs;
5use std::path::{Path, PathBuf};
6use std::sync::Mutex;
7use std::time::{Duration, UNIX_EPOCH};
8
9use crate::{
10    auto_offload_options,
11    fusion::{active_fusion, FusionKind},
12    fusion_residency,
13    precision::ensure_provider_supports_dtype,
14    AutoOffloadLogLevel,
15};
16use anyhow::{anyhow, Result};
17use futures::lock::Mutex as AsyncMutex;
18use log::{debug, info, trace, warn};
19use once_cell::sync::{Lazy, OnceCell};
20use runmat_accelerate_api::{AccelProvider, ApiDeviceInfo, HostTensorView, ProviderPrecision};
21use runmat_builtins::{builtin_functions, AccelTag, Tensor, Value};
22use runmat_runtime::builtins::common::spec::{builtin_residency_policy, ResidencyPolicy};
23use runmat_runtime::gather_if_needed_async;
24use serde::{Deserialize, Serialize};
25
26const DEFAULT_CPU_ELEM_PER_ELEM: f64 = 1.0e-7;
27const DEFAULT_CPU_REDUCTION_PER_ELEM: f64 = 1.2e-7;
28const DEFAULT_CPU_MATMUL_PER_FLOP: f64 = 2.5e-11;
29const SMALL_BATCH_DEFAULT_MAX_DIM: usize = 8;
30const SMALL_BATCH_DEFAULT_MIN_ELEMS: usize = 1_048_576;
31const DECISION_LOG_CAPACITY: usize = 128;
32const CALIBRATION_VERSION: u32 = 1;
33
34#[derive(Clone, Copy, Debug)]
35pub enum BinaryOp {
36    Elementwise,
37    MatMul,
38}
39
40#[derive(Clone, Copy, Debug)]
41pub enum UnaryOp {
42    Generic,
43    Transpose,
44}
45
46#[derive(Clone, Copy, Debug)]
47pub enum ReductionOp {
48    Sum,
49    Mean,
50    Min,
51    Max,
52}
53
54#[derive(Debug, Clone, Serialize, Deserialize)]
55struct ThresholdConfig {
56    unary_min_elems: usize,
57    binary_min_elems: usize,
58    reduction_min_elems: usize,
59    matmul_min_flops: usize,
60    cpu_elem_per_elem: f64,
61    cpu_reduction_per_elem: f64,
62    cpu_matmul_per_flop: f64,
63    small_batch_max_dim: usize,
64    small_batch_min_elems: usize,
65}
66
67impl Default for ThresholdConfig {
68    fn default() -> Self {
69        Self {
70            unary_min_elems: 4_096,
71            binary_min_elems: 4_096,
72            reduction_min_elems: 256,
73            matmul_min_flops: 1_000_000, // roughly 100x100x100
74            cpu_elem_per_elem: DEFAULT_CPU_ELEM_PER_ELEM,
75            cpu_reduction_per_elem: DEFAULT_CPU_REDUCTION_PER_ELEM,
76            cpu_matmul_per_flop: DEFAULT_CPU_MATMUL_PER_FLOP,
77            small_batch_max_dim: SMALL_BATCH_DEFAULT_MAX_DIM,
78            small_batch_min_elems: SMALL_BATCH_DEFAULT_MIN_ELEMS,
79        }
80    }
81}
82
83#[derive(Debug, Clone, Serialize)]
84pub struct AutoOffloadDecisionEntry {
85    pub timestamp_ms: u128,
86    pub operation: String,
87    pub elements: Option<usize>,
88    pub flops: Option<usize>,
89    pub batch: Option<usize>,
90    pub decision: AutoOffloadDisposition,
91    pub reason: DecisionReason,
92    pub cpu_estimate_ms: Option<f64>,
93    pub gpu_estimate_ms: Option<f64>,
94    pub threshold: Option<usize>,
95    pub fusion_kind: Option<FusionKind>,
96}
97
98#[derive(Debug, Clone, Copy, Serialize, PartialEq, Eq)]
99#[serde(rename_all = "kebab-case")]
100pub enum AutoOffloadDisposition {
101    Gpu,
102    Cpu,
103}
104
105#[derive(Debug, Clone, Copy, Serialize, PartialEq, Eq)]
106#[serde(rename_all = "kebab-case")]
107pub enum DecisionReason {
108    FusionOverride,
109    Residency,
110    SmallBatchGuard,
111    ProfileModel,
112    Threshold,
113    Disabled,
114}
115
116#[derive(Debug, Clone, Serialize)]
117pub struct ThresholdSnapshot {
118    pub unary_min_elems: usize,
119    pub binary_min_elems: usize,
120    pub reduction_min_elems: usize,
121    pub matmul_min_flops: usize,
122    pub cpu_elem_per_elem: f64,
123    pub cpu_reduction_per_elem: f64,
124    pub cpu_matmul_per_flop: f64,
125    pub small_batch_max_dim: usize,
126    pub small_batch_min_elems: usize,
127}
128
129#[derive(Debug, Clone, Serialize)]
130pub struct AutoOffloadCalibrationSummary {
131    pub previous: ThresholdSnapshot,
132    pub delta: ThresholdDelta,
133}
134
135#[derive(Debug, Clone, Serialize, Default)]
136pub struct ThresholdDelta {
137    #[serde(skip_serializing_if = "Option::is_none")]
138    pub cpu_elem_per_elem: Option<ThresholdDeltaEntry>,
139    #[serde(skip_serializing_if = "Option::is_none")]
140    pub cpu_reduction_per_elem: Option<ThresholdDeltaEntry>,
141    #[serde(skip_serializing_if = "Option::is_none")]
142    pub cpu_matmul_per_flop: Option<ThresholdDeltaEntry>,
143}
144
145#[derive(Debug, Clone, Serialize)]
146pub struct ThresholdDeltaEntry {
147    pub before: f64,
148    pub after: f64,
149    pub absolute: f64,
150    #[serde(skip_serializing_if = "Option::is_none")]
151    pub ratio: Option<f64>,
152}
153
154impl ThresholdDeltaEntry {
155    fn new(before: f64, after: f64) -> Self {
156        let absolute = after - before;
157        let ratio = if before.abs() > f64::EPSILON {
158            Some(after / before)
159        } else {
160            None
161        };
162        Self {
163            before,
164            after,
165            absolute,
166            ratio,
167        }
168    }
169}
170
171#[derive(Debug, Clone, Serialize)]
172pub struct AutoOffloadReport {
173    pub provider: Option<CachedProviderInfo>,
174    pub thresholds: ThresholdSnapshot,
175    pub base_source: ThresholdBase,
176    pub env_overrides_applied: bool,
177    pub cache_path: Option<String>,
178    pub calibrate_duration_ms: Option<u128>,
179    #[serde(skip_serializing_if = "Option::is_none")]
180    pub calibration: Option<AutoOffloadCalibrationSummary>,
181    pub decisions: Vec<AutoOffloadDecisionEntry>,
182}
183
184#[derive(Debug, Clone, Copy, Serialize, PartialEq, Eq)]
185#[serde(rename_all = "kebab-case")]
186pub enum ThresholdBase {
187    BuiltInDefault,
188    LoadedFromCache,
189    Calibrated,
190}
191
192impl ThresholdBase {
193    pub fn as_str(&self) -> &'static str {
194        match self {
195            ThresholdBase::BuiltInDefault => "built-in-default",
196            ThresholdBase::LoadedFromCache => "loaded-from-cache",
197            ThresholdBase::Calibrated => "calibrated",
198        }
199    }
200}
201
202#[derive(Debug, Clone, Serialize)]
203pub struct CachedProviderInfo {
204    pub name: String,
205    pub vendor: String,
206    pub backend: Option<String>,
207    pub device_id: u32,
208}
209
210#[derive(Debug, Clone)]
211struct AutoOffloadState {
212    provider: Option<CachedProviderInfo>,
213    thresholds: ThresholdConfig,
214    base_source: ThresholdBase,
215    env_overrides_applied: bool,
216    cache_path: Option<String>,
217    calibrate_duration_ms: Option<u128>,
218    previous_thresholds: Option<ThresholdConfig>,
219    calibration_delta: Option<ThresholdDelta>,
220}
221
222#[derive(Clone)]
223struct DecisionEvaluation {
224    recommend_gpu: bool,
225    reason: DecisionReason,
226    cpu_secs: Option<f64>,
227    gpu_secs: Option<f64>,
228    threshold: Option<usize>,
229    fusion_kind: Option<FusionKind>,
230    batch: Option<usize>,
231}
232
233struct DecisionLog {
234    entries: Vec<AutoOffloadDecisionEntry>,
235}
236
237impl DecisionLog {
238    fn new() -> Self {
239        Self {
240            entries: Vec::new(),
241        }
242    }
243
244    fn push(&mut self, entry: AutoOffloadDecisionEntry) {
245        self.entries.push(entry);
246        if self.entries.len() > DECISION_LOG_CAPACITY {
247            let overflow = self.entries.len() - DECISION_LOG_CAPACITY;
248            self.entries.drain(0..overflow);
249        }
250    }
251
252    fn snapshot(&self) -> Vec<AutoOffloadDecisionEntry> {
253        self.entries.clone()
254    }
255
256    fn clear(&mut self) {
257        self.entries.clear();
258    }
259}
260
261static DECISION_LOG: Lazy<Mutex<DecisionLog>> = Lazy::new(|| Mutex::new(DecisionLog::new()));
262static AUTO_STATE: OnceCell<Mutex<AutoOffloadState>> = OnceCell::new();
263
264fn record_decision(entry: AutoOffloadDecisionEntry) {
265    if let Ok(mut log) = DECISION_LOG.lock() {
266        log.push(entry);
267    }
268}
269
270fn snapshot_decisions() -> Vec<AutoOffloadDecisionEntry> {
271    DECISION_LOG
272        .lock()
273        .map(|log| log.snapshot())
274        .unwrap_or_default()
275}
276
277fn clear_decisions() {
278    if let Ok(mut log) = DECISION_LOG.lock() {
279        log.clear();
280    }
281}
282
283fn now_millis() -> u128 {
284    system_time_now()
285        .duration_since(UNIX_EPOCH)
286        .unwrap_or_else(|_| Duration::from_secs(0))
287        .as_millis()
288}
289
290fn threshold_snapshot(cfg: &ThresholdConfig) -> ThresholdSnapshot {
291    ThresholdSnapshot {
292        unary_min_elems: cfg.unary_min_elems,
293        binary_min_elems: cfg.binary_min_elems,
294        reduction_min_elems: cfg.reduction_min_elems,
295        matmul_min_flops: cfg.matmul_min_flops,
296        cpu_elem_per_elem: cfg.cpu_elem_per_elem,
297        cpu_reduction_per_elem: cfg.cpu_reduction_per_elem,
298        cpu_matmul_per_flop: cfg.cpu_matmul_per_flop,
299        small_batch_max_dim: cfg.small_batch_max_dim,
300        small_batch_min_elems: cfg.small_batch_min_elems,
301    }
302}
303
304fn compute_delta(before: &ThresholdConfig, after: &ThresholdConfig) -> ThresholdDelta {
305    let mut delta = ThresholdDelta::default();
306
307    if (before.cpu_elem_per_elem - after.cpu_elem_per_elem).abs() > f64::EPSILON {
308        delta.cpu_elem_per_elem = Some(ThresholdDeltaEntry::new(
309            before.cpu_elem_per_elem,
310            after.cpu_elem_per_elem,
311        ));
312    }
313
314    if (before.cpu_reduction_per_elem - after.cpu_reduction_per_elem).abs() > f64::EPSILON {
315        delta.cpu_reduction_per_elem = Some(ThresholdDeltaEntry::new(
316            before.cpu_reduction_per_elem,
317            after.cpu_reduction_per_elem,
318        ));
319    }
320
321    if (before.cpu_matmul_per_flop - after.cpu_matmul_per_flop).abs() > f64::EPSILON {
322        delta.cpu_matmul_per_flop = Some(ThresholdDeltaEntry::new(
323            before.cpu_matmul_per_flop,
324            after.cpu_matmul_per_flop,
325        ));
326    }
327
328    delta
329}
330
331#[derive(Debug, Deserialize)]
332struct CalibrationFile {
333    #[serde(default)]
334    suite: Option<CalibrationSuiteSection>,
335    #[serde(default)]
336    auto_offload_calibration: Option<CalibrationSample>,
337}
338
339#[derive(Debug, Deserialize)]
340struct CalibrationSuiteSection {
341    #[serde(default)]
342    auto_offload_calibration: Option<CalibrationSample>,
343}
344
345#[derive(Debug, Clone, Deserialize)]
346struct CalibrationSample {
347    #[serde(default)]
348    runs: usize,
349    #[serde(default, rename = "cpu_time_ms")]
350    cpu_time: CalibrationTimes,
351    #[serde(default)]
352    units: CalibrationUnits,
353    #[serde(default)]
354    provider: Option<CalibrationProviderInfo>,
355    #[serde(default)]
356    provider_conflict: bool,
357}
358
359#[derive(Debug, Clone, Deserialize, Default)]
360struct CalibrationTimes {
361    #[serde(default)]
362    elementwise: f64,
363    #[serde(default)]
364    reduction: f64,
365    #[serde(default)]
366    matmul: f64,
367}
368
369#[derive(Debug, Clone, Deserialize, Default)]
370struct CalibrationUnits {
371    #[serde(default)]
372    elementwise: f64,
373    #[serde(default)]
374    reduction: f64,
375    #[serde(default, rename = "matmul_flops")]
376    matmul_flops: f64,
377}
378
379#[derive(Debug, Clone, Deserialize)]
380struct CalibrationProviderInfo {
381    name: String,
382    vendor: String,
383    #[serde(default)]
384    backend: Option<String>,
385    device_id: u32,
386}
387
388#[derive(Debug, Serialize)]
389pub struct AutoOffloadCalibrationOutcome {
390    pub runs: usize,
391    pub before: ThresholdSnapshot,
392    pub after: ThresholdSnapshot,
393    #[serde(skip_serializing_if = "Option::is_none")]
394    pub delta: Option<ThresholdDelta>,
395    #[serde(skip_serializing_if = "Option::is_none")]
396    pub persisted_to: Option<String>,
397    #[serde(skip_serializing_if = "Option::is_none")]
398    pub provider: Option<CachedProviderInfo>,
399    pub commit: bool,
400}
401
402fn load_calibration_sample(path: &Path) -> Result<CalibrationSample> {
403    let payload = fs::read_to_string(path).map_err(|e| anyhow!(e.to_string()))?;
404    let file: CalibrationFile = serde_json::from_str(&payload)
405        .map_err(|e| anyhow!(format!("failed to parse calibration file: {e}")))?;
406    if let Some(suite) = file.suite {
407        if let Some(sample) = suite.auto_offload_calibration {
408            return Ok(sample);
409        }
410    }
411    if let Some(sample) = file.auto_offload_calibration {
412        return Ok(sample);
413    }
414    Err(anyhow!(
415        "calibration file does not contain an auto_offload_calibration section"
416    ))
417}
418
419fn apply_calibration_sample(
420    cfg: &mut ThresholdConfig,
421    sample: &CalibrationSample,
422) -> Option<ThresholdDelta> {
423    let mut delta = ThresholdDelta::default();
424    let mut changed = false;
425
426    if sample.units.elementwise > 0.0 && sample.cpu_time.elementwise > 0.0 {
427        let secs_per_elem = (sample.cpu_time.elementwise / 1_000.0) / sample.units.elementwise;
428        if secs_per_elem.is_finite()
429            && secs_per_elem > 0.0
430            && (cfg.cpu_elem_per_elem - secs_per_elem).abs() > f64::EPSILON
431        {
432            delta.cpu_elem_per_elem = Some(ThresholdDeltaEntry::new(
433                cfg.cpu_elem_per_elem,
434                secs_per_elem,
435            ));
436            cfg.cpu_elem_per_elem = secs_per_elem;
437            changed = true;
438        }
439    }
440
441    if sample.units.reduction > 0.0 && sample.cpu_time.reduction > 0.0 {
442        let secs_per_elem = (sample.cpu_time.reduction / 1_000.0) / sample.units.reduction;
443        if secs_per_elem.is_finite()
444            && secs_per_elem > 0.0
445            && (cfg.cpu_reduction_per_elem - secs_per_elem).abs() > f64::EPSILON
446        {
447            delta.cpu_reduction_per_elem = Some(ThresholdDeltaEntry::new(
448                cfg.cpu_reduction_per_elem,
449                secs_per_elem,
450            ));
451            cfg.cpu_reduction_per_elem = secs_per_elem;
452            changed = true;
453        }
454    }
455
456    if sample.units.matmul_flops > 0.0 && sample.cpu_time.matmul > 0.0 {
457        let secs_per_flop = (sample.cpu_time.matmul / 1_000.0) / sample.units.matmul_flops;
458        if secs_per_flop.is_finite()
459            && secs_per_flop > 0.0
460            && (cfg.cpu_matmul_per_flop - secs_per_flop).abs() > f64::EPSILON
461        {
462            delta.cpu_matmul_per_flop = Some(ThresholdDeltaEntry::new(
463                cfg.cpu_matmul_per_flop,
464                secs_per_flop,
465            ));
466            cfg.cpu_matmul_per_flop = secs_per_flop;
467            changed = true;
468        }
469    }
470
471    if changed {
472        Some(delta)
473    } else {
474        None
475    }
476}
477
478pub fn apply_auto_offload_calibration_from_file(
479    path: &Path,
480    commit: bool,
481) -> Result<AutoOffloadCalibrationOutcome> {
482    let sample = load_calibration_sample(path)?;
483    if sample.runs == 0 {
484        return Err(anyhow!("calibration sample contains zero runs"));
485    }
486
487    let provider = runmat_accelerate_api::provider()
488        .ok_or_else(|| anyhow!("no acceleration provider registered"))?;
489    let device_info = provider.device_info_struct();
490
491    if let Some(ref prov) = sample.provider {
492        if prov.name != device_info.name
493            || prov.vendor != device_info.vendor
494            || prov.backend.as_deref() != device_info.backend.as_deref()
495            || prov.device_id != device_info.device_id
496        {
497            warn!(
498                "Calibration provider mismatch: sample='{} ({})' device='{} ({})'",
499                prov.name, prov.vendor, device_info.name, device_info.vendor
500            );
501        }
502        if sample.provider_conflict {
503            warn!("Calibration sample reported provider conflict across cases");
504        }
505    }
506
507    let (mut cfg, _) = load_cached_thresholds(&device_info)
508        .unwrap_or_else(|| (ThresholdConfig::default(), PathBuf::new()));
509    let before_cfg = cfg.clone();
510
511    let delta = apply_calibration_sample(&mut cfg, &sample)
512        .ok_or_else(|| anyhow!("calibration sample did not produce coefficient updates"))?;
513
514    let mut persisted_to: Option<PathBuf> = None;
515    if commit {
516        persisted_to = Some(persist_thresholds(&device_info, &cfg)?);
517    }
518
519    if let Some(state_mutex) = AUTO_STATE.get() {
520        if let Ok(mut state) = state_mutex.lock() {
521            state.previous_thresholds = Some(before_cfg.clone());
522            state.calibration_delta = Some(delta.clone());
523            if commit {
524                state.thresholds = cfg.clone();
525                state.base_source = ThresholdBase::Calibrated;
526                if let Some(ref path_buf) = persisted_to {
527                    state.cache_path = Some(path_buf.to_string_lossy().into_owned());
528                }
529                state.calibrate_duration_ms = None;
530            }
531        }
532    }
533
534    Ok(AutoOffloadCalibrationOutcome {
535        runs: sample.runs,
536        before: threshold_snapshot(&before_cfg),
537        after: threshold_snapshot(&cfg),
538        delta: Some(delta),
539        persisted_to: persisted_to.map(|p| p.to_string_lossy().into_owned()),
540        provider: Some(cached_provider_info(&device_info)),
541        commit,
542    })
543}
544
545fn cached_provider_info(info: &ApiDeviceInfo) -> CachedProviderInfo {
546    CachedProviderInfo {
547        name: info.name.clone(),
548        vendor: info.vendor.clone(),
549        backend: info.backend.clone(),
550        device_id: info.device_id,
551    }
552}
553
554fn cpu_estimate(per_unit: f64, units: usize) -> Option<f64> {
555    if per_unit.is_finite() && per_unit > 0.0 {
556        Some(per_unit * units as f64)
557    } else {
558        None
559    }
560}
561
562fn value_shape(value: &Value) -> Option<&[usize]> {
563    match value {
564        Value::Tensor(t) => Some(&t.shape),
565        Value::GpuTensor(handle) => Some(&handle.shape),
566        _ => None,
567    }
568}
569
570fn batch_dimension_from_value(value: &Value) -> Option<usize> {
571    let shape = value_shape(value)?;
572    if shape.len() < 3 {
573        return None;
574    }
575    shape.last().copied()
576}
577
578fn batch_dimension_from_values(values: &[&Value]) -> Option<usize> {
579    values
580        .iter()
581        .filter_map(|value| batch_dimension_from_value(value))
582        .min()
583}
584
585fn decision_entry(
586    operation: &str,
587    elements: Option<usize>,
588    flops: Option<usize>,
589    eval: &DecisionEvaluation,
590) -> AutoOffloadDecisionEntry {
591    AutoOffloadDecisionEntry {
592        timestamp_ms: now_millis(),
593        operation: operation.to_string(),
594        elements,
595        flops,
596        batch: eval.batch,
597        decision: if eval.recommend_gpu {
598            AutoOffloadDisposition::Gpu
599        } else {
600            AutoOffloadDisposition::Cpu
601        },
602        reason: eval.reason,
603        cpu_estimate_ms: eval.cpu_secs.map(|secs| secs * 1_000.0),
604        gpu_estimate_ms: eval.gpu_secs.map(|secs| secs * 1_000.0),
605        threshold: eval.threshold,
606        fusion_kind: eval.fusion_kind.clone(),
607    }
608}
609
610pub struct NativeAutoOffload {
611    provider: &'static dyn AccelProvider,
612    thresholds: ThresholdConfig,
613    enabled: bool,
614}
615
616static GLOBAL: OnceCell<Option<NativeAutoOffload>> = OnceCell::new();
617static GLOBAL_INIT_LOCK: Lazy<AsyncMutex<()>> = Lazy::new(|| AsyncMutex::new(()));
618static PROFILE_MODEL: OnceCell<Option<ProfileCostModel>> = OnceCell::new();
619
620fn env_bool(key: &str) -> Option<bool> {
621    env::var(key).ok().and_then(|v| parse_bool(&v))
622}
623
624fn parse_bool(s: &str) -> Option<bool> {
625    match s.trim().to_ascii_lowercase().as_str() {
626        "1" | "true" | "yes" | "on" => Some(true),
627        "0" | "false" | "no" | "off" => Some(false),
628        _ => None,
629    }
630}
631
632fn log_promotion<F>(builder: F)
633where
634    F: FnOnce() -> String,
635{
636    match auto_offload_options().log_level {
637        AutoOffloadLogLevel::Off => {}
638        AutoOffloadLogLevel::Info => info!("{}", builder()),
639        AutoOffloadLogLevel::Trace => trace!("{}", builder()),
640    }
641}
642
643fn update_cpu_cost(slot: &mut f64, candidate: f64) {
644    if candidate.is_finite() && candidate > 0.0 && candidate < *slot {
645        *slot = candidate;
646    }
647}
648
649fn value_len(value: &Value) -> Option<usize> {
650    match value {
651        Value::Tensor(t) => Some(t.data.len()),
652        Value::GpuTensor(handle) => Some(handle.shape.iter().product()),
653        Value::Num(_) | Value::Bool(_) | Value::Int(_) => Some(1),
654        Value::Complex(_, _) => Some(1),
655        _ => None,
656    }
657}
658
659fn element_count_pair(a: &Value, b: &Value) -> Option<usize> {
660    let la = value_len(a)?;
661    let lb = value_len(b)?;
662    Some(la.max(lb))
663}
664
665pub async fn global() -> Option<&'static NativeAutoOffload> {
666    if let Some(existing) = GLOBAL.get() {
667        return existing.as_ref();
668    }
669    // If auto-offload is disabled or there is no GPU provider registered,
670    // initialize_async() would return None immediately (no I/O, no blocking).
671    // Return None directly without acquiring the async lock so single-poll
672    // callers (e.g. the turbine JIT interpreter fallback) never observe a
673    // spurious Pending.  We intentionally do NOT write to GLOBAL here: doing
674    // so without holding GLOBAL_INIT_LOCK would race with a concurrent thread
675    // that is partway through initialize_async() and has found a valid
676    // provider.  That thread's subsequent GLOBAL.set(Some(offload)) would
677    // silently fail (OnceCell is set-once), permanently disabling the
678    // accelerator for the lifetime of the process.  These two checks are
679    // cheap (no I/O), so re-evaluating them on each call is acceptable.
680    if !auto_enabled() || runmat_accelerate_api::provider().is_none() {
681        return None;
682    }
683    let _guard = GLOBAL_INIT_LOCK.lock().await;
684    if let Some(existing) = GLOBAL.get() {
685        return existing.as_ref();
686    }
687    let initialized = initialize_async().await;
688    let _ = GLOBAL.set(initialized);
689    GLOBAL.get().and_then(|value| value.as_ref())
690}
691
692async fn initialize_async() -> Option<NativeAutoOffload> {
693    if !auto_enabled() {
694        clear_decisions();
695        return None;
696    }
697    let provider = runmat_accelerate_api::provider()?;
698    let device_info = provider.device_info_struct();
699    let mut config = ThresholdConfig::default();
700    let mut base_source = ThresholdBase::BuiltInDefault;
701    let mut cache_path: Option<String> = None;
702    let mut calibrate_duration_ms: Option<u128> = None;
703    let refresh_calibration = calibrate_refresh_enabled();
704
705    if !refresh_calibration {
706        if let Some((cached, path)) = load_cached_thresholds_async(&device_info).await {
707            info!(
708                "Native auto-offload: loaded cached calibration for '{}' from {}",
709                device_info.name, path
710            );
711            config = cached;
712            cache_path = Some(path);
713            base_source = ThresholdBase::LoadedFromCache;
714        }
715    }
716
717    let needs_calibration = calibrate_enabled() && (refresh_calibration || cache_path.is_none());
718    if needs_calibration {
719        let start = Instant::now();
720        match auto_calibrate(provider, &mut config) {
721            Ok(()) => {
722                calibrate_duration_ms = Some(start.elapsed().as_millis());
723                base_source = ThresholdBase::Calibrated;
724                match persist_thresholds_async(&device_info, &config).await {
725                    Ok(path) => {
726                        cache_path = Some(path.clone());
727                        info!(
728                            "Native auto-offload: persisted calibration for '{}' to {}",
729                            device_info.name, path
730                        );
731                    }
732                    Err(err) => {
733                        debug!("Native auto-offload: failed to persist calibration: {err}");
734                    }
735                }
736            }
737            Err(err) => {
738                debug!("Native auto-offload calibration failed: {err}");
739            }
740        }
741    }
742
743    let env_overrides_applied = apply_env_overrides(&mut config);
744    let model_status = if profile_cost_model().is_some() {
745        "profile"
746    } else {
747        "fallback"
748    };
749    info!(
750        "Native auto-offload thresholds: unary={} binary={} reduction={} matmul_flops={} small_batch_dim={} small_batch_min_elems={} (model: {}, source: {}, env_overrides={})",
751        config.unary_min_elems,
752        config.binary_min_elems,
753        config.reduction_min_elems,
754        config.matmul_min_flops,
755        config.small_batch_max_dim,
756        config.small_batch_min_elems,
757        model_status,
758        base_source.as_str(),
759        env_overrides_applied
760    );
761
762    let cache_path_str = cache_path.clone();
763    let state = AutoOffloadState {
764        provider: Some(cached_provider_info(&device_info)),
765        thresholds: config.clone(),
766        base_source,
767        env_overrides_applied,
768        cache_path: cache_path_str,
769        calibrate_duration_ms,
770        previous_thresholds: None,
771        calibration_delta: None,
772    };
773    let _ = AUTO_STATE.set(Mutex::new(state));
774
775    Some(NativeAutoOffload::new(provider, config))
776}
777
778impl NativeAutoOffload {
779    fn new(provider: &'static dyn AccelProvider, thresholds: ThresholdConfig) -> Self {
780        let enabled = true;
781        Self {
782            provider,
783            thresholds,
784            enabled,
785        }
786    }
787
788    fn promote_tensor_if_large(&self, value: &Value, threshold: usize) -> Result<Value> {
789        match value {
790            Value::GpuTensor(_) => Ok(value.clone()),
791            Value::Tensor(t) => {
792                if ensure_provider_supports_dtype(self.provider, t.dtype).is_err() {
793                    return Ok(value.clone());
794                }
795                if t.data.len() >= threshold && threshold > 0 {
796                    log_promotion(|| {
797                        format!(
798                            "Promoting tensor to GPU (len={}, threshold={})",
799                            t.data.len(),
800                            threshold
801                        )
802                    });
803                    self.tensor_to_gpu(t)
804                } else {
805                    Ok(value.clone())
806                }
807            }
808            _ => Ok(value.clone()),
809        }
810    }
811
812    fn tensor_to_gpu(&self, tensor: &Tensor) -> Result<Value> {
813        let view = HostTensorView {
814            data: &tensor.data,
815            shape: &tensor.shape,
816        };
817        let handle = self
818            .provider
819            .upload(&view)
820            .map_err(|e| anyhow!(e.to_string()))?;
821        Ok(Value::GpuTensor(handle))
822    }
823
824    fn small_batch_guard(&self, elements: usize, batch: Option<usize>) -> bool {
825        if !self.enabled {
826            return false;
827        }
828        let Some(batch) = batch else {
829            return false;
830        };
831        if batch == 0 {
832            return false;
833        }
834        let thresholds = &self.thresholds;
835        thresholds.small_batch_max_dim > 0
836            && thresholds.small_batch_min_elems > 0
837            && batch <= thresholds.small_batch_max_dim
838            && elements >= thresholds.small_batch_min_elems
839    }
840
841    fn promote_binary(&self, op: BinaryOp, a: &Value, b: &Value) -> Result<(Value, Value)> {
842        if !self.enabled {
843            return Ok((a.clone(), b.clone()));
844        }
845        match op {
846            BinaryOp::Elementwise => {
847                let elems = element_count_pair(a, b).unwrap_or(0);
848                let eval = self.evaluate_elementwise(elems, &[a, b]);
849                record_decision(decision_entry("elementwise", Some(elems), None, &eval));
850                if eval.recommend_gpu {
851                    log_promotion(|| format!("Elementwise offload accepted ({} elems)", elems));
852                    let a_p = self.promote_tensor_if_large(a, 1)?;
853                    let b_p = self.promote_tensor_if_large(b, 1)?;
854                    Ok((a_p, b_p))
855                } else {
856                    Ok((a.clone(), b.clone()))
857                }
858            }
859            BinaryOp::MatMul => {
860                if let (Some((ra, ca)), Some((rb, cb))) = (tensor_rows_cols(a), tensor_rows_cols(b))
861                {
862                    if ca != rb {
863                        return Ok((a.clone(), b.clone()));
864                    }
865                    let flops = ra.saturating_mul(ca).saturating_mul(cb);
866                    let eval = self.evaluate_matmul(flops);
867                    record_decision(decision_entry("matmul", None, Some(flops), &eval));
868                    if eval.recommend_gpu {
869                        log_promotion(|| {
870                            format!(
871                                "Promoting matmul operands (flops={}, threshold={})",
872                                flops, self.thresholds.matmul_min_flops
873                            )
874                        });
875                        let a_p = self.promote_tensor_if_large(a, 1)?;
876                        let b_p = self.promote_tensor_if_large(b, 1)?;
877                        return Ok((a_p, b_p));
878                    }
879                }
880                Ok((a.clone(), b.clone()))
881            }
882        }
883    }
884
885    fn promote_unary(&self, op: UnaryOp, v: &Value) -> Result<Value> {
886        if !self.enabled {
887            return Ok(v.clone());
888        }
889        let elems = value_len(v).unwrap_or(0);
890        let eval = self.evaluate_unary(elems, op, v);
891        let op_label = match op {
892            UnaryOp::Transpose => "transpose",
893            UnaryOp::Generic => "unary",
894        };
895        record_decision(decision_entry(op_label, Some(elems), None, &eval));
896        if eval.recommend_gpu {
897            log_promotion(|| format!("Unary offload accepted ({:?}, {} elems)", op, elems));
898            self.promote_tensor_if_large(v, 1)
899        } else {
900            Ok(v.clone())
901        }
902    }
903
904    fn promote_reduction(&self, _op: ReductionOp, args: &[Value]) -> Result<Vec<Value>> {
905        if !self.enabled || args.is_empty() {
906            return Ok(args.to_vec());
907        }
908        let elems = value_len(&args[0]).unwrap_or(0);
909        let eval = self.evaluate_reduction(elems);
910        record_decision(decision_entry("reduction", Some(elems), None, &eval));
911        if !eval.recommend_gpu {
912            return Ok(args.to_vec());
913        }
914        log_promotion(|| format!("Reduction offload accepted ({} elems)", elems));
915        let mut out = Vec::with_capacity(args.len());
916        if let Some(first) = args.first() {
917            out.push(self.promote_tensor_if_large(first, 1)?);
918            out.extend(args.iter().skip(1).cloned());
919        }
920        Ok(out)
921    }
922
923    fn evaluate_elementwise(&self, elements: usize, values: &[&Value]) -> DecisionEvaluation {
924        let fusion = active_fusion();
925        let fusion_kind = fusion.as_ref().map(|f| f.kind.clone());
926        let batch = batch_dimension_from_values(values);
927        let cpu_secs = cpu_estimate(self.thresholds.cpu_elem_per_elem, elements);
928
929        // Chain-aware residency: if any input is already on GPU, keep the op on GPU
930        if values.iter().any(|v| matches!(v, Value::GpuTensor(_))) {
931            return DecisionEvaluation {
932                recommend_gpu: true,
933                reason: DecisionReason::Residency,
934                cpu_secs,
935                gpu_secs: None,
936                threshold: Some(self.thresholds.binary_min_elems),
937                fusion_kind,
938                batch,
939            };
940        }
941
942        if let Some(active) = fusion.as_ref() {
943            // If an elementwise chain is actively fused OR this elementwise op
944            // participates in a fused reduction group, force GPU to keep the
945            // whole chain resident and avoid host round-trips.
946            if (active.kind.is_elementwise() || active.kind.is_reduction()) && active.supported {
947                return DecisionEvaluation {
948                    recommend_gpu: true,
949                    reason: DecisionReason::FusionOverride,
950                    cpu_secs,
951                    gpu_secs: None,
952                    threshold: Some(self.thresholds.binary_min_elems),
953                    fusion_kind,
954                    batch,
955                };
956            }
957        }
958
959        if self.small_batch_guard(elements, batch) {
960            return DecisionEvaluation {
961                recommend_gpu: false,
962                reason: DecisionReason::SmallBatchGuard,
963                cpu_secs,
964                gpu_secs: None,
965                threshold: Some(self.thresholds.binary_min_elems),
966                fusion_kind,
967                batch,
968            };
969        }
970
971        if let Some(model) = profile_cost_model() {
972            if let Some(gpu_duration) = model.estimate_elemwise(elements) {
973                let gpu_secs = Some(gpu_duration.as_secs_f64());
974                let cpu = cpu_secs.unwrap_or(f64::INFINITY);
975                let recommend = gpu_duration.as_secs_f64() * 0.95 < cpu;
976                return DecisionEvaluation {
977                    recommend_gpu: recommend,
978                    reason: DecisionReason::ProfileModel,
979                    cpu_secs,
980                    gpu_secs,
981                    threshold: Some(self.thresholds.binary_min_elems),
982                    fusion_kind,
983                    batch,
984                };
985            }
986        }
987
988        DecisionEvaluation {
989            recommend_gpu: elements >= self.thresholds.binary_min_elems,
990            reason: DecisionReason::Threshold,
991            cpu_secs,
992            gpu_secs: None,
993            threshold: Some(self.thresholds.binary_min_elems),
994            fusion_kind,
995            batch,
996        }
997    }
998
999    fn evaluate_matmul(&self, flops: usize) -> DecisionEvaluation {
1000        let cpu_secs = cpu_estimate(self.thresholds.cpu_matmul_per_flop, flops);
1001        if let Some(model) = profile_cost_model() {
1002            if let Some(gpu_duration) = model.estimate_matmul_flops(flops) {
1003                let gpu_secs = Some(gpu_duration.as_secs_f64());
1004                let cpu = cpu_secs.unwrap_or(f64::INFINITY);
1005                let recommend = gpu_duration.as_secs_f64() * 0.95 < cpu;
1006                return DecisionEvaluation {
1007                    recommend_gpu: recommend,
1008                    reason: DecisionReason::ProfileModel,
1009                    cpu_secs,
1010                    gpu_secs,
1011                    threshold: Some(self.thresholds.matmul_min_flops),
1012                    fusion_kind: None,
1013                    batch: None,
1014                };
1015            }
1016        }
1017
1018        DecisionEvaluation {
1019            recommend_gpu: flops >= self.thresholds.matmul_min_flops,
1020            reason: DecisionReason::Threshold,
1021            cpu_secs,
1022            gpu_secs: None,
1023            threshold: Some(self.thresholds.matmul_min_flops),
1024            fusion_kind: None,
1025            batch: None,
1026        }
1027    }
1028
1029    fn evaluate_reduction(&self, elements: usize) -> DecisionEvaluation {
1030        let fusion_kind = active_fusion().map(|f| f.kind.clone());
1031        let cpu_secs = cpu_estimate(self.thresholds.cpu_reduction_per_elem, elements);
1032        if let Some(model) = profile_cost_model() {
1033            if let Some(gpu_duration) = model.estimate_reduction(elements) {
1034                let gpu_secs = Some(gpu_duration.as_secs_f64());
1035                let cpu = cpu_secs.unwrap_or(f64::INFINITY);
1036                let recommend = gpu_duration.as_secs_f64() * 0.95 < cpu;
1037                return DecisionEvaluation {
1038                    recommend_gpu: recommend,
1039                    reason: DecisionReason::ProfileModel,
1040                    cpu_secs,
1041                    gpu_secs,
1042                    threshold: Some(self.thresholds.reduction_min_elems),
1043                    fusion_kind,
1044                    batch: None,
1045                };
1046            }
1047        }
1048
1049        DecisionEvaluation {
1050            recommend_gpu: elements >= self.thresholds.reduction_min_elems,
1051            reason: DecisionReason::Threshold,
1052            cpu_secs,
1053            gpu_secs: None,
1054            threshold: Some(self.thresholds.reduction_min_elems),
1055            fusion_kind,
1056            batch: None,
1057        }
1058    }
1059
1060    fn evaluate_unary(&self, elements: usize, op: UnaryOp, value: &Value) -> DecisionEvaluation {
1061        let fusion_kind = active_fusion().map(|f| f.kind.clone());
1062        let batch = batch_dimension_from_values(&[value]);
1063        // Chain-aware residency for unary ops: if operand is already on GPU, keep it on GPU
1064        if matches!(value, Value::GpuTensor(_)) {
1065            return DecisionEvaluation {
1066                recommend_gpu: true,
1067                reason: DecisionReason::Residency,
1068                cpu_secs: cpu_estimate(self.thresholds.cpu_elem_per_elem, elements),
1069                gpu_secs: None,
1070                threshold: Some(self.thresholds.unary_min_elems),
1071                fusion_kind,
1072                batch,
1073            };
1074        }
1075        if matches!(op, UnaryOp::Generic) && self.small_batch_guard(elements, batch) {
1076            return DecisionEvaluation {
1077                recommend_gpu: false,
1078                reason: DecisionReason::SmallBatchGuard,
1079                cpu_secs: cpu_estimate(self.thresholds.cpu_elem_per_elem, elements),
1080                gpu_secs: None,
1081                threshold: Some(self.thresholds.unary_min_elems),
1082                fusion_kind,
1083                batch,
1084            };
1085        }
1086
1087        let cpu_secs = cpu_estimate(self.thresholds.cpu_elem_per_elem, elements);
1088        if let Some(model) = profile_cost_model() {
1089            let gpu_duration = match op {
1090                UnaryOp::Transpose => model.estimate_transpose(elements),
1091                UnaryOp::Generic => model.estimate_elemwise(elements),
1092            };
1093            if let Some(gpu_duration) = gpu_duration {
1094                let gpu_secs = Some(gpu_duration.as_secs_f64());
1095                let cpu = cpu_secs.unwrap_or(f64::INFINITY);
1096                let recommend = gpu_duration.as_secs_f64() * 0.95 < cpu;
1097                return DecisionEvaluation {
1098                    recommend_gpu: recommend,
1099                    reason: DecisionReason::ProfileModel,
1100                    cpu_secs,
1101                    gpu_secs,
1102                    threshold: Some(self.thresholds.unary_min_elems),
1103                    fusion_kind,
1104                    batch,
1105                };
1106            }
1107        }
1108
1109        DecisionEvaluation {
1110            recommend_gpu: elements >= self.thresholds.unary_min_elems,
1111            reason: DecisionReason::Threshold,
1112            cpu_secs,
1113            gpu_secs: None,
1114            threshold: Some(self.thresholds.unary_min_elems),
1115            fusion_kind,
1116            batch,
1117        }
1118    }
1119
1120    async fn prepare_builtin(&self, name: &str, args: &[Value]) -> Result<Vec<Value>> {
1121        if !self.enabled {
1122            return Ok(args.to_vec());
1123        }
1124        // Do not attempt to promote 'double' on providers that cannot store f64.
1125        // Offloading a cast to double requires device-side f64; otherwise keep host.
1126        if name.eq_ignore_ascii_case("double")
1127            && self.provider.precision() != runmat_accelerate_api::ProviderPrecision::F64
1128        {
1129            return Ok(args.to_vec());
1130        }
1131        if let Some(policy) = builtin_policy(name) {
1132            if policy.is_sink {
1133                clear_sink_inputs(args);
1134                if should_gather_sink_args(name) {
1135                    trace!(
1136                        "auto-offload: prepare_builtin(name={:?}) is_sink=true residency=GatherImmediately -> gathering {} arg(s)",
1137                        name,
1138                        args.len()
1139                    );
1140                    return gather_args(args).await;
1141                }
1142                trace!(
1143                    "auto-offload: prepare_builtin(name={:?}) is_sink=true residency!=GatherImmediately -> no gather (fusion barrier only)",
1144                    name
1145                );
1146                return Ok(args.to_vec());
1147            }
1148
1149            let mut processed = args.to_vec();
1150
1151            if policy
1152                .accel_tags
1153                .iter()
1154                .any(|tag| matches!(tag, AccelTag::Reduction))
1155            {
1156                if (name.eq_ignore_ascii_case("max") || name.eq_ignore_ascii_case("min"))
1157                    && !max_or_min_reduction_call(args)
1158                {
1159                    trace!(
1160                        "Skipping reduction promotion for builtin '{}' (detected elementwise form)",
1161                        name
1162                    );
1163                } else {
1164                    log_promotion(|| format!("Promoting builtin '{}' as reduction", name));
1165                    return self.promote_reduction(reduction_op_hint(name), args);
1166                }
1167            }
1168
1169            if policy
1170                .accel_tags
1171                .iter()
1172                .any(|tag| matches!(tag, AccelTag::MatMul))
1173                && processed.len() >= 2
1174            {
1175                log_promotion(|| format!("Promoting builtin '{}' as matmul", name));
1176                let (a_p, b_p) =
1177                    self.promote_binary(BinaryOp::MatMul, &processed[0], &processed[1])?;
1178                processed[0] = a_p;
1179                processed[1] = b_p;
1180                return Ok(processed);
1181            }
1182
1183            if policy
1184                .accel_tags
1185                .iter()
1186                .any(|tag| matches!(tag, AccelTag::Elementwise))
1187                && processed.len() >= 2
1188            {
1189                log_promotion(|| format!("Promoting builtin '{}' as elementwise", name));
1190                let (a_p, b_p) =
1191                    self.promote_binary(BinaryOp::Elementwise, &processed[0], &processed[1])?;
1192                processed[0] = a_p;
1193                processed[1] = b_p;
1194                return Ok(processed);
1195            }
1196
1197            if let Some(first) = processed.first_mut() {
1198                if policy
1199                    .accel_tags
1200                    .iter()
1201                    .any(|tag| matches!(tag, AccelTag::Transpose))
1202                {
1203                    log_promotion(|| format!("Promoting builtin '{}' as transpose", name));
1204                    *first = self.promote_unary(UnaryOp::Transpose, first)?;
1205                    return Ok(processed);
1206                }
1207
1208                if policy
1209                    .accel_tags
1210                    .iter()
1211                    .any(|tag| matches!(tag, AccelTag::Unary))
1212                {
1213                    log_promotion(|| format!("Promoting builtin '{}' as unary", name));
1214                    *first = self.promote_unary(UnaryOp::Generic, first)?;
1215                    return Ok(processed);
1216                }
1217            }
1218        }
1219        Ok(args.to_vec())
1220    }
1221}
1222
1223fn tensor_rows_cols(value: &Value) -> Option<(usize, usize)> {
1224    match value {
1225        Value::Tensor(t) => Some((t.rows(), t.cols())),
1226        Value::GpuTensor(handle) => {
1227            if handle.shape.len() == 2 {
1228                Some((handle.shape[0], handle.shape[1]))
1229            } else {
1230                None
1231            }
1232        }
1233        _ => None,
1234    }
1235}
1236
1237#[allow(dead_code)]
1238fn should_skip_reduction_promotion(name: &str, args: &[Value]) -> bool {
1239    (name.eq_ignore_ascii_case("max") || name.eq_ignore_ascii_case("min"))
1240        && !max_or_min_reduction_call(args)
1241}
1242
1243fn reduction_op_hint(name: &str) -> ReductionOp {
1244    if name.eq_ignore_ascii_case("max") {
1245        ReductionOp::Max
1246    } else if name.eq_ignore_ascii_case("min") {
1247        ReductionOp::Min
1248    } else {
1249        ReductionOp::Sum
1250    }
1251}
1252
1253fn max_or_min_reduction_call(args: &[Value]) -> bool {
1254    if args.len() <= 1 {
1255        return true;
1256    }
1257    args.get(1).map(is_empty_placeholder_value).unwrap_or(false)
1258}
1259
1260fn is_empty_placeholder_value(value: &Value) -> bool {
1261    match value {
1262        Value::Tensor(t) => t.data.is_empty(),
1263        Value::LogicalArray(l) => l.data.is_empty(),
1264        Value::StringArray(sa) => sa.data.is_empty(),
1265        Value::CharArray(ca) => ca.data.is_empty(),
1266        Value::Cell(cell) => cell.data.is_empty(),
1267        Value::String(s) => s.is_empty(),
1268        _ => false,
1269    }
1270}
1271
1272async fn gather_args(args: &[Value]) -> Result<Vec<Value>> {
1273    let mut out = Vec::with_capacity(args.len());
1274    for (idx, value) in args.iter().enumerate() {
1275        if let Value::GpuTensor(handle) = value {
1276            trace!(
1277                "auto-offload: gather_args arg[{}]=GpuTensor device_id={} buffer_id={} shape={:?}",
1278                idx,
1279                handle.device_id,
1280                handle.buffer_id,
1281                handle.shape
1282            );
1283        } else {
1284            trace!(
1285                "auto-offload: gather_args arg[{}]={:?}",
1286                idx,
1287                value_kind(value)
1288            );
1289        }
1290        let gathered = gather_if_needed_async(value)
1291            .await
1292            .map_err(|e| anyhow!(e))?;
1293        trace!(
1294            "auto-offload: gather_args arg[{}] -> {:?}",
1295            idx,
1296            value_kind(&gathered)
1297        );
1298        out.push(gathered);
1299    }
1300    Ok(out)
1301}
1302
1303fn clear_sink_inputs(args: &[Value]) {
1304    for value in args {
1305        if let Value::GpuTensor(handle) = value {
1306            fusion_residency::clear(handle);
1307        }
1308    }
1309}
1310
1311fn should_gather_sink_args(name: &str) -> bool {
1312    matches!(
1313        builtin_residency_policy(name),
1314        Some(ResidencyPolicy::GatherImmediately) | None
1315    )
1316}
1317
1318fn value_kind(value: &Value) -> &'static str {
1319    match value {
1320        Value::GpuTensor(_) => "GpuTensor",
1321        Value::Tensor(_) => "Tensor",
1322        Value::Num(_) => "Num",
1323        Value::Int(_) => "Int",
1324        Value::Bool(_) => "Bool",
1325        Value::LogicalArray(_) => "LogicalArray",
1326        Value::CharArray(_) => "CharArray",
1327        Value::String(_) => "String",
1328        Value::StringArray(_) => "StringArray",
1329        Value::Cell(_) => "Cell",
1330        Value::Struct(_) => "Struct",
1331        Value::Object(_) => "Object",
1332        Value::HandleObject(_) => "HandleObject",
1333        Value::FunctionHandle(_)
1334        | Value::ExternalFunctionHandle(_)
1335        | Value::MethodFunctionHandle(_) => "FunctionHandle",
1336        Value::BoundFunctionHandle { .. } => "FunctionHandle",
1337        Value::Closure(_) => "Closure",
1338        Value::ClassRef(_) => "ClassRef",
1339        Value::Complex(_, _) => "Complex",
1340        Value::ComplexTensor(_) => "ComplexTensor",
1341        Value::Listener(_) => "Listener",
1342        Value::MException(_) => "MException",
1343        Value::OutputList(_) => "OutputList",
1344    }
1345}
1346
1347#[cfg(test)]
1348mod tests {
1349    use super::*;
1350
1351    #[test]
1352    fn max_detection_handles_placeholders() {
1353        let tensor = Tensor::new(vec![1.0], vec![1, 1]).unwrap();
1354        let placeholder = Tensor::new(Vec::<f64>::new(), vec![0, 0]).unwrap();
1355        let data = Value::Tensor(tensor);
1356        let empty = Value::Tensor(placeholder);
1357
1358        assert!(max_or_min_reduction_call(std::slice::from_ref(&data)));
1359        assert!(max_or_min_reduction_call(&[
1360            data.clone(),
1361            empty.clone(),
1362            Value::Num(1.0)
1363        ]));
1364        assert!(!max_or_min_reduction_call(&[data.clone(), Value::Num(0.0)]));
1365    }
1366}
1367
1368#[derive(Clone, Copy)]
1369struct BuiltinPolicy {
1370    accel_tags: &'static [AccelTag],
1371    is_sink: bool,
1372}
1373
1374static BUILTIN_POLICIES: OnceCell<HashMap<String, BuiltinPolicy>> = OnceCell::new();
1375
1376fn build_builtin_policy_map() -> HashMap<String, BuiltinPolicy> {
1377    let mut map = HashMap::new();
1378    for func in builtin_functions() {
1379        map.insert(
1380            func.name.to_ascii_lowercase(),
1381            BuiltinPolicy {
1382                accel_tags: func.accel_tags,
1383                is_sink: func.is_sink,
1384            },
1385        );
1386    }
1387    map
1388}
1389
1390fn builtin_policy(name: &str) -> Option<BuiltinPolicy> {
1391    let map = BUILTIN_POLICIES.get_or_init(build_builtin_policy_map);
1392    map.get(&name.to_ascii_lowercase()).copied()
1393}
1394
1395fn auto_enabled() -> bool {
1396    if let Some(flag) = env_bool("RUNMAT_ACCEL_AUTO_OFFLOAD") {
1397        return flag;
1398    }
1399    auto_offload_options().enabled
1400}
1401
1402fn calibrate_enabled() -> bool {
1403    if let Some(flag) = env_bool("RUNMAT_ACCEL_CALIBRATE") {
1404        return flag;
1405    }
1406    auto_offload_options().calibrate
1407}
1408
1409fn calibrate_refresh_enabled() -> bool {
1410    env_bool("RUNMAT_ACCEL_CALIBRATE_REFRESH").unwrap_or(false)
1411}
1412
1413fn apply_env_overrides(cfg: &mut ThresholdConfig) -> bool {
1414    let mut applied = false;
1415    if let Some(val) = env_usize("RUNMAT_ACCEL_THRESHOLD_UNARY") {
1416        cfg.unary_min_elems = val;
1417        applied = true;
1418    }
1419    if let Some(val) = env_usize("RUNMAT_ACCEL_THRESHOLD_ELEMWISE") {
1420        cfg.binary_min_elems = val;
1421        applied = true;
1422    }
1423    if let Some(val) = env_usize("RUNMAT_ACCEL_THRESHOLD_REDUCTION") {
1424        cfg.reduction_min_elems = val;
1425        applied = true;
1426    }
1427    if let Some(val) = env_usize("RUNMAT_ACCEL_THRESHOLD_MATMUL") {
1428        cfg.matmul_min_flops = val;
1429        applied = true;
1430    }
1431    if let Some(val) = env_usize("RUNMAT_ACCEL_THRESHOLD_ALL") {
1432        cfg.unary_min_elems = val;
1433        cfg.binary_min_elems = val;
1434        cfg.reduction_min_elems = val;
1435        applied = true;
1436    }
1437    if let Some(val) = env_usize("RUNMAT_ACCEL_SMALL_BATCH_MAX_DIM") {
1438        cfg.small_batch_max_dim = val;
1439        applied = true;
1440    }
1441    if let Some(val) = env_usize("RUNMAT_ACCEL_SMALL_BATCH_MIN_ELEMS") {
1442        cfg.small_batch_min_elems = val;
1443        applied = true;
1444    }
1445    applied
1446}
1447
1448fn env_usize(key: &str) -> Option<usize> {
1449    env::var(key).ok().and_then(|v| v.parse::<usize>().ok())
1450}
1451
1452#[derive(Debug, Clone, Serialize, Deserialize)]
1453struct CalibrationRecord {
1454    version: u32,
1455    recorded_at: u64,
1456    provider: CalibrationProviderDetails,
1457    thresholds: ThresholdConfig,
1458}
1459
1460#[derive(Debug, Clone, Serialize, Deserialize)]
1461struct CalibrationProviderDetails {
1462    name: String,
1463    vendor: String,
1464    backend: Option<String>,
1465    device_id: u32,
1466}
1467
1468#[cfg(target_arch = "wasm32")]
1469fn calibration_cache_key(info: &ApiDeviceInfo) -> String {
1470    let vendor = slugify(&info.vendor);
1471    let name = slugify(&info.name);
1472    let backend = slugify(info.backend.as_deref().unwrap_or("unknown"));
1473    format!("{}-{}-{}-{}.json", vendor, name, backend, info.device_id)
1474}
1475
1476async fn load_cached_thresholds_async(info: &ApiDeviceInfo) -> Option<(ThresholdConfig, String)> {
1477    #[cfg(target_arch = "wasm32")]
1478    {
1479        let key = calibration_cache_key(info);
1480        let contents = crate::web_auto_offload_store::load(&key).await?;
1481        match serde_json::from_str::<CalibrationRecord>(&contents) {
1482            Ok(record) => {
1483                if record.version != CALIBRATION_VERSION {
1484                    debug!(
1485                        "Native auto-offload calibration cache version mismatch (found {}, expected {})",
1486                        record.version,
1487                        CALIBRATION_VERSION
1488                    );
1489                    None
1490                } else {
1491                    Some((record.thresholds, key))
1492                }
1493            }
1494            Err(err) => {
1495                debug!(
1496                    "Native auto-offload failed to parse cached calibration for '{}': {err}",
1497                    info.name
1498                );
1499                None
1500            }
1501        }
1502    }
1503    #[cfg(not(target_arch = "wasm32"))]
1504    {
1505        load_cached_thresholds(info).map(|(cfg, path)| (cfg, path.display().to_string()))
1506    }
1507}
1508
1509async fn persist_thresholds_async(info: &ApiDeviceInfo, cfg: &ThresholdConfig) -> Result<String> {
1510    #[cfg(target_arch = "wasm32")]
1511    {
1512        let key = calibration_cache_key(info);
1513        let record = CalibrationRecord {
1514            version: CALIBRATION_VERSION,
1515            recorded_at: system_time_now()
1516                .duration_since(UNIX_EPOCH)
1517                .unwrap_or_else(|_| Duration::from_secs(0))
1518                .as_secs(),
1519            provider: CalibrationProviderDetails {
1520                name: info.name.clone(),
1521                vendor: info.vendor.clone(),
1522                backend: info.backend.clone(),
1523                device_id: info.device_id,
1524            },
1525            thresholds: cfg.clone(),
1526        };
1527        let payload = serde_json::to_string_pretty(&record).map_err(|e| anyhow!(e.to_string()))?;
1528        crate::web_auto_offload_store::save(&key, &payload)
1529            .await
1530            .map_err(|e| anyhow!(format!("indexeddb persist failed: {e:?}")))?;
1531        Ok(key)
1532    }
1533    #[cfg(not(target_arch = "wasm32"))]
1534    {
1535        persist_thresholds(info, cfg).map(|path| path.display().to_string())
1536    }
1537}
1538
1539fn load_cached_thresholds(info: &ApiDeviceInfo) -> Option<(ThresholdConfig, PathBuf)> {
1540    let path = calibration_cache_file(info)?;
1541    let contents = fs::read_to_string(&path).ok()?;
1542    match serde_json::from_str::<CalibrationRecord>(&contents) {
1543        Ok(record) => {
1544            if record.version != CALIBRATION_VERSION {
1545                debug!(
1546                    "Native auto-offload calibration cache version mismatch (found {}, expected {})",
1547                    record.version,
1548                    CALIBRATION_VERSION
1549                );
1550                None
1551            } else {
1552                Some((record.thresholds, path))
1553            }
1554        }
1555        Err(err) => {
1556            debug!(
1557                "Native auto-offload failed to parse cached calibration for '{}': {err}",
1558                info.name
1559            );
1560            None
1561        }
1562    }
1563}
1564
1565fn persist_thresholds(info: &ApiDeviceInfo, cfg: &ThresholdConfig) -> Result<PathBuf> {
1566    let path = calibration_cache_file(info)
1567        .ok_or_else(|| anyhow!("unable to determine calibration cache directory"))?;
1568    if let Some(parent) = path.parent() {
1569        fs::create_dir_all(parent).map_err(|e| anyhow!(e.to_string()))?;
1570    }
1571    let record = CalibrationRecord {
1572        version: CALIBRATION_VERSION,
1573        recorded_at: system_time_now()
1574            .duration_since(UNIX_EPOCH)
1575            .unwrap_or_else(|_| Duration::from_secs(0))
1576            .as_secs(),
1577        provider: CalibrationProviderDetails {
1578            name: info.name.clone(),
1579            vendor: info.vendor.clone(),
1580            backend: info.backend.clone(),
1581            device_id: info.device_id,
1582        },
1583        thresholds: cfg.clone(),
1584    };
1585    let payload = serde_json::to_string_pretty(&record).map_err(|e| anyhow!(e.to_string()))?;
1586    fs::write(&path, payload).map_err(|e| anyhow!(e.to_string()))?;
1587    Ok(path)
1588}
1589
1590fn calibration_cache_file(info: &ApiDeviceInfo) -> Option<PathBuf> {
1591    let mut dir = calibration_cache_dir()?;
1592    let vendor = slugify(&info.vendor);
1593    let name = slugify(&info.name);
1594    let backend = slugify(info.backend.as_deref().unwrap_or("unknown"));
1595    let file = format!("{}-{}-{}-{}.json", vendor, name, backend, info.device_id);
1596    dir.push(file);
1597    Some(dir)
1598}
1599
1600fn calibration_cache_dir() -> Option<PathBuf> {
1601    dirs::cache_dir().map(|base| base.join("runmat").join("auto_offload"))
1602}
1603
1604fn slugify(input: &str) -> String {
1605    let mut out = String::with_capacity(input.len());
1606    let mut last_underscore = false;
1607    for ch in input.chars() {
1608        if ch.is_ascii_alphanumeric() {
1609            out.push(ch.to_ascii_lowercase());
1610            last_underscore = false;
1611        } else if !last_underscore {
1612            out.push('_');
1613            last_underscore = true;
1614        }
1615    }
1616    let trimmed = out.trim_matches('_');
1617    if trimmed.is_empty() {
1618        "device".to_string()
1619    } else {
1620        trimmed.to_string()
1621    }
1622}
1623
1624fn auto_calibrate(provider: &'static dyn AccelProvider, cfg: &mut ThresholdConfig) -> Result<()> {
1625    if let Some(elem_threshold) = calibrate_elemwise(provider, cfg).transpose()? {
1626        if elem_threshold != usize::MAX {
1627            cfg.binary_min_elems = elem_threshold;
1628            cfg.unary_min_elems = cfg.unary_min_elems.min(elem_threshold);
1629        }
1630    }
1631    if let Some(red_threshold) = calibrate_reduction(provider, cfg).transpose()? {
1632        if red_threshold != usize::MAX {
1633            cfg.reduction_min_elems = red_threshold;
1634        }
1635    }
1636    if let Some(matmul_threshold) = calibrate_matmul(provider, cfg).transpose()? {
1637        if matmul_threshold != usize::MAX {
1638            cfg.matmul_min_flops = matmul_threshold;
1639        }
1640    }
1641    Ok(())
1642}
1643
1644fn calibrate_elemwise(
1645    provider: &'static dyn AccelProvider,
1646    cfg: &mut ThresholdConfig,
1647) -> Option<Result<usize>> {
1648    let sizes = [256usize, 1_024, 4_096, 16_384, 65_536];
1649    for size in sizes {
1650        match compare_elemwise(provider, size, &mut cfg.cpu_elem_per_elem) {
1651            Ok(Some(true)) => return Some(Ok(size)),
1652            Ok(Some(false)) => continue,
1653            Ok(None) => return None,
1654            Err(e) => return Some(Err(e)),
1655        }
1656    }
1657    Some(Ok(usize::MAX))
1658}
1659
1660fn compare_elemwise(
1661    provider: &'static dyn AccelProvider,
1662    elements: usize,
1663    cpu_cost_slot: &mut f64,
1664) -> Result<Option<bool>> {
1665    if elements == 0 {
1666        return Ok(Some(false));
1667    }
1668    let shape = vec![elements, 1];
1669    let template = match provider.precision() {
1670        ProviderPrecision::F64 => {
1671            Tensor::new((0..elements).map(|i| i as f64).collect(), shape.clone())
1672                .map_err(|e| anyhow!(e))?
1673        }
1674        ProviderPrecision::F32 => {
1675            Tensor::from_f32((0..elements).map(|i| i as f32).collect(), shape.clone())
1676                .map_err(|e| anyhow!(e))?
1677        }
1678    };
1679    let a = Value::Tensor(template.clone());
1680    let b = Value::Tensor(template.clone());
1681    let cpu_time = time(|| runmat_runtime::call_builtin("plus", &[a.clone(), b.clone()]))?;
1682    let cpu_per_elem = cpu_time.as_secs_f64() / elements as f64;
1683    update_cpu_cost(cpu_cost_slot, cpu_per_elem);
1684    if let Some(model) = profile_cost_model() {
1685        if let Some(gpu_time) = model.estimate_elemwise(elements) {
1686            trace!(
1687                "Elemwise calibration ({} elems): cpu={:?}, gpu_est={:?}",
1688                elements,
1689                cpu_time,
1690                gpu_time
1691            );
1692            return Ok(Some(gpu_time < cpu_time));
1693        }
1694    }
1695    let view = HostTensorView {
1696        data: template.data.as_slice(),
1697        shape: template.shape.as_slice(),
1698    };
1699    let ha = provider.upload(&view).map_err(|e| anyhow!(e.to_string()))?;
1700    let hb = provider.upload(&view).map_err(|e| anyhow!(e.to_string()))?;
1701    let start = Instant::now();
1702    let hc = match futures::executor::block_on(provider.elem_add(&ha, &hb)) {
1703        Ok(h) => h,
1704        Err(_) => {
1705            let _ = provider.free(&ha);
1706            let _ = provider.free(&hb);
1707            return Ok(None);
1708        }
1709    };
1710    let gpu_time = start.elapsed();
1711    let _ = provider.free(&ha);
1712    let _ = provider.free(&hb);
1713    let _ = provider.free(&hc);
1714    Ok(Some(gpu_time < cpu_time))
1715}
1716
1717fn calibrate_reduction(
1718    provider: &'static dyn AccelProvider,
1719    cfg: &mut ThresholdConfig,
1720) -> Option<Result<usize>> {
1721    let sizes = [256usize, 1_024, 4_096, 16_384, 65_536];
1722    for size in sizes {
1723        match compare_reduction(provider, size, &mut cfg.cpu_reduction_per_elem) {
1724            Ok(Some(true)) => return Some(Ok(size)),
1725            Ok(Some(false)) => continue,
1726            Ok(None) => return None,
1727            Err(e) => return Some(Err(e)),
1728        }
1729    }
1730    Some(Ok(usize::MAX))
1731}
1732
1733fn compare_reduction(
1734    provider: &'static dyn AccelProvider,
1735    elements: usize,
1736    cpu_cost_slot: &mut f64,
1737) -> Result<Option<bool>> {
1738    let shape = vec![elements, 1];
1739    let template = match provider.precision() {
1740        ProviderPrecision::F64 => {
1741            Tensor::new((0..elements).map(|i| i as f64).collect(), shape.clone())
1742                .map_err(|e| anyhow!(e))?
1743        }
1744        ProviderPrecision::F32 => {
1745            Tensor::from_f32((0..elements).map(|i| i as f32).collect(), shape.clone())
1746                .map_err(|e| anyhow!(e))?
1747        }
1748    };
1749    let value = Value::Tensor(template.clone());
1750    let cpu_time = time(|| runmat_runtime::call_builtin("sum", std::slice::from_ref(&value)))?;
1751    let cpu_per_elem = cpu_time.as_secs_f64() / elements as f64;
1752    update_cpu_cost(cpu_cost_slot, cpu_per_elem);
1753    if let Some(model) = profile_cost_model() {
1754        if let Some(gpu_time) = model.estimate_reduction(elements) {
1755            trace!(
1756                "Reduction calibration ({} elems): cpu={:?}, gpu_est={:?}",
1757                elements,
1758                cpu_time,
1759                gpu_time
1760            );
1761            return Ok(Some(gpu_time < cpu_time));
1762        }
1763    }
1764    let view = HostTensorView {
1765        data: template.data.as_slice(),
1766        shape: template.shape.as_slice(),
1767    };
1768    let h = provider.upload(&view).map_err(|e| anyhow!(e.to_string()))?;
1769    let start = Instant::now();
1770    let out = match futures::executor::block_on(provider.reduce_sum(&h)) {
1771        Ok(hc) => hc,
1772        Err(_) => {
1773            provider.free(&h).ok();
1774            return Ok(None);
1775        }
1776    };
1777    let gpu_time = start.elapsed();
1778    let _ = provider.free(&h);
1779    let _ = provider.free(&out);
1780    Ok(Some(gpu_time < cpu_time))
1781}
1782
1783fn calibrate_matmul(
1784    provider: &'static dyn AccelProvider,
1785    cfg: &mut ThresholdConfig,
1786) -> Option<Result<usize>> {
1787    let dims = [32usize, 64, 96, 128, 192];
1788    for n in dims {
1789        match compare_matmul(provider, n, &mut cfg.cpu_matmul_per_flop) {
1790            Ok(Some(true)) => {
1791                let flops = n * n * n;
1792                return Some(Ok(flops));
1793            }
1794            Ok(Some(false)) => continue,
1795            Ok(None) => return None,
1796            Err(e) => return Some(Err(e)),
1797        }
1798    }
1799    Some(Ok(usize::MAX))
1800}
1801
1802fn compare_matmul(
1803    provider: &'static dyn AccelProvider,
1804    n: usize,
1805    cpu_cost_slot: &mut f64,
1806) -> Result<Option<bool>> {
1807    if n == 0 {
1808        return Ok(Some(false));
1809    }
1810    let total = n * n;
1811    let shape = vec![n, n];
1812    let (ta, tb) = match provider.precision() {
1813        ProviderPrecision::F64 => {
1814            let data_a: Vec<f64> = (0..total).map(|i| (i % 13) as f64).collect();
1815            let data_b: Vec<f64> = (0..total).map(|i| (i % 7) as f64).collect();
1816            let ta = Tensor::new(data_a, shape.clone()).map_err(|e| anyhow!(e))?;
1817            let tb = Tensor::new(data_b, shape.clone()).map_err(|e| anyhow!(e))?;
1818            (ta, tb)
1819        }
1820        ProviderPrecision::F32 => {
1821            let data_a: Vec<f32> = (0..total).map(|i| (i % 13) as f32).collect();
1822            let data_b: Vec<f32> = (0..total).map(|i| (i % 7) as f32).collect();
1823            let ta = Tensor::from_f32(data_a, shape.clone()).map_err(|e| anyhow!(e))?;
1824            let tb = Tensor::from_f32(data_b, shape.clone()).map_err(|e| anyhow!(e))?;
1825            (ta, tb)
1826        }
1827    };
1828    let a = Value::Tensor(ta.clone());
1829    let b = Value::Tensor(tb.clone());
1830    let cpu_time = time(|| futures::executor::block_on(runmat_runtime::value_matmul(&a, &b)))?;
1831    let flops = (n * n * n) as f64;
1832    update_cpu_cost(cpu_cost_slot, cpu_time.as_secs_f64() / flops);
1833    if let Some(model) = profile_cost_model() {
1834        if let Some(gpu_time) = model.estimate_matmul(n, n, n) {
1835            trace!(
1836                "Matmul calibration ({}^3 flops): cpu={:?}, gpu_est={:?}",
1837                n,
1838                cpu_time,
1839                gpu_time
1840            );
1841            return Ok(Some(gpu_time < cpu_time));
1842        }
1843    }
1844    let view_a = HostTensorView {
1845        data: ta.data.as_slice(),
1846        shape: ta.shape.as_slice(),
1847    };
1848    let view_b = HostTensorView {
1849        data: tb.data.as_slice(),
1850        shape: tb.shape.as_slice(),
1851    };
1852    let ha = provider
1853        .upload(&view_a)
1854        .map_err(|e| anyhow!(e.to_string()))?;
1855    let hb = provider
1856        .upload(&view_b)
1857        .map_err(|e| anyhow!(e.to_string()))?;
1858    let start = Instant::now();
1859    let hc = match futures::executor::block_on(provider.matmul(&ha, &hb)) {
1860        Ok(h) => h,
1861        Err(_) => {
1862            let _ = provider.free(&ha);
1863            let _ = provider.free(&hb);
1864            return Ok(None);
1865        }
1866    };
1867    let gpu_time = start.elapsed();
1868    let _ = provider.free(&ha);
1869    let _ = provider.free(&hb);
1870    let _ = provider.free(&hc);
1871    Ok(Some(gpu_time < cpu_time))
1872}
1873
1874fn time<F, T>(mut f: F) -> Result<Duration>
1875where
1876    F: FnMut() -> runmat_runtime::BuiltinResult<T>,
1877{
1878    let start = Instant::now();
1879    let _ = f().map_err(|err| anyhow!(err))?;
1880    Ok(start.elapsed())
1881}
1882
1883pub fn auto_offload_report() -> Option<AutoOffloadReport> {
1884    let state_guard = AUTO_STATE.get()?;
1885    let state = state_guard.lock().ok()?;
1886    let calibration = state.previous_thresholds.as_ref().map(|prev| {
1887        let delta = state
1888            .calibration_delta
1889            .clone()
1890            .unwrap_or_else(|| compute_delta(prev, &state.thresholds));
1891        AutoOffloadCalibrationSummary {
1892            previous: threshold_snapshot(prev),
1893            delta,
1894        }
1895    });
1896    Some(AutoOffloadReport {
1897        provider: state.provider.clone(),
1898        thresholds: threshold_snapshot(&state.thresholds),
1899        base_source: state.base_source,
1900        env_overrides_applied: state.env_overrides_applied,
1901        cache_path: state.cache_path.clone(),
1902        calibrate_duration_ms: state.calibrate_duration_ms,
1903        calibration,
1904        decisions: snapshot_decisions(),
1905    })
1906}
1907
1908pub fn sequence_threshold_hint() -> Option<usize> {
1909    AUTO_STATE
1910        .get()
1911        .and_then(|state| state.lock().ok())
1912        .map(|state| state.thresholds.unary_min_elems)
1913}
1914
1915pub fn reset_auto_offload_log() {
1916    clear_decisions();
1917}
1918
1919#[derive(Clone, Deserialize, Debug)]
1920struct ProfileDurationSummary {
1921    #[serde(default)]
1922    avg_ms: f64,
1923}
1924
1925#[derive(Clone, Deserialize, Debug)]
1926struct ProfileReport {
1927    category: String,
1928    #[serde(default)]
1929    input_shapes: Vec<Vec<usize>>,
1930    total_ms: ProfileDurationSummary,
1931}
1932
1933#[derive(Clone, Copy, Default, Debug)]
1934struct LinearModel {
1935    slope: f64,
1936    intercept: f64,
1937}
1938
1939impl LinearModel {
1940    fn estimate(&self, x: f64) -> Option<Duration> {
1941        if !self.slope.is_finite() || self.slope <= 0.0 {
1942            return None;
1943        }
1944        let total = self.intercept + self.slope * x;
1945        if total.is_finite() && total > 0.0 {
1946            Some(Duration::from_secs_f64(total))
1947        } else {
1948            None
1949        }
1950    }
1951}
1952
1953#[derive(Default)]
1954struct ProfileCostModel {
1955    elem: Option<LinearModel>,
1956    reduction: Option<LinearModel>,
1957    transpose: Option<LinearModel>,
1958    matmul: Option<LinearModel>,
1959}
1960
1961impl ProfileCostModel {
1962    fn from_reports(reports: &[ProfileReport]) -> Self {
1963        let mut elem_samples = Vec::<(f64, f64)>::new();
1964        let mut reduction_samples = Vec::<(f64, f64)>::new();
1965        let mut transpose_samples = Vec::<(f64, f64)>::new();
1966        let mut matmul_samples = Vec::<(f64, f64)>::new();
1967
1968        for report in reports {
1969            let total_secs = report.total_ms.avg_ms / 1_000.0;
1970            match report.category.as_str() {
1971                "elementwise" | "reduction" | "transpose" => {
1972                    if let Some(shape) = report.input_shapes.first() {
1973                        let elems: usize = shape.iter().copied().product();
1974                        if elems == 0 {
1975                            continue;
1976                        }
1977                        let sample = (elems as f64, total_secs);
1978                        match report.category.as_str() {
1979                            "elementwise" => elem_samples.push(sample),
1980                            "reduction" => reduction_samples.push(sample),
1981                            "transpose" => transpose_samples.push(sample),
1982                            _ => {}
1983                        }
1984                    }
1985                }
1986                "matmul" => {
1987                    if report.input_shapes.len() >= 2 {
1988                        let a = &report.input_shapes[0];
1989                        let b = &report.input_shapes[1];
1990                        if a.len() == 2 && b.len() == 2 {
1991                            let m = a[0];
1992                            let k = a[1];
1993                            let n = b[1];
1994                            let flops = m.checked_mul(k).and_then(|val| val.checked_mul(n));
1995                            if let Some(flops) = flops {
1996                                matmul_samples.push((flops as f64, total_secs));
1997                            }
1998                        }
1999                    }
2000                }
2001                _ => {}
2002            }
2003        }
2004
2005        ProfileCostModel {
2006            elem: fit_linear_model(&elem_samples),
2007            reduction: fit_linear_model(&reduction_samples),
2008            transpose: fit_linear_model(&transpose_samples),
2009            matmul: fit_linear_model(&matmul_samples),
2010        }
2011    }
2012
2013    fn estimate_elemwise(&self, elements: usize) -> Option<Duration> {
2014        self.elem.and_then(|model| model.estimate(elements as f64))
2015    }
2016
2017    fn estimate_reduction(&self, elements: usize) -> Option<Duration> {
2018        self.reduction
2019            .and_then(|model| model.estimate(elements as f64))
2020    }
2021
2022    fn estimate_matmul(&self, m: usize, k: usize, n: usize) -> Option<Duration> {
2023        let flops = m.checked_mul(k)?.checked_mul(n)?;
2024        self.matmul.and_then(|model| model.estimate(flops as f64))
2025    }
2026
2027    fn estimate_matmul_flops(&self, flops: usize) -> Option<Duration> {
2028        self.matmul.and_then(|model| model.estimate(flops as f64))
2029    }
2030
2031    fn estimate_transpose(&self, elements: usize) -> Option<Duration> {
2032        self.transpose
2033            .and_then(|model| model.estimate(elements as f64))
2034    }
2035}
2036
2037fn fit_linear_model(samples: &[(f64, f64)]) -> Option<LinearModel> {
2038    if samples.is_empty() {
2039        return None;
2040    }
2041    if samples.len() == 1 {
2042        let (x, y) = samples[0];
2043        if x > 0.0 {
2044            return Some(LinearModel {
2045                slope: (y / x).max(0.0),
2046                intercept: 0.0,
2047            });
2048        }
2049        return None;
2050    }
2051
2052    let sum_x: f64 = samples.iter().map(|(x, _)| *x).sum();
2053    let sum_y: f64 = samples.iter().map(|(_, y)| *y).sum();
2054    let sum_xx: f64 = samples.iter().map(|(x, _)| x * x).sum();
2055    let sum_xy: f64 = samples.iter().map(|(x, y)| x * y).sum();
2056    let n = samples.len() as f64;
2057    let denom = (n * sum_xx) - (sum_x * sum_x);
2058    if denom.abs() < f64::EPSILON {
2059        return None;
2060    }
2061    let slope = ((n * sum_xy) - (sum_x * sum_y)) / denom;
2062    let mean_x = sum_x / n;
2063    let mean_y = sum_y / n;
2064    let mut intercept = mean_y - slope * mean_x;
2065    if intercept < 0.0 {
2066        intercept = 0.0;
2067    }
2068    if !slope.is_finite() || slope <= 0.0 {
2069        return None;
2070    }
2071    Some(LinearModel { slope, intercept })
2072}
2073
2074fn profile_cost_model() -> Option<&'static ProfileCostModel> {
2075    PROFILE_MODEL.get_or_init(load_profile_cost_model).as_ref()
2076}
2077
2078fn load_profile_cost_model() -> Option<ProfileCostModel> {
2079    let mut candidates = Vec::new();
2080    if let Ok(path) = env::var("RUNMAT_ACCEL_PROFILE") {
2081        candidates.push(PathBuf::from(path));
2082    }
2083    if let Some(path) = auto_offload_options().profile_path.clone() {
2084        candidates.push(path);
2085    }
2086    candidates.push(PathBuf::from("benchmarks/wgpu_profile/mac_m2.json"));
2087    candidates.push(PathBuf::from("wgpu_profile.json"));
2088
2089    for path in candidates {
2090        if !path.exists() {
2091            continue;
2092        }
2093        match fs::read_to_string(&path) {
2094            Ok(contents) => match serde_json::from_str::<Vec<ProfileReport>>(&contents) {
2095                Ok(reports) => {
2096                    debug!(
2097                        "Loaded {} GPU profile reports from {}",
2098                        reports.len(),
2099                        path.display()
2100                    );
2101                    return Some(ProfileCostModel::from_reports(&reports));
2102                }
2103                Err(err) => {
2104                    debug!("Failed to parse GPU profile {}: {err}", path.display());
2105                }
2106            },
2107            Err(err) => {
2108                debug!("Failed to read GPU profile {}: {err}", path.display());
2109            }
2110        }
2111    }
2112    None
2113}
2114
2115pub async fn promote_binary(op: BinaryOp, a: &Value, b: &Value) -> Result<(Value, Value)> {
2116    if !auto_enabled() {
2117        return Ok((a.clone(), b.clone()));
2118    }
2119    if let Some(auto) = global().await {
2120        auto.promote_binary(op, a, b)
2121    } else {
2122        Ok((a.clone(), b.clone()))
2123    }
2124}
2125
2126pub async fn promote_unary(op: UnaryOp, value: &Value) -> Result<Value> {
2127    if !auto_enabled() {
2128        return Ok(value.clone());
2129    }
2130    if let Some(auto) = global().await {
2131        auto.promote_unary(op, value)
2132    } else {
2133        Ok(value.clone())
2134    }
2135}
2136
2137pub async fn prepare_builtin_args(name: &str, args: &[Value]) -> Result<Vec<Value>> {
2138    if let Some(policy) = builtin_policy(name) {
2139        if policy.is_sink {
2140            clear_sink_inputs(args);
2141            if should_gather_sink_args(name) {
2142                trace!(
2143                    "auto-offload: prepare_builtin_args(name={:?}) is_sink=true residency=GatherImmediately -> gathering {} arg(s)",
2144                    name,
2145                    args.len()
2146                );
2147                return gather_args(args).await;
2148            }
2149            trace!(
2150                "auto-offload: prepare_builtin_args(name={:?}) is_sink=true residency!=GatherImmediately -> no gather (fusion barrier only)",
2151                name
2152            );
2153            return Ok(args.to_vec());
2154        }
2155    }
2156    if !auto_enabled() {
2157        return Ok(args.to_vec());
2158    }
2159    if let Some(auto) = global().await {
2160        auto.prepare_builtin(name, args).await
2161    } else {
2162        Ok(args.to_vec())
2163    }
2164}
2165
2166pub fn is_sink(name: &str) -> bool {
2167    builtin_policy(name).map(|p| p.is_sink).unwrap_or(false)
2168}
2169
2170pub async fn promote_reduction_args(op: ReductionOp, args: &[Value]) -> Result<Vec<Value>> {
2171    if !auto_enabled() {
2172        return Ok(args.to_vec());
2173    }
2174    if let Some(auto) = global().await {
2175        auto.promote_reduction(op, args)
2176    } else {
2177        Ok(args.to_vec())
2178    }
2179}