runmat_accelerate/
native_auto.rs

1use std::collections::HashMap;
2use std::env;
3use std::fs;
4use std::path::{Path, PathBuf};
5use std::sync::Mutex;
6use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
7
8use crate::{
9    auto_offload_options,
10    fusion::{active_fusion, FusionKind},
11    fusion_residency,
12    precision::ensure_provider_supports_dtype,
13    AutoOffloadLogLevel,
14};
15use anyhow::{anyhow, Result};
16use log::{debug, info, trace, warn};
17use once_cell::sync::{Lazy, OnceCell};
18use runmat_accelerate_api::{AccelProvider, ApiDeviceInfo, HostTensorView, ProviderPrecision};
19use runmat_builtins::{builtin_functions, AccelTag, Tensor, Value};
20use runmat_runtime::gather_if_needed;
21use serde::{Deserialize, Serialize};
22
23const DEFAULT_CPU_ELEM_PER_ELEM: f64 = 1.0e-7;
24const DEFAULT_CPU_REDUCTION_PER_ELEM: f64 = 1.2e-7;
25const DEFAULT_CPU_MATMUL_PER_FLOP: f64 = 2.5e-11;
26const SMALL_BATCH_DEFAULT_MAX_DIM: usize = 8;
27const SMALL_BATCH_DEFAULT_MIN_ELEMS: usize = 1_048_576;
28const DECISION_LOG_CAPACITY: usize = 128;
29const CALIBRATION_VERSION: u32 = 1;
30
31#[derive(Clone, Copy, Debug)]
32pub enum BinaryOp {
33    Elementwise,
34    MatMul,
35}
36
37#[derive(Clone, Copy, Debug)]
38pub enum UnaryOp {
39    Generic,
40    Transpose,
41}
42
43#[derive(Clone, Copy, Debug)]
44pub enum ReductionOp {
45    Sum,
46    Mean,
47    Min,
48    Max,
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize)]
52struct ThresholdConfig {
53    unary_min_elems: usize,
54    binary_min_elems: usize,
55    reduction_min_elems: usize,
56    matmul_min_flops: usize,
57    cpu_elem_per_elem: f64,
58    cpu_reduction_per_elem: f64,
59    cpu_matmul_per_flop: f64,
60    small_batch_max_dim: usize,
61    small_batch_min_elems: usize,
62}
63
64impl Default for ThresholdConfig {
65    fn default() -> Self {
66        Self {
67            unary_min_elems: 4_096,
68            binary_min_elems: 4_096,
69            reduction_min_elems: 256,
70            matmul_min_flops: 1_000_000, // roughly 100x100x100
71            cpu_elem_per_elem: DEFAULT_CPU_ELEM_PER_ELEM,
72            cpu_reduction_per_elem: DEFAULT_CPU_REDUCTION_PER_ELEM,
73            cpu_matmul_per_flop: DEFAULT_CPU_MATMUL_PER_FLOP,
74            small_batch_max_dim: SMALL_BATCH_DEFAULT_MAX_DIM,
75            small_batch_min_elems: SMALL_BATCH_DEFAULT_MIN_ELEMS,
76        }
77    }
78}
79
80#[derive(Debug, Clone, Serialize)]
81pub struct AutoOffloadDecisionEntry {
82    pub timestamp_ms: u128,
83    pub operation: String,
84    pub elements: Option<usize>,
85    pub flops: Option<usize>,
86    pub batch: Option<usize>,
87    pub decision: AutoOffloadDisposition,
88    pub reason: DecisionReason,
89    pub cpu_estimate_ms: Option<f64>,
90    pub gpu_estimate_ms: Option<f64>,
91    pub threshold: Option<usize>,
92    pub fusion_kind: Option<FusionKind>,
93}
94
95#[derive(Debug, Clone, Copy, Serialize, PartialEq, Eq)]
96#[serde(rename_all = "kebab-case")]
97pub enum AutoOffloadDisposition {
98    Gpu,
99    Cpu,
100}
101
102#[derive(Debug, Clone, Copy, Serialize, PartialEq, Eq)]
103#[serde(rename_all = "kebab-case")]
104pub enum DecisionReason {
105    FusionOverride,
106    Residency,
107    SmallBatchGuard,
108    ProfileModel,
109    Threshold,
110    Disabled,
111}
112
113#[derive(Debug, Clone, Serialize)]
114pub struct ThresholdSnapshot {
115    pub unary_min_elems: usize,
116    pub binary_min_elems: usize,
117    pub reduction_min_elems: usize,
118    pub matmul_min_flops: usize,
119    pub cpu_elem_per_elem: f64,
120    pub cpu_reduction_per_elem: f64,
121    pub cpu_matmul_per_flop: f64,
122    pub small_batch_max_dim: usize,
123    pub small_batch_min_elems: usize,
124}
125
126#[derive(Debug, Clone, Serialize)]
127pub struct AutoOffloadCalibrationSummary {
128    pub previous: ThresholdSnapshot,
129    pub delta: ThresholdDelta,
130}
131
132#[derive(Debug, Clone, Serialize, Default)]
133pub struct ThresholdDelta {
134    #[serde(skip_serializing_if = "Option::is_none")]
135    pub cpu_elem_per_elem: Option<ThresholdDeltaEntry>,
136    #[serde(skip_serializing_if = "Option::is_none")]
137    pub cpu_reduction_per_elem: Option<ThresholdDeltaEntry>,
138    #[serde(skip_serializing_if = "Option::is_none")]
139    pub cpu_matmul_per_flop: Option<ThresholdDeltaEntry>,
140}
141
142#[derive(Debug, Clone, Serialize)]
143pub struct ThresholdDeltaEntry {
144    pub before: f64,
145    pub after: f64,
146    pub absolute: f64,
147    #[serde(skip_serializing_if = "Option::is_none")]
148    pub ratio: Option<f64>,
149}
150
151impl ThresholdDeltaEntry {
152    fn new(before: f64, after: f64) -> Self {
153        let absolute = after - before;
154        let ratio = if before.abs() > f64::EPSILON {
155            Some(after / before)
156        } else {
157            None
158        };
159        Self {
160            before,
161            after,
162            absolute,
163            ratio,
164        }
165    }
166}
167
168#[derive(Debug, Clone, Serialize)]
169pub struct AutoOffloadReport {
170    pub provider: Option<CachedProviderInfo>,
171    pub thresholds: ThresholdSnapshot,
172    pub base_source: ThresholdBase,
173    pub env_overrides_applied: bool,
174    pub cache_path: Option<String>,
175    pub calibrate_duration_ms: Option<u128>,
176    #[serde(skip_serializing_if = "Option::is_none")]
177    pub calibration: Option<AutoOffloadCalibrationSummary>,
178    pub decisions: Vec<AutoOffloadDecisionEntry>,
179}
180
181#[derive(Debug, Clone, Copy, Serialize, PartialEq, Eq)]
182#[serde(rename_all = "kebab-case")]
183pub enum ThresholdBase {
184    BuiltInDefault,
185    LoadedFromCache,
186    Calibrated,
187}
188
189impl ThresholdBase {
190    pub fn as_str(&self) -> &'static str {
191        match self {
192            ThresholdBase::BuiltInDefault => "built-in-default",
193            ThresholdBase::LoadedFromCache => "loaded-from-cache",
194            ThresholdBase::Calibrated => "calibrated",
195        }
196    }
197}
198
199#[derive(Debug, Clone, Serialize)]
200pub struct CachedProviderInfo {
201    pub name: String,
202    pub vendor: String,
203    pub backend: Option<String>,
204    pub device_id: u32,
205}
206
207#[derive(Debug, Clone)]
208struct AutoOffloadState {
209    provider: Option<CachedProviderInfo>,
210    thresholds: ThresholdConfig,
211    base_source: ThresholdBase,
212    env_overrides_applied: bool,
213    cache_path: Option<String>,
214    calibrate_duration_ms: Option<u128>,
215    previous_thresholds: Option<ThresholdConfig>,
216    calibration_delta: Option<ThresholdDelta>,
217}
218
219#[derive(Clone)]
220struct DecisionEvaluation {
221    recommend_gpu: bool,
222    reason: DecisionReason,
223    cpu_secs: Option<f64>,
224    gpu_secs: Option<f64>,
225    threshold: Option<usize>,
226    fusion_kind: Option<FusionKind>,
227    batch: Option<usize>,
228}
229
230struct DecisionLog {
231    entries: Vec<AutoOffloadDecisionEntry>,
232}
233
234impl DecisionLog {
235    fn new() -> Self {
236        Self {
237            entries: Vec::new(),
238        }
239    }
240
241    fn push(&mut self, entry: AutoOffloadDecisionEntry) {
242        self.entries.push(entry);
243        if self.entries.len() > DECISION_LOG_CAPACITY {
244            let overflow = self.entries.len() - DECISION_LOG_CAPACITY;
245            self.entries.drain(0..overflow);
246        }
247    }
248
249    fn snapshot(&self) -> Vec<AutoOffloadDecisionEntry> {
250        self.entries.clone()
251    }
252
253    fn clear(&mut self) {
254        self.entries.clear();
255    }
256}
257
258static DECISION_LOG: Lazy<Mutex<DecisionLog>> = Lazy::new(|| Mutex::new(DecisionLog::new()));
259static AUTO_STATE: OnceCell<Mutex<AutoOffloadState>> = OnceCell::new();
260
261fn record_decision(entry: AutoOffloadDecisionEntry) {
262    if let Ok(mut log) = DECISION_LOG.lock() {
263        log.push(entry);
264    }
265}
266
267fn snapshot_decisions() -> Vec<AutoOffloadDecisionEntry> {
268    DECISION_LOG
269        .lock()
270        .map(|log| log.snapshot())
271        .unwrap_or_default()
272}
273
274fn clear_decisions() {
275    if let Ok(mut log) = DECISION_LOG.lock() {
276        log.clear();
277    }
278}
279
280fn now_millis() -> u128 {
281    SystemTime::now()
282        .duration_since(UNIX_EPOCH)
283        .unwrap_or_else(|_| Duration::from_secs(0))
284        .as_millis()
285}
286
287fn threshold_snapshot(cfg: &ThresholdConfig) -> ThresholdSnapshot {
288    ThresholdSnapshot {
289        unary_min_elems: cfg.unary_min_elems,
290        binary_min_elems: cfg.binary_min_elems,
291        reduction_min_elems: cfg.reduction_min_elems,
292        matmul_min_flops: cfg.matmul_min_flops,
293        cpu_elem_per_elem: cfg.cpu_elem_per_elem,
294        cpu_reduction_per_elem: cfg.cpu_reduction_per_elem,
295        cpu_matmul_per_flop: cfg.cpu_matmul_per_flop,
296        small_batch_max_dim: cfg.small_batch_max_dim,
297        small_batch_min_elems: cfg.small_batch_min_elems,
298    }
299}
300
301fn compute_delta(before: &ThresholdConfig, after: &ThresholdConfig) -> ThresholdDelta {
302    let mut delta = ThresholdDelta::default();
303
304    if (before.cpu_elem_per_elem - after.cpu_elem_per_elem).abs() > f64::EPSILON {
305        delta.cpu_elem_per_elem = Some(ThresholdDeltaEntry::new(
306            before.cpu_elem_per_elem,
307            after.cpu_elem_per_elem,
308        ));
309    }
310
311    if (before.cpu_reduction_per_elem - after.cpu_reduction_per_elem).abs() > f64::EPSILON {
312        delta.cpu_reduction_per_elem = Some(ThresholdDeltaEntry::new(
313            before.cpu_reduction_per_elem,
314            after.cpu_reduction_per_elem,
315        ));
316    }
317
318    if (before.cpu_matmul_per_flop - after.cpu_matmul_per_flop).abs() > f64::EPSILON {
319        delta.cpu_matmul_per_flop = Some(ThresholdDeltaEntry::new(
320            before.cpu_matmul_per_flop,
321            after.cpu_matmul_per_flop,
322        ));
323    }
324
325    delta
326}
327
328#[derive(Debug, Deserialize)]
329struct CalibrationFile {
330    #[serde(default)]
331    suite: Option<CalibrationSuiteSection>,
332    #[serde(default)]
333    auto_offload_calibration: Option<CalibrationSample>,
334}
335
336#[derive(Debug, Deserialize)]
337struct CalibrationSuiteSection {
338    #[serde(default)]
339    auto_offload_calibration: Option<CalibrationSample>,
340}
341
342#[derive(Debug, Clone, Deserialize)]
343struct CalibrationSample {
344    #[serde(default)]
345    runs: usize,
346    #[serde(default, rename = "cpu_time_ms")]
347    cpu_time: CalibrationTimes,
348    #[serde(default)]
349    units: CalibrationUnits,
350    #[serde(default)]
351    provider: Option<CalibrationProviderInfo>,
352    #[serde(default)]
353    provider_conflict: bool,
354}
355
356#[derive(Debug, Clone, Deserialize, Default)]
357struct CalibrationTimes {
358    #[serde(default)]
359    elementwise: f64,
360    #[serde(default)]
361    reduction: f64,
362    #[serde(default)]
363    matmul: f64,
364}
365
366#[derive(Debug, Clone, Deserialize, Default)]
367struct CalibrationUnits {
368    #[serde(default)]
369    elementwise: f64,
370    #[serde(default)]
371    reduction: f64,
372    #[serde(default, rename = "matmul_flops")]
373    matmul_flops: f64,
374}
375
376#[derive(Debug, Clone, Deserialize)]
377struct CalibrationProviderInfo {
378    name: String,
379    vendor: String,
380    #[serde(default)]
381    backend: Option<String>,
382    device_id: u32,
383}
384
385#[derive(Debug, Serialize)]
386pub struct AutoOffloadCalibrationOutcome {
387    pub runs: usize,
388    pub before: ThresholdSnapshot,
389    pub after: ThresholdSnapshot,
390    #[serde(skip_serializing_if = "Option::is_none")]
391    pub delta: Option<ThresholdDelta>,
392    #[serde(skip_serializing_if = "Option::is_none")]
393    pub persisted_to: Option<String>,
394    #[serde(skip_serializing_if = "Option::is_none")]
395    pub provider: Option<CachedProviderInfo>,
396    pub commit: bool,
397}
398
399fn load_calibration_sample(path: &Path) -> Result<CalibrationSample> {
400    let payload = fs::read_to_string(path).map_err(|e| anyhow!(e.to_string()))?;
401    let file: CalibrationFile = serde_json::from_str(&payload)
402        .map_err(|e| anyhow!(format!("failed to parse calibration file: {e}")))?;
403    if let Some(suite) = file.suite {
404        if let Some(sample) = suite.auto_offload_calibration {
405            return Ok(sample);
406        }
407    }
408    if let Some(sample) = file.auto_offload_calibration {
409        return Ok(sample);
410    }
411    Err(anyhow!(
412        "calibration file does not contain an auto_offload_calibration section"
413    ))
414}
415
416fn apply_calibration_sample(
417    cfg: &mut ThresholdConfig,
418    sample: &CalibrationSample,
419) -> Option<ThresholdDelta> {
420    let mut delta = ThresholdDelta::default();
421    let mut changed = false;
422
423    if sample.units.elementwise > 0.0 && sample.cpu_time.elementwise > 0.0 {
424        let secs_per_elem = (sample.cpu_time.elementwise / 1_000.0) / sample.units.elementwise;
425        if secs_per_elem.is_finite()
426            && secs_per_elem > 0.0
427            && (cfg.cpu_elem_per_elem - secs_per_elem).abs() > f64::EPSILON
428        {
429            delta.cpu_elem_per_elem = Some(ThresholdDeltaEntry::new(
430                cfg.cpu_elem_per_elem,
431                secs_per_elem,
432            ));
433            cfg.cpu_elem_per_elem = secs_per_elem;
434            changed = true;
435        }
436    }
437
438    if sample.units.reduction > 0.0 && sample.cpu_time.reduction > 0.0 {
439        let secs_per_elem = (sample.cpu_time.reduction / 1_000.0) / sample.units.reduction;
440        if secs_per_elem.is_finite()
441            && secs_per_elem > 0.0
442            && (cfg.cpu_reduction_per_elem - secs_per_elem).abs() > f64::EPSILON
443        {
444            delta.cpu_reduction_per_elem = Some(ThresholdDeltaEntry::new(
445                cfg.cpu_reduction_per_elem,
446                secs_per_elem,
447            ));
448            cfg.cpu_reduction_per_elem = secs_per_elem;
449            changed = true;
450        }
451    }
452
453    if sample.units.matmul_flops > 0.0 && sample.cpu_time.matmul > 0.0 {
454        let secs_per_flop = (sample.cpu_time.matmul / 1_000.0) / sample.units.matmul_flops;
455        if secs_per_flop.is_finite()
456            && secs_per_flop > 0.0
457            && (cfg.cpu_matmul_per_flop - secs_per_flop).abs() > f64::EPSILON
458        {
459            delta.cpu_matmul_per_flop = Some(ThresholdDeltaEntry::new(
460                cfg.cpu_matmul_per_flop,
461                secs_per_flop,
462            ));
463            cfg.cpu_matmul_per_flop = secs_per_flop;
464            changed = true;
465        }
466    }
467
468    if changed {
469        Some(delta)
470    } else {
471        None
472    }
473}
474
475pub fn apply_auto_offload_calibration_from_file(
476    path: &Path,
477    commit: bool,
478) -> Result<AutoOffloadCalibrationOutcome> {
479    let sample = load_calibration_sample(path)?;
480    if sample.runs == 0 {
481        return Err(anyhow!("calibration sample contains zero runs"));
482    }
483
484    let provider = runmat_accelerate_api::provider()
485        .ok_or_else(|| anyhow!("no acceleration provider registered"))?;
486    let device_info = provider.device_info_struct();
487
488    if let Some(ref prov) = sample.provider {
489        if prov.name != device_info.name
490            || prov.vendor != device_info.vendor
491            || prov.backend.as_deref() != device_info.backend.as_deref()
492            || prov.device_id != device_info.device_id
493        {
494            warn!(
495                "Calibration provider mismatch: sample='{} ({})' device='{} ({})'",
496                prov.name, prov.vendor, device_info.name, device_info.vendor
497            );
498        }
499        if sample.provider_conflict {
500            warn!("Calibration sample reported provider conflict across cases");
501        }
502    }
503
504    let (mut cfg, _) = load_cached_thresholds(&device_info)
505        .unwrap_or_else(|| (ThresholdConfig::default(), PathBuf::new()));
506    let before_cfg = cfg.clone();
507
508    let delta = apply_calibration_sample(&mut cfg, &sample)
509        .ok_or_else(|| anyhow!("calibration sample did not produce coefficient updates"))?;
510
511    let mut persisted_to: Option<PathBuf> = None;
512    if commit {
513        persisted_to = Some(persist_thresholds(&device_info, &cfg)?);
514    }
515
516    if let Some(state_mutex) = AUTO_STATE.get() {
517        if let Ok(mut state) = state_mutex.lock() {
518            state.previous_thresholds = Some(before_cfg.clone());
519            state.calibration_delta = Some(delta.clone());
520            if commit {
521                state.thresholds = cfg.clone();
522                state.base_source = ThresholdBase::Calibrated;
523                if let Some(ref path_buf) = persisted_to {
524                    state.cache_path = Some(path_buf.to_string_lossy().into_owned());
525                }
526                state.calibrate_duration_ms = None;
527            }
528        }
529    }
530
531    Ok(AutoOffloadCalibrationOutcome {
532        runs: sample.runs,
533        before: threshold_snapshot(&before_cfg),
534        after: threshold_snapshot(&cfg),
535        delta: Some(delta),
536        persisted_to: persisted_to.map(|p| p.to_string_lossy().into_owned()),
537        provider: Some(cached_provider_info(&device_info)),
538        commit,
539    })
540}
541
542fn cached_provider_info(info: &ApiDeviceInfo) -> CachedProviderInfo {
543    CachedProviderInfo {
544        name: info.name.clone(),
545        vendor: info.vendor.clone(),
546        backend: info.backend.clone(),
547        device_id: info.device_id,
548    }
549}
550
551fn cpu_estimate(per_unit: f64, units: usize) -> Option<f64> {
552    if per_unit.is_finite() && per_unit > 0.0 {
553        Some(per_unit * units as f64)
554    } else {
555        None
556    }
557}
558
559fn value_shape(value: &Value) -> Option<&[usize]> {
560    match value {
561        Value::Tensor(t) => Some(&t.shape),
562        Value::GpuTensor(handle) => Some(&handle.shape),
563        _ => None,
564    }
565}
566
567fn batch_dimension_from_value(value: &Value) -> Option<usize> {
568    let shape = value_shape(value)?;
569    if shape.len() < 3 {
570        return None;
571    }
572    shape.last().copied()
573}
574
575fn batch_dimension_from_values(values: &[&Value]) -> Option<usize> {
576    values
577        .iter()
578        .filter_map(|value| batch_dimension_from_value(value))
579        .min()
580}
581
582fn decision_entry(
583    operation: &str,
584    elements: Option<usize>,
585    flops: Option<usize>,
586    eval: &DecisionEvaluation,
587) -> AutoOffloadDecisionEntry {
588    AutoOffloadDecisionEntry {
589        timestamp_ms: now_millis(),
590        operation: operation.to_string(),
591        elements,
592        flops,
593        batch: eval.batch,
594        decision: if eval.recommend_gpu {
595            AutoOffloadDisposition::Gpu
596        } else {
597            AutoOffloadDisposition::Cpu
598        },
599        reason: eval.reason,
600        cpu_estimate_ms: eval.cpu_secs.map(|secs| secs * 1_000.0),
601        gpu_estimate_ms: eval.gpu_secs.map(|secs| secs * 1_000.0),
602        threshold: eval.threshold,
603        fusion_kind: eval.fusion_kind.clone(),
604    }
605}
606
607pub struct NativeAutoOffload {
608    provider: &'static dyn AccelProvider,
609    thresholds: ThresholdConfig,
610    enabled: bool,
611}
612
613static GLOBAL: OnceCell<Option<NativeAutoOffload>> = OnceCell::new();
614static PROFILE_MODEL: OnceCell<Option<ProfileCostModel>> = OnceCell::new();
615
616fn env_bool(key: &str) -> Option<bool> {
617    env::var(key).ok().and_then(|v| parse_bool(&v))
618}
619
620fn parse_bool(s: &str) -> Option<bool> {
621    match s.trim().to_ascii_lowercase().as_str() {
622        "1" | "true" | "yes" | "on" => Some(true),
623        "0" | "false" | "no" | "off" => Some(false),
624        _ => None,
625    }
626}
627
628fn log_promotion<F>(builder: F)
629where
630    F: FnOnce() -> String,
631{
632    match auto_offload_options().log_level {
633        AutoOffloadLogLevel::Off => {}
634        AutoOffloadLogLevel::Info => info!("{}", builder()),
635        AutoOffloadLogLevel::Trace => trace!("{}", builder()),
636    }
637}
638
639fn update_cpu_cost(slot: &mut f64, candidate: f64) {
640    if candidate.is_finite() && candidate > 0.0 && candidate < *slot {
641        *slot = candidate;
642    }
643}
644
645fn value_len(value: &Value) -> Option<usize> {
646    match value {
647        Value::Tensor(t) => Some(t.data.len()),
648        Value::GpuTensor(handle) => Some(handle.shape.iter().product()),
649        Value::Num(_) | Value::Bool(_) | Value::Int(_) => Some(1),
650        Value::Complex(_, _) => Some(1),
651        _ => None,
652    }
653}
654
655fn element_count_pair(a: &Value, b: &Value) -> Option<usize> {
656    let la = value_len(a)?;
657    let lb = value_len(b)?;
658    Some(la.max(lb))
659}
660
661pub fn global() -> Option<&'static NativeAutoOffload> {
662    GLOBAL.get_or_init(initialize).as_ref()
663}
664
665fn initialize() -> Option<NativeAutoOffload> {
666    if !auto_enabled() {
667        clear_decisions();
668        return None;
669    }
670    let provider = runmat_accelerate_api::provider()?;
671    let device_info = provider.device_info_struct();
672    let mut config = ThresholdConfig::default();
673    let mut base_source = ThresholdBase::BuiltInDefault;
674    let mut cache_path: Option<PathBuf> = None;
675    let mut calibrate_duration_ms: Option<u128> = None;
676    let refresh_calibration = calibrate_refresh_enabled();
677
678    if !refresh_calibration {
679        if let Some((cached, path)) = load_cached_thresholds(&device_info) {
680            info!(
681                "Native auto-offload: loaded cached calibration for '{}' from {}",
682                device_info.name,
683                path.display()
684            );
685            config = cached;
686            cache_path = Some(path);
687            base_source = ThresholdBase::LoadedFromCache;
688        }
689    }
690
691    let needs_calibration = calibrate_enabled() && (refresh_calibration || cache_path.is_none());
692    if needs_calibration {
693        let start = Instant::now();
694        match auto_calibrate(provider, &mut config) {
695            Ok(()) => {
696                calibrate_duration_ms = Some(start.elapsed().as_millis());
697                base_source = ThresholdBase::Calibrated;
698                match persist_thresholds(&device_info, &config) {
699                    Ok(path) => {
700                        cache_path = Some(path.clone());
701                        info!(
702                            "Native auto-offload: persisted calibration for '{}' to {}",
703                            device_info.name,
704                            path.display()
705                        );
706                    }
707                    Err(err) => {
708                        debug!("Native auto-offload: failed to persist calibration: {err}");
709                    }
710                }
711            }
712            Err(err) => {
713                debug!("Native auto-offload calibration failed: {err}");
714            }
715        }
716    }
717
718    let env_overrides_applied = apply_env_overrides(&mut config);
719    let model_status = if profile_cost_model().is_some() {
720        "profile"
721    } else {
722        "fallback"
723    };
724    info!(
725        "Native auto-offload thresholds: unary={} binary={} reduction={} matmul_flops={} small_batch_dim={} small_batch_min_elems={} (model: {}, source: {}, env_overrides={})",
726        config.unary_min_elems,
727        config.binary_min_elems,
728        config.reduction_min_elems,
729        config.matmul_min_flops,
730        config.small_batch_max_dim,
731        config.small_batch_min_elems,
732        model_status,
733        base_source.as_str(),
734        env_overrides_applied
735    );
736
737    let cache_path_str = cache_path
738        .as_ref()
739        .map(|p| p.to_string_lossy().into_owned());
740    let state = AutoOffloadState {
741        provider: Some(cached_provider_info(&device_info)),
742        thresholds: config.clone(),
743        base_source,
744        env_overrides_applied,
745        cache_path: cache_path_str,
746        calibrate_duration_ms,
747        previous_thresholds: None,
748        calibration_delta: None,
749    };
750    let _ = AUTO_STATE.set(Mutex::new(state));
751
752    Some(NativeAutoOffload::new(provider, config))
753}
754
755impl NativeAutoOffload {
756    fn new(provider: &'static dyn AccelProvider, thresholds: ThresholdConfig) -> Self {
757        let enabled = true;
758        Self {
759            provider,
760            thresholds,
761            enabled,
762        }
763    }
764
765    fn promote_tensor_if_large(&self, value: &Value, threshold: usize) -> Result<Value> {
766        match value {
767            Value::GpuTensor(_) => Ok(value.clone()),
768            Value::Tensor(t) => {
769                if ensure_provider_supports_dtype(self.provider, t.dtype).is_err() {
770                    return Ok(value.clone());
771                }
772                if t.data.len() >= threshold && threshold > 0 {
773                    log_promotion(|| {
774                        format!(
775                            "Promoting tensor to GPU (len={}, threshold={})",
776                            t.data.len(),
777                            threshold
778                        )
779                    });
780                    self.tensor_to_gpu(t)
781                } else {
782                    Ok(value.clone())
783                }
784            }
785            _ => Ok(value.clone()),
786        }
787    }
788
789    fn tensor_to_gpu(&self, tensor: &Tensor) -> Result<Value> {
790        let view = HostTensorView {
791            data: &tensor.data,
792            shape: &tensor.shape,
793        };
794        let handle = self
795            .provider
796            .upload(&view)
797            .map_err(|e| anyhow!(e.to_string()))?;
798        Ok(Value::GpuTensor(handle))
799    }
800
801    fn small_batch_guard(&self, elements: usize, batch: Option<usize>) -> bool {
802        if !self.enabled {
803            return false;
804        }
805        let Some(batch) = batch else {
806            return false;
807        };
808        if batch == 0 {
809            return false;
810        }
811        let thresholds = &self.thresholds;
812        thresholds.small_batch_max_dim > 0
813            && thresholds.small_batch_min_elems > 0
814            && batch <= thresholds.small_batch_max_dim
815            && elements >= thresholds.small_batch_min_elems
816    }
817
818    fn promote_binary(&self, op: BinaryOp, a: &Value, b: &Value) -> Result<(Value, Value)> {
819        if !self.enabled {
820            return Ok((a.clone(), b.clone()));
821        }
822        match op {
823            BinaryOp::Elementwise => {
824                let elems = element_count_pair(a, b).unwrap_or(0);
825                let eval = self.evaluate_elementwise(elems, &[a, b]);
826                record_decision(decision_entry("elementwise", Some(elems), None, &eval));
827                if eval.recommend_gpu {
828                    log_promotion(|| format!("Elementwise offload accepted ({} elems)", elems));
829                    let a_p = self.promote_tensor_if_large(a, 1)?;
830                    let b_p = self.promote_tensor_if_large(b, 1)?;
831                    Ok((a_p, b_p))
832                } else {
833                    Ok((a.clone(), b.clone()))
834                }
835            }
836            BinaryOp::MatMul => {
837                if let (Some((ra, ca)), Some((rb, cb))) = (tensor_rows_cols(a), tensor_rows_cols(b))
838                {
839                    if ca != rb {
840                        return Ok((a.clone(), b.clone()));
841                    }
842                    let flops = ra.saturating_mul(ca).saturating_mul(cb);
843                    let eval = self.evaluate_matmul(flops);
844                    record_decision(decision_entry("matmul", None, Some(flops), &eval));
845                    if eval.recommend_gpu {
846                        log_promotion(|| {
847                            format!(
848                                "Promoting matmul operands (flops={}, threshold={})",
849                                flops, self.thresholds.matmul_min_flops
850                            )
851                        });
852                        let a_p = self.promote_tensor_if_large(a, 1)?;
853                        let b_p = self.promote_tensor_if_large(b, 1)?;
854                        return Ok((a_p, b_p));
855                    }
856                }
857                Ok((a.clone(), b.clone()))
858            }
859        }
860    }
861
862    fn promote_unary(&self, op: UnaryOp, v: &Value) -> Result<Value> {
863        if !self.enabled {
864            return Ok(v.clone());
865        }
866        let elems = value_len(v).unwrap_or(0);
867        let eval = self.evaluate_unary(elems, op, v);
868        let op_label = match op {
869            UnaryOp::Transpose => "transpose",
870            UnaryOp::Generic => "unary",
871        };
872        record_decision(decision_entry(op_label, Some(elems), None, &eval));
873        if eval.recommend_gpu {
874            log_promotion(|| format!("Unary offload accepted ({:?}, {} elems)", op, elems));
875            self.promote_tensor_if_large(v, 1)
876        } else {
877            Ok(v.clone())
878        }
879    }
880
881    fn promote_reduction(&self, _op: ReductionOp, args: &[Value]) -> Result<Vec<Value>> {
882        if !self.enabled || args.is_empty() {
883            return Ok(args.to_vec());
884        }
885        let elems = value_len(&args[0]).unwrap_or(0);
886        let eval = self.evaluate_reduction(elems);
887        record_decision(decision_entry("reduction", Some(elems), None, &eval));
888        if !eval.recommend_gpu {
889            return Ok(args.to_vec());
890        }
891        log_promotion(|| format!("Reduction offload accepted ({} elems)", elems));
892        let mut out = Vec::with_capacity(args.len());
893        if let Some(first) = args.first() {
894            out.push(self.promote_tensor_if_large(first, 1)?);
895            out.extend(args.iter().skip(1).cloned());
896        }
897        Ok(out)
898    }
899
900    fn evaluate_elementwise(&self, elements: usize, values: &[&Value]) -> DecisionEvaluation {
901        let fusion = active_fusion();
902        let fusion_kind = fusion.as_ref().map(|f| f.kind.clone());
903        let batch = batch_dimension_from_values(values);
904        let cpu_secs = cpu_estimate(self.thresholds.cpu_elem_per_elem, elements);
905
906        // Chain-aware residency: if any input is already on GPU, keep the op on GPU
907        if values.iter().any(|v| matches!(v, Value::GpuTensor(_))) {
908            return DecisionEvaluation {
909                recommend_gpu: true,
910                reason: DecisionReason::Residency,
911                cpu_secs,
912                gpu_secs: None,
913                threshold: Some(self.thresholds.binary_min_elems),
914                fusion_kind,
915                batch,
916            };
917        }
918
919        if let Some(active) = fusion.as_ref() {
920            // If an elementwise chain is actively fused OR this elementwise op
921            // participates in a fused reduction group, force GPU to keep the
922            // whole chain resident and avoid host round-trips.
923            if (active.kind.is_elementwise() || active.kind.is_reduction()) && active.supported {
924                return DecisionEvaluation {
925                    recommend_gpu: true,
926                    reason: DecisionReason::FusionOverride,
927                    cpu_secs,
928                    gpu_secs: None,
929                    threshold: Some(self.thresholds.binary_min_elems),
930                    fusion_kind,
931                    batch,
932                };
933            }
934        }
935
936        if self.small_batch_guard(elements, batch) {
937            return DecisionEvaluation {
938                recommend_gpu: false,
939                reason: DecisionReason::SmallBatchGuard,
940                cpu_secs,
941                gpu_secs: None,
942                threshold: Some(self.thresholds.binary_min_elems),
943                fusion_kind,
944                batch,
945            };
946        }
947
948        if let Some(model) = profile_cost_model() {
949            if let Some(gpu_duration) = model.estimate_elemwise(elements) {
950                let gpu_secs = Some(gpu_duration.as_secs_f64());
951                let cpu = cpu_secs.unwrap_or(f64::INFINITY);
952                let recommend = gpu_duration.as_secs_f64() * 0.95 < cpu;
953                return DecisionEvaluation {
954                    recommend_gpu: recommend,
955                    reason: DecisionReason::ProfileModel,
956                    cpu_secs,
957                    gpu_secs,
958                    threshold: Some(self.thresholds.binary_min_elems),
959                    fusion_kind,
960                    batch,
961                };
962            }
963        }
964
965        DecisionEvaluation {
966            recommend_gpu: elements >= self.thresholds.binary_min_elems,
967            reason: DecisionReason::Threshold,
968            cpu_secs,
969            gpu_secs: None,
970            threshold: Some(self.thresholds.binary_min_elems),
971            fusion_kind,
972            batch,
973        }
974    }
975
976    fn evaluate_matmul(&self, flops: usize) -> DecisionEvaluation {
977        let cpu_secs = cpu_estimate(self.thresholds.cpu_matmul_per_flop, flops);
978        if let Some(model) = profile_cost_model() {
979            if let Some(gpu_duration) = model.estimate_matmul_flops(flops) {
980                let gpu_secs = Some(gpu_duration.as_secs_f64());
981                let cpu = cpu_secs.unwrap_or(f64::INFINITY);
982                let recommend = gpu_duration.as_secs_f64() * 0.95 < cpu;
983                return DecisionEvaluation {
984                    recommend_gpu: recommend,
985                    reason: DecisionReason::ProfileModel,
986                    cpu_secs,
987                    gpu_secs,
988                    threshold: Some(self.thresholds.matmul_min_flops),
989                    fusion_kind: None,
990                    batch: None,
991                };
992            }
993        }
994
995        DecisionEvaluation {
996            recommend_gpu: flops >= self.thresholds.matmul_min_flops,
997            reason: DecisionReason::Threshold,
998            cpu_secs,
999            gpu_secs: None,
1000            threshold: Some(self.thresholds.matmul_min_flops),
1001            fusion_kind: None,
1002            batch: None,
1003        }
1004    }
1005
1006    fn evaluate_reduction(&self, elements: usize) -> DecisionEvaluation {
1007        let fusion_kind = active_fusion().map(|f| f.kind.clone());
1008        let cpu_secs = cpu_estimate(self.thresholds.cpu_reduction_per_elem, elements);
1009        if let Some(model) = profile_cost_model() {
1010            if let Some(gpu_duration) = model.estimate_reduction(elements) {
1011                let gpu_secs = Some(gpu_duration.as_secs_f64());
1012                let cpu = cpu_secs.unwrap_or(f64::INFINITY);
1013                let recommend = gpu_duration.as_secs_f64() * 0.95 < cpu;
1014                return DecisionEvaluation {
1015                    recommend_gpu: recommend,
1016                    reason: DecisionReason::ProfileModel,
1017                    cpu_secs,
1018                    gpu_secs,
1019                    threshold: Some(self.thresholds.reduction_min_elems),
1020                    fusion_kind,
1021                    batch: None,
1022                };
1023            }
1024        }
1025
1026        DecisionEvaluation {
1027            recommend_gpu: elements >= self.thresholds.reduction_min_elems,
1028            reason: DecisionReason::Threshold,
1029            cpu_secs,
1030            gpu_secs: None,
1031            threshold: Some(self.thresholds.reduction_min_elems),
1032            fusion_kind,
1033            batch: None,
1034        }
1035    }
1036
1037    fn evaluate_unary(&self, elements: usize, op: UnaryOp, value: &Value) -> DecisionEvaluation {
1038        let fusion_kind = active_fusion().map(|f| f.kind.clone());
1039        let batch = batch_dimension_from_values(&[value]);
1040        // Chain-aware residency for unary ops: if operand is already on GPU, keep it on GPU
1041        if matches!(value, Value::GpuTensor(_)) {
1042            return DecisionEvaluation {
1043                recommend_gpu: true,
1044                reason: DecisionReason::Residency,
1045                cpu_secs: cpu_estimate(self.thresholds.cpu_elem_per_elem, elements),
1046                gpu_secs: None,
1047                threshold: Some(self.thresholds.unary_min_elems),
1048                fusion_kind,
1049                batch,
1050            };
1051        }
1052        if matches!(op, UnaryOp::Generic) && self.small_batch_guard(elements, batch) {
1053            return DecisionEvaluation {
1054                recommend_gpu: false,
1055                reason: DecisionReason::SmallBatchGuard,
1056                cpu_secs: cpu_estimate(self.thresholds.cpu_elem_per_elem, elements),
1057                gpu_secs: None,
1058                threshold: Some(self.thresholds.unary_min_elems),
1059                fusion_kind,
1060                batch,
1061            };
1062        }
1063
1064        let cpu_secs = cpu_estimate(self.thresholds.cpu_elem_per_elem, elements);
1065        if let Some(model) = profile_cost_model() {
1066            let gpu_duration = match op {
1067                UnaryOp::Transpose => model.estimate_transpose(elements),
1068                UnaryOp::Generic => model.estimate_elemwise(elements),
1069            };
1070            if let Some(gpu_duration) = gpu_duration {
1071                let gpu_secs = Some(gpu_duration.as_secs_f64());
1072                let cpu = cpu_secs.unwrap_or(f64::INFINITY);
1073                let recommend = gpu_duration.as_secs_f64() * 0.95 < cpu;
1074                return DecisionEvaluation {
1075                    recommend_gpu: recommend,
1076                    reason: DecisionReason::ProfileModel,
1077                    cpu_secs,
1078                    gpu_secs,
1079                    threshold: Some(self.thresholds.unary_min_elems),
1080                    fusion_kind,
1081                    batch,
1082                };
1083            }
1084        }
1085
1086        DecisionEvaluation {
1087            recommend_gpu: elements >= self.thresholds.unary_min_elems,
1088            reason: DecisionReason::Threshold,
1089            cpu_secs,
1090            gpu_secs: None,
1091            threshold: Some(self.thresholds.unary_min_elems),
1092            fusion_kind,
1093            batch,
1094        }
1095    }
1096
1097    fn prepare_builtin(&self, name: &str, args: &[Value]) -> Result<Vec<Value>> {
1098        if !self.enabled {
1099            return Ok(args.to_vec());
1100        }
1101        // Do not attempt to promote 'double' on providers that cannot store f64.
1102        // Offloading a cast to double requires device-side f64; otherwise keep host.
1103        if name.eq_ignore_ascii_case("double")
1104            && self.provider.precision() != runmat_accelerate_api::ProviderPrecision::F64
1105        {
1106            return Ok(args.to_vec());
1107        }
1108        if let Some(policy) = builtin_policy(name) {
1109            if policy.is_sink {
1110                return gather_args(args);
1111            }
1112
1113            let mut processed = args.to_vec();
1114
1115            if policy
1116                .accel_tags
1117                .iter()
1118                .any(|tag| matches!(tag, AccelTag::Reduction))
1119            {
1120                if (name.eq_ignore_ascii_case("max") || name.eq_ignore_ascii_case("min"))
1121                    && !max_or_min_reduction_call(args)
1122                {
1123                    trace!(
1124                        "Skipping reduction promotion for builtin '{}' (detected elementwise form)",
1125                        name
1126                    );
1127                } else {
1128                    log_promotion(|| format!("Promoting builtin '{}' as reduction", name));
1129                    return self.promote_reduction(reduction_op_hint(name), args);
1130                }
1131            }
1132
1133            if policy
1134                .accel_tags
1135                .iter()
1136                .any(|tag| matches!(tag, AccelTag::MatMul))
1137                && processed.len() >= 2
1138            {
1139                log_promotion(|| format!("Promoting builtin '{}' as matmul", name));
1140                let (a_p, b_p) =
1141                    self.promote_binary(BinaryOp::MatMul, &processed[0], &processed[1])?;
1142                processed[0] = a_p;
1143                processed[1] = b_p;
1144                return Ok(processed);
1145            }
1146
1147            if policy
1148                .accel_tags
1149                .iter()
1150                .any(|tag| matches!(tag, AccelTag::Elementwise))
1151                && processed.len() >= 2
1152            {
1153                log_promotion(|| format!("Promoting builtin '{}' as elementwise", name));
1154                let (a_p, b_p) =
1155                    self.promote_binary(BinaryOp::Elementwise, &processed[0], &processed[1])?;
1156                processed[0] = a_p;
1157                processed[1] = b_p;
1158                return Ok(processed);
1159            }
1160
1161            if let Some(first) = processed.first_mut() {
1162                if policy
1163                    .accel_tags
1164                    .iter()
1165                    .any(|tag| matches!(tag, AccelTag::Transpose))
1166                {
1167                    log_promotion(|| format!("Promoting builtin '{}' as transpose", name));
1168                    *first = self.promote_unary(UnaryOp::Transpose, first)?;
1169                    return Ok(processed);
1170                }
1171
1172                if policy
1173                    .accel_tags
1174                    .iter()
1175                    .any(|tag| matches!(tag, AccelTag::Unary))
1176                {
1177                    log_promotion(|| format!("Promoting builtin '{}' as unary", name));
1178                    *first = self.promote_unary(UnaryOp::Generic, first)?;
1179                    return Ok(processed);
1180                }
1181            }
1182        }
1183        Ok(args.to_vec())
1184    }
1185}
1186
1187fn tensor_rows_cols(value: &Value) -> Option<(usize, usize)> {
1188    match value {
1189        Value::Tensor(t) => Some((t.rows(), t.cols())),
1190        Value::GpuTensor(handle) => {
1191            if handle.shape.len() == 2 {
1192                Some((handle.shape[0], handle.shape[1]))
1193            } else {
1194                None
1195            }
1196        }
1197        _ => None,
1198    }
1199}
1200
1201#[allow(dead_code)]
1202fn should_skip_reduction_promotion(name: &str, args: &[Value]) -> bool {
1203    (name.eq_ignore_ascii_case("max") || name.eq_ignore_ascii_case("min"))
1204        && !max_or_min_reduction_call(args)
1205}
1206
1207fn reduction_op_hint(name: &str) -> ReductionOp {
1208    if name.eq_ignore_ascii_case("max") {
1209        ReductionOp::Max
1210    } else if name.eq_ignore_ascii_case("min") {
1211        ReductionOp::Min
1212    } else {
1213        ReductionOp::Sum
1214    }
1215}
1216
1217fn max_or_min_reduction_call(args: &[Value]) -> bool {
1218    if args.len() <= 1 {
1219        return true;
1220    }
1221    args.get(1).map(is_empty_placeholder_value).unwrap_or(false)
1222}
1223
1224fn is_empty_placeholder_value(value: &Value) -> bool {
1225    match value {
1226        Value::Tensor(t) => t.data.is_empty(),
1227        Value::LogicalArray(l) => l.data.is_empty(),
1228        Value::StringArray(sa) => sa.data.is_empty(),
1229        Value::CharArray(ca) => ca.data.is_empty(),
1230        Value::Cell(cell) => cell.data.is_empty(),
1231        Value::String(s) => s.is_empty(),
1232        _ => false,
1233    }
1234}
1235
1236fn gather_args(args: &[Value]) -> Result<Vec<Value>> {
1237    let mut out = Vec::with_capacity(args.len());
1238    for value in args {
1239        if let Value::GpuTensor(handle) = value {
1240            fusion_residency::clear(handle);
1241        }
1242        out.push(gather_if_needed(value).map_err(|e| anyhow!(e))?);
1243    }
1244    Ok(out)
1245}
1246
1247#[cfg(test)]
1248mod tests {
1249    use super::*;
1250
1251    #[test]
1252    fn max_detection_handles_placeholders() {
1253        let tensor = Tensor::new(vec![1.0], vec![1, 1]).unwrap();
1254        let placeholder = Tensor::new(Vec::<f64>::new(), vec![0, 0]).unwrap();
1255        let data = Value::Tensor(tensor);
1256        let empty = Value::Tensor(placeholder);
1257
1258        assert!(max_or_min_reduction_call(std::slice::from_ref(&data)));
1259        assert!(max_or_min_reduction_call(&[
1260            data.clone(),
1261            empty.clone(),
1262            Value::Num(1.0)
1263        ]));
1264        assert!(!max_or_min_reduction_call(&[data.clone(), Value::Num(0.0)]));
1265    }
1266}
1267
1268#[derive(Clone, Copy)]
1269struct BuiltinPolicy {
1270    accel_tags: &'static [AccelTag],
1271    is_sink: bool,
1272}
1273
1274static BUILTIN_POLICIES: OnceCell<HashMap<String, BuiltinPolicy>> = OnceCell::new();
1275
1276fn build_builtin_policy_map() -> HashMap<String, BuiltinPolicy> {
1277    let mut map = HashMap::new();
1278    for func in builtin_functions() {
1279        map.insert(
1280            func.name.to_ascii_lowercase(),
1281            BuiltinPolicy {
1282                accel_tags: func.accel_tags,
1283                is_sink: func.is_sink,
1284            },
1285        );
1286    }
1287    map
1288}
1289
1290fn builtin_policy(name: &str) -> Option<BuiltinPolicy> {
1291    let map = BUILTIN_POLICIES.get_or_init(build_builtin_policy_map);
1292    map.get(&name.to_ascii_lowercase()).copied()
1293}
1294
1295fn auto_enabled() -> bool {
1296    if let Some(flag) = env_bool("RUNMAT_ACCEL_AUTO_OFFLOAD") {
1297        return flag;
1298    }
1299    auto_offload_options().enabled
1300}
1301
1302fn calibrate_enabled() -> bool {
1303    if let Some(flag) = env_bool("RUNMAT_ACCEL_CALIBRATE") {
1304        return flag;
1305    }
1306    auto_offload_options().calibrate
1307}
1308
1309fn calibrate_refresh_enabled() -> bool {
1310    env_bool("RUNMAT_ACCEL_CALIBRATE_REFRESH").unwrap_or(false)
1311}
1312
1313fn apply_env_overrides(cfg: &mut ThresholdConfig) -> bool {
1314    let mut applied = false;
1315    if let Some(val) = env_usize("RUNMAT_ACCEL_THRESHOLD_UNARY") {
1316        cfg.unary_min_elems = val;
1317        applied = true;
1318    }
1319    if let Some(val) = env_usize("RUNMAT_ACCEL_THRESHOLD_ELEMWISE") {
1320        cfg.binary_min_elems = val;
1321        applied = true;
1322    }
1323    if let Some(val) = env_usize("RUNMAT_ACCEL_THRESHOLD_REDUCTION") {
1324        cfg.reduction_min_elems = val;
1325        applied = true;
1326    }
1327    if let Some(val) = env_usize("RUNMAT_ACCEL_THRESHOLD_MATMUL") {
1328        cfg.matmul_min_flops = val;
1329        applied = true;
1330    }
1331    if let Some(val) = env_usize("RUNMAT_ACCEL_THRESHOLD_ALL") {
1332        cfg.unary_min_elems = val;
1333        cfg.binary_min_elems = val;
1334        cfg.reduction_min_elems = val;
1335        applied = true;
1336    }
1337    if let Some(val) = env_usize("RUNMAT_ACCEL_SMALL_BATCH_MAX_DIM") {
1338        cfg.small_batch_max_dim = val;
1339        applied = true;
1340    }
1341    if let Some(val) = env_usize("RUNMAT_ACCEL_SMALL_BATCH_MIN_ELEMS") {
1342        cfg.small_batch_min_elems = val;
1343        applied = true;
1344    }
1345    applied
1346}
1347
1348fn env_usize(key: &str) -> Option<usize> {
1349    env::var(key).ok().and_then(|v| v.parse::<usize>().ok())
1350}
1351
1352#[derive(Debug, Clone, Serialize, Deserialize)]
1353struct CalibrationRecord {
1354    version: u32,
1355    recorded_at: u64,
1356    provider: CalibrationProviderDetails,
1357    thresholds: ThresholdConfig,
1358}
1359
1360#[derive(Debug, Clone, Serialize, Deserialize)]
1361struct CalibrationProviderDetails {
1362    name: String,
1363    vendor: String,
1364    backend: Option<String>,
1365    device_id: u32,
1366}
1367
1368fn load_cached_thresholds(info: &ApiDeviceInfo) -> Option<(ThresholdConfig, PathBuf)> {
1369    let path = calibration_cache_file(info)?;
1370    let contents = fs::read_to_string(&path).ok()?;
1371    match serde_json::from_str::<CalibrationRecord>(&contents) {
1372        Ok(record) => {
1373            if record.version != CALIBRATION_VERSION {
1374                debug!(
1375                    "Native auto-offload calibration cache version mismatch (found {}, expected {})",
1376                    record.version,
1377                    CALIBRATION_VERSION
1378                );
1379                None
1380            } else {
1381                Some((record.thresholds, path))
1382            }
1383        }
1384        Err(err) => {
1385            debug!(
1386                "Native auto-offload failed to parse cached calibration for '{}': {err}",
1387                info.name
1388            );
1389            None
1390        }
1391    }
1392}
1393
1394fn persist_thresholds(info: &ApiDeviceInfo, cfg: &ThresholdConfig) -> Result<PathBuf> {
1395    let path = calibration_cache_file(info)
1396        .ok_or_else(|| anyhow!("unable to determine calibration cache directory"))?;
1397    if let Some(parent) = path.parent() {
1398        fs::create_dir_all(parent).map_err(|e| anyhow!(e.to_string()))?;
1399    }
1400    let record = CalibrationRecord {
1401        version: CALIBRATION_VERSION,
1402        recorded_at: SystemTime::now()
1403            .duration_since(UNIX_EPOCH)
1404            .unwrap_or_else(|_| Duration::from_secs(0))
1405            .as_secs(),
1406        provider: CalibrationProviderDetails {
1407            name: info.name.clone(),
1408            vendor: info.vendor.clone(),
1409            backend: info.backend.clone(),
1410            device_id: info.device_id,
1411        },
1412        thresholds: cfg.clone(),
1413    };
1414    let payload = serde_json::to_string_pretty(&record).map_err(|e| anyhow!(e.to_string()))?;
1415    fs::write(&path, payload).map_err(|e| anyhow!(e.to_string()))?;
1416    Ok(path)
1417}
1418
1419fn calibration_cache_file(info: &ApiDeviceInfo) -> Option<PathBuf> {
1420    let mut dir = calibration_cache_dir()?;
1421    let vendor = slugify(&info.vendor);
1422    let name = slugify(&info.name);
1423    let backend = slugify(info.backend.as_deref().unwrap_or("unknown"));
1424    let file = format!("{}-{}-{}-{}.json", vendor, name, backend, info.device_id);
1425    dir.push(file);
1426    Some(dir)
1427}
1428
1429fn calibration_cache_dir() -> Option<PathBuf> {
1430    dirs::cache_dir().map(|base| base.join("runmat").join("auto_offload"))
1431}
1432
1433fn slugify(input: &str) -> String {
1434    let mut out = String::with_capacity(input.len());
1435    let mut last_underscore = false;
1436    for ch in input.chars() {
1437        if ch.is_ascii_alphanumeric() {
1438            out.push(ch.to_ascii_lowercase());
1439            last_underscore = false;
1440        } else if !last_underscore {
1441            out.push('_');
1442            last_underscore = true;
1443        }
1444    }
1445    let trimmed = out.trim_matches('_');
1446    if trimmed.is_empty() {
1447        "device".to_string()
1448    } else {
1449        trimmed.to_string()
1450    }
1451}
1452
1453fn auto_calibrate(provider: &'static dyn AccelProvider, cfg: &mut ThresholdConfig) -> Result<()> {
1454    if let Some(elem_threshold) = calibrate_elemwise(provider, cfg).transpose()? {
1455        if elem_threshold != usize::MAX {
1456            cfg.binary_min_elems = elem_threshold;
1457            cfg.unary_min_elems = cfg.unary_min_elems.min(elem_threshold);
1458        }
1459    }
1460    if let Some(red_threshold) = calibrate_reduction(provider, cfg).transpose()? {
1461        if red_threshold != usize::MAX {
1462            cfg.reduction_min_elems = red_threshold;
1463        }
1464    }
1465    if let Some(matmul_threshold) = calibrate_matmul(provider, cfg).transpose()? {
1466        if matmul_threshold != usize::MAX {
1467            cfg.matmul_min_flops = matmul_threshold;
1468        }
1469    }
1470    Ok(())
1471}
1472
1473fn calibrate_elemwise(
1474    provider: &'static dyn AccelProvider,
1475    cfg: &mut ThresholdConfig,
1476) -> Option<Result<usize>> {
1477    let sizes = [256usize, 1_024, 4_096, 16_384, 65_536];
1478    for size in sizes {
1479        match compare_elemwise(provider, size, &mut cfg.cpu_elem_per_elem) {
1480            Ok(Some(true)) => return Some(Ok(size)),
1481            Ok(Some(false)) => continue,
1482            Ok(None) => return None,
1483            Err(e) => return Some(Err(e)),
1484        }
1485    }
1486    Some(Ok(usize::MAX))
1487}
1488
1489fn compare_elemwise(
1490    provider: &'static dyn AccelProvider,
1491    elements: usize,
1492    cpu_cost_slot: &mut f64,
1493) -> Result<Option<bool>> {
1494    if elements == 0 {
1495        return Ok(Some(false));
1496    }
1497    let shape = vec![elements, 1];
1498    let template = match provider.precision() {
1499        ProviderPrecision::F64 => {
1500            Tensor::new((0..elements).map(|i| i as f64).collect(), shape.clone())
1501                .map_err(|e| anyhow!(e))?
1502        }
1503        ProviderPrecision::F32 => {
1504            Tensor::from_f32((0..elements).map(|i| i as f32).collect(), shape.clone())
1505                .map_err(|e| anyhow!(e))?
1506        }
1507    };
1508    let a = Value::Tensor(template.clone());
1509    let b = Value::Tensor(template.clone());
1510    let cpu_time = time(|| runmat_runtime::call_builtin("plus", &[a.clone(), b.clone()]))?;
1511    let cpu_per_elem = cpu_time.as_secs_f64() / elements as f64;
1512    update_cpu_cost(cpu_cost_slot, cpu_per_elem);
1513    if let Some(model) = profile_cost_model() {
1514        if let Some(gpu_time) = model.estimate_elemwise(elements) {
1515            trace!(
1516                "Elemwise calibration ({} elems): cpu={:?}, gpu_est={:?}",
1517                elements,
1518                cpu_time,
1519                gpu_time
1520            );
1521            return Ok(Some(gpu_time < cpu_time));
1522        }
1523    }
1524    let view = HostTensorView {
1525        data: template.data.as_slice(),
1526        shape: template.shape.as_slice(),
1527    };
1528    let ha = provider.upload(&view).map_err(|e| anyhow!(e.to_string()))?;
1529    let hb = provider.upload(&view).map_err(|e| anyhow!(e.to_string()))?;
1530    let start = Instant::now();
1531    let hc = match provider.elem_add(&ha, &hb) {
1532        Ok(h) => h,
1533        Err(_) => {
1534            let _ = provider.free(&ha);
1535            let _ = provider.free(&hb);
1536            return Ok(None);
1537        }
1538    };
1539    let gpu_time = start.elapsed();
1540    let _ = provider.free(&ha);
1541    let _ = provider.free(&hb);
1542    let _ = provider.free(&hc);
1543    Ok(Some(gpu_time < cpu_time))
1544}
1545
1546fn calibrate_reduction(
1547    provider: &'static dyn AccelProvider,
1548    cfg: &mut ThresholdConfig,
1549) -> Option<Result<usize>> {
1550    let sizes = [256usize, 1_024, 4_096, 16_384, 65_536];
1551    for size in sizes {
1552        match compare_reduction(provider, size, &mut cfg.cpu_reduction_per_elem) {
1553            Ok(Some(true)) => return Some(Ok(size)),
1554            Ok(Some(false)) => continue,
1555            Ok(None) => return None,
1556            Err(e) => return Some(Err(e)),
1557        }
1558    }
1559    Some(Ok(usize::MAX))
1560}
1561
1562fn compare_reduction(
1563    provider: &'static dyn AccelProvider,
1564    elements: usize,
1565    cpu_cost_slot: &mut f64,
1566) -> Result<Option<bool>> {
1567    let shape = vec![elements, 1];
1568    let template = match provider.precision() {
1569        ProviderPrecision::F64 => {
1570            Tensor::new((0..elements).map(|i| i as f64).collect(), shape.clone())
1571                .map_err(|e| anyhow!(e))?
1572        }
1573        ProviderPrecision::F32 => {
1574            Tensor::from_f32((0..elements).map(|i| i as f32).collect(), shape.clone())
1575                .map_err(|e| anyhow!(e))?
1576        }
1577    };
1578    let value = Value::Tensor(template.clone());
1579    let cpu_time = time(|| runmat_runtime::call_builtin("sum", std::slice::from_ref(&value)))?;
1580    let cpu_per_elem = cpu_time.as_secs_f64() / elements as f64;
1581    update_cpu_cost(cpu_cost_slot, cpu_per_elem);
1582    if let Some(model) = profile_cost_model() {
1583        if let Some(gpu_time) = model.estimate_reduction(elements) {
1584            trace!(
1585                "Reduction calibration ({} elems): cpu={:?}, gpu_est={:?}",
1586                elements,
1587                cpu_time,
1588                gpu_time
1589            );
1590            return Ok(Some(gpu_time < cpu_time));
1591        }
1592    }
1593    let view = HostTensorView {
1594        data: template.data.as_slice(),
1595        shape: template.shape.as_slice(),
1596    };
1597    let h = provider.upload(&view).map_err(|e| anyhow!(e.to_string()))?;
1598    let start = Instant::now();
1599    let out = match provider.reduce_sum(&h) {
1600        Ok(hc) => hc,
1601        Err(_) => {
1602            provider.free(&h).ok();
1603            return Ok(None);
1604        }
1605    };
1606    let gpu_time = start.elapsed();
1607    let _ = provider.free(&h);
1608    let _ = provider.free(&out);
1609    Ok(Some(gpu_time < cpu_time))
1610}
1611
1612fn calibrate_matmul(
1613    provider: &'static dyn AccelProvider,
1614    cfg: &mut ThresholdConfig,
1615) -> Option<Result<usize>> {
1616    let dims = [32usize, 64, 96, 128, 192];
1617    for n in dims {
1618        match compare_matmul(provider, n, &mut cfg.cpu_matmul_per_flop) {
1619            Ok(Some(true)) => {
1620                let flops = n * n * n;
1621                return Some(Ok(flops));
1622            }
1623            Ok(Some(false)) => continue,
1624            Ok(None) => return None,
1625            Err(e) => return Some(Err(e)),
1626        }
1627    }
1628    Some(Ok(usize::MAX))
1629}
1630
1631fn compare_matmul(
1632    provider: &'static dyn AccelProvider,
1633    n: usize,
1634    cpu_cost_slot: &mut f64,
1635) -> Result<Option<bool>> {
1636    if n == 0 {
1637        return Ok(Some(false));
1638    }
1639    let total = n * n;
1640    let shape = vec![n, n];
1641    let (ta, tb) = match provider.precision() {
1642        ProviderPrecision::F64 => {
1643            let data_a: Vec<f64> = (0..total).map(|i| (i % 13) as f64).collect();
1644            let data_b: Vec<f64> = (0..total).map(|i| (i % 7) as f64).collect();
1645            let ta = Tensor::new(data_a, shape.clone()).map_err(|e| anyhow!(e))?;
1646            let tb = Tensor::new(data_b, shape.clone()).map_err(|e| anyhow!(e))?;
1647            (ta, tb)
1648        }
1649        ProviderPrecision::F32 => {
1650            let data_a: Vec<f32> = (0..total).map(|i| (i % 13) as f32).collect();
1651            let data_b: Vec<f32> = (0..total).map(|i| (i % 7) as f32).collect();
1652            let ta = Tensor::from_f32(data_a, shape.clone()).map_err(|e| anyhow!(e))?;
1653            let tb = Tensor::from_f32(data_b, shape.clone()).map_err(|e| anyhow!(e))?;
1654            (ta, tb)
1655        }
1656    };
1657    let a = Value::Tensor(ta.clone());
1658    let b = Value::Tensor(tb.clone());
1659    let cpu_time = time(|| runmat_runtime::matrix::value_matmul(&a, &b))?;
1660    let flops = (n * n * n) as f64;
1661    update_cpu_cost(cpu_cost_slot, cpu_time.as_secs_f64() / flops);
1662    if let Some(model) = profile_cost_model() {
1663        if let Some(gpu_time) = model.estimate_matmul(n, n, n) {
1664            trace!(
1665                "Matmul calibration ({}^3 flops): cpu={:?}, gpu_est={:?}",
1666                n,
1667                cpu_time,
1668                gpu_time
1669            );
1670            return Ok(Some(gpu_time < cpu_time));
1671        }
1672    }
1673    let view_a = HostTensorView {
1674        data: ta.data.as_slice(),
1675        shape: ta.shape.as_slice(),
1676    };
1677    let view_b = HostTensorView {
1678        data: tb.data.as_slice(),
1679        shape: tb.shape.as_slice(),
1680    };
1681    let ha = provider
1682        .upload(&view_a)
1683        .map_err(|e| anyhow!(e.to_string()))?;
1684    let hb = provider
1685        .upload(&view_b)
1686        .map_err(|e| anyhow!(e.to_string()))?;
1687    let start = Instant::now();
1688    let hc = match provider.matmul(&ha, &hb) {
1689        Ok(h) => h,
1690        Err(_) => {
1691            let _ = provider.free(&ha);
1692            let _ = provider.free(&hb);
1693            return Ok(None);
1694        }
1695    };
1696    let gpu_time = start.elapsed();
1697    let _ = provider.free(&ha);
1698    let _ = provider.free(&hb);
1699    let _ = provider.free(&hc);
1700    Ok(Some(gpu_time < cpu_time))
1701}
1702
1703fn time<F, T>(mut f: F) -> Result<Duration>
1704where
1705    F: FnMut() -> Result<T, String>,
1706{
1707    let start = Instant::now();
1708    let _ = f().map_err(|e| anyhow!(e))?;
1709    Ok(start.elapsed())
1710}
1711
1712pub fn auto_offload_report() -> Option<AutoOffloadReport> {
1713    let state_guard = AUTO_STATE.get()?;
1714    let state = state_guard.lock().ok()?;
1715    let calibration = state.previous_thresholds.as_ref().map(|prev| {
1716        let delta = state
1717            .calibration_delta
1718            .clone()
1719            .unwrap_or_else(|| compute_delta(prev, &state.thresholds));
1720        AutoOffloadCalibrationSummary {
1721            previous: threshold_snapshot(prev),
1722            delta,
1723        }
1724    });
1725    Some(AutoOffloadReport {
1726        provider: state.provider.clone(),
1727        thresholds: threshold_snapshot(&state.thresholds),
1728        base_source: state.base_source,
1729        env_overrides_applied: state.env_overrides_applied,
1730        cache_path: state.cache_path.clone(),
1731        calibrate_duration_ms: state.calibrate_duration_ms,
1732        calibration,
1733        decisions: snapshot_decisions(),
1734    })
1735}
1736
1737pub fn sequence_threshold_hint() -> Option<usize> {
1738    AUTO_STATE
1739        .get()
1740        .and_then(|state| state.lock().ok())
1741        .map(|state| state.thresholds.unary_min_elems)
1742}
1743
1744pub fn reset_auto_offload_log() {
1745    clear_decisions();
1746}
1747
1748#[derive(Clone, Deserialize, Debug)]
1749struct ProfileDurationSummary {
1750    #[serde(default)]
1751    avg_ms: f64,
1752}
1753
1754#[derive(Clone, Deserialize, Debug)]
1755struct ProfileReport {
1756    category: String,
1757    #[serde(default)]
1758    input_shapes: Vec<Vec<usize>>,
1759    total_ms: ProfileDurationSummary,
1760}
1761
1762#[derive(Clone, Copy, Default, Debug)]
1763struct LinearModel {
1764    slope: f64,
1765    intercept: f64,
1766}
1767
1768impl LinearModel {
1769    fn estimate(&self, x: f64) -> Option<Duration> {
1770        if !self.slope.is_finite() || self.slope <= 0.0 {
1771            return None;
1772        }
1773        let total = self.intercept + self.slope * x;
1774        if total.is_finite() && total > 0.0 {
1775            Some(Duration::from_secs_f64(total))
1776        } else {
1777            None
1778        }
1779    }
1780}
1781
1782#[derive(Default)]
1783struct ProfileCostModel {
1784    elem: Option<LinearModel>,
1785    reduction: Option<LinearModel>,
1786    transpose: Option<LinearModel>,
1787    matmul: Option<LinearModel>,
1788}
1789
1790impl ProfileCostModel {
1791    fn from_reports(reports: &[ProfileReport]) -> Self {
1792        let mut elem_samples = Vec::<(f64, f64)>::new();
1793        let mut reduction_samples = Vec::<(f64, f64)>::new();
1794        let mut transpose_samples = Vec::<(f64, f64)>::new();
1795        let mut matmul_samples = Vec::<(f64, f64)>::new();
1796
1797        for report in reports {
1798            let total_secs = report.total_ms.avg_ms / 1_000.0;
1799            match report.category.as_str() {
1800                "elementwise" | "reduction" | "transpose" => {
1801                    if let Some(shape) = report.input_shapes.first() {
1802                        let elems: usize = shape.iter().copied().product();
1803                        if elems == 0 {
1804                            continue;
1805                        }
1806                        let sample = (elems as f64, total_secs);
1807                        match report.category.as_str() {
1808                            "elementwise" => elem_samples.push(sample),
1809                            "reduction" => reduction_samples.push(sample),
1810                            "transpose" => transpose_samples.push(sample),
1811                            _ => {}
1812                        }
1813                    }
1814                }
1815                "matmul" => {
1816                    if report.input_shapes.len() >= 2 {
1817                        let a = &report.input_shapes[0];
1818                        let b = &report.input_shapes[1];
1819                        if a.len() == 2 && b.len() == 2 {
1820                            let m = a[0];
1821                            let k = a[1];
1822                            let n = b[1];
1823                            let flops = m.checked_mul(k).and_then(|val| val.checked_mul(n));
1824                            if let Some(flops) = flops {
1825                                matmul_samples.push((flops as f64, total_secs));
1826                            }
1827                        }
1828                    }
1829                }
1830                _ => {}
1831            }
1832        }
1833
1834        ProfileCostModel {
1835            elem: fit_linear_model(&elem_samples),
1836            reduction: fit_linear_model(&reduction_samples),
1837            transpose: fit_linear_model(&transpose_samples),
1838            matmul: fit_linear_model(&matmul_samples),
1839        }
1840    }
1841
1842    fn estimate_elemwise(&self, elements: usize) -> Option<Duration> {
1843        self.elem.and_then(|model| model.estimate(elements as f64))
1844    }
1845
1846    fn estimate_reduction(&self, elements: usize) -> Option<Duration> {
1847        self.reduction
1848            .and_then(|model| model.estimate(elements as f64))
1849    }
1850
1851    fn estimate_matmul(&self, m: usize, k: usize, n: usize) -> Option<Duration> {
1852        let flops = m.checked_mul(k)?.checked_mul(n)?;
1853        self.matmul.and_then(|model| model.estimate(flops as f64))
1854    }
1855
1856    fn estimate_matmul_flops(&self, flops: usize) -> Option<Duration> {
1857        self.matmul.and_then(|model| model.estimate(flops as f64))
1858    }
1859
1860    fn estimate_transpose(&self, elements: usize) -> Option<Duration> {
1861        self.transpose
1862            .and_then(|model| model.estimate(elements as f64))
1863    }
1864}
1865
1866fn fit_linear_model(samples: &[(f64, f64)]) -> Option<LinearModel> {
1867    if samples.is_empty() {
1868        return None;
1869    }
1870    if samples.len() == 1 {
1871        let (x, y) = samples[0];
1872        if x > 0.0 {
1873            return Some(LinearModel {
1874                slope: (y / x).max(0.0),
1875                intercept: 0.0,
1876            });
1877        }
1878        return None;
1879    }
1880
1881    let sum_x: f64 = samples.iter().map(|(x, _)| *x).sum();
1882    let sum_y: f64 = samples.iter().map(|(_, y)| *y).sum();
1883    let sum_xx: f64 = samples.iter().map(|(x, _)| x * x).sum();
1884    let sum_xy: f64 = samples.iter().map(|(x, y)| x * y).sum();
1885    let n = samples.len() as f64;
1886    let denom = (n * sum_xx) - (sum_x * sum_x);
1887    if denom.abs() < f64::EPSILON {
1888        return None;
1889    }
1890    let slope = ((n * sum_xy) - (sum_x * sum_y)) / denom;
1891    let mean_x = sum_x / n;
1892    let mean_y = sum_y / n;
1893    let mut intercept = mean_y - slope * mean_x;
1894    if intercept < 0.0 {
1895        intercept = 0.0;
1896    }
1897    if !slope.is_finite() || slope <= 0.0 {
1898        return None;
1899    }
1900    Some(LinearModel { slope, intercept })
1901}
1902
1903fn profile_cost_model() -> Option<&'static ProfileCostModel> {
1904    PROFILE_MODEL.get_or_init(load_profile_cost_model).as_ref()
1905}
1906
1907fn load_profile_cost_model() -> Option<ProfileCostModel> {
1908    let mut candidates = Vec::new();
1909    if let Ok(path) = env::var("RUNMAT_ACCEL_PROFILE") {
1910        candidates.push(PathBuf::from(path));
1911    }
1912    if let Some(path) = auto_offload_options().profile_path.clone() {
1913        candidates.push(path);
1914    }
1915    candidates.push(PathBuf::from("benchmarks/wgpu_profile/mac_m2.json"));
1916    candidates.push(PathBuf::from("wgpu_profile.json"));
1917
1918    for path in candidates {
1919        if !path.exists() {
1920            continue;
1921        }
1922        match fs::read_to_string(&path) {
1923            Ok(contents) => match serde_json::from_str::<Vec<ProfileReport>>(&contents) {
1924                Ok(reports) => {
1925                    debug!(
1926                        "Loaded {} GPU profile reports from {}",
1927                        reports.len(),
1928                        path.display()
1929                    );
1930                    return Some(ProfileCostModel::from_reports(&reports));
1931                }
1932                Err(err) => {
1933                    debug!("Failed to parse GPU profile {}: {err}", path.display());
1934                }
1935            },
1936            Err(err) => {
1937                debug!("Failed to read GPU profile {}: {err}", path.display());
1938            }
1939        }
1940    }
1941    None
1942}
1943
1944pub fn promote_binary(op: BinaryOp, a: &Value, b: &Value) -> Result<(Value, Value)> {
1945    if !auto_enabled() {
1946        return Ok((a.clone(), b.clone()));
1947    }
1948    if let Some(auto) = global() {
1949        auto.promote_binary(op, a, b)
1950    } else {
1951        Ok((a.clone(), b.clone()))
1952    }
1953}
1954
1955pub fn promote_unary(op: UnaryOp, value: &Value) -> Result<Value> {
1956    if !auto_enabled() {
1957        return Ok(value.clone());
1958    }
1959    if let Some(auto) = global() {
1960        auto.promote_unary(op, value)
1961    } else {
1962        Ok(value.clone())
1963    }
1964}
1965
1966pub fn prepare_builtin_args(name: &str, args: &[Value]) -> Result<Vec<Value>> {
1967    if let Some(policy) = builtin_policy(name) {
1968        if policy.is_sink {
1969            return gather_args(args);
1970        }
1971    }
1972    if !auto_enabled() {
1973        return Ok(args.to_vec());
1974    }
1975    if let Some(auto) = global() {
1976        auto.prepare_builtin(name, args)
1977    } else {
1978        Ok(args.to_vec())
1979    }
1980}
1981
1982pub fn is_sink(name: &str) -> bool {
1983    builtin_policy(name).map(|p| p.is_sink).unwrap_or(false)
1984}
1985
1986pub fn promote_reduction_args(op: ReductionOp, args: &[Value]) -> Result<Vec<Value>> {
1987    if !auto_enabled() {
1988        return Ok(args.to_vec());
1989    }
1990    if let Some(auto) = global() {
1991        auto.promote_reduction(op, args)
1992    } else {
1993        Ok(args.to_vec())
1994    }
1995}