runmat_accelerate/
precision.rs

1use once_cell::sync::{Lazy, OnceCell};
2use runmat_accelerate_api::{AccelProvider, ProviderPrecision};
3use runmat_builtins::{NumericDType, Tensor, Value};
4use std::env;
5
6/// Return the logical numeric dtype associated with the provided value, if any.
7pub fn value_numeric_dtype(value: &Value) -> Option<NumericDType> {
8    match value {
9        Value::Tensor(t) => Some(t.dtype),
10        Value::Num(_) | Value::Int(_) | Value::Bool(_) => Some(NumericDType::F64),
11        Value::LogicalArray(_) | Value::CharArray(_) => Some(NumericDType::F64),
12        Value::GpuTensor(_) => None, // already resident; assume provider handled dtype
13        _ => None,
14    }
15}
16
17/// Return the logical dtype represented by a tensor.
18pub fn tensor_numeric_dtype(tensor: &Tensor) -> NumericDType {
19    tensor.dtype
20}
21
22fn parse_bool(s: &str) -> Option<bool> {
23    match s.trim().to_ascii_lowercase().as_str() {
24        "1" | "true" | "yes" | "on" => Some(true),
25        "0" | "false" | "no" | "off" => Some(false),
26        _ => None,
27    }
28}
29
30static ALLOW_DOWNCAST: Lazy<bool> = Lazy::new(|| {
31    env::var("RUNMAT_ALLOW_PRECISION_DOWNCAST")
32        .ok()
33        .and_then(|value| parse_bool(&value))
34        .unwrap_or(false)
35});
36
37static DOWNCAST_WARNING: OnceCell<()> = OnceCell::new();
38
39/// True if the provider can execute kernels with the requested logical dtype.
40pub fn provider_supports_dtype(provider: &dyn AccelProvider, dtype: NumericDType) -> bool {
41    match dtype {
42        NumericDType::F32 => true,
43        NumericDType::F64 => provider.precision() == ProviderPrecision::F64,
44    }
45}
46
47fn downcast_permitted_for(dtype: NumericDType) -> bool {
48    matches!(dtype, NumericDType::F64) && *ALLOW_DOWNCAST
49}
50
51/// Returns an error message if the provider cannot execute the requested dtype.
52pub fn ensure_provider_supports_dtype(
53    provider: &dyn AccelProvider,
54    dtype: NumericDType,
55) -> Result<(), String> {
56    if provider_supports_dtype(provider, dtype) {
57        Ok(())
58    } else if downcast_permitted_for(dtype) {
59        DOWNCAST_WARNING.get_or_init(|| {
60            log::warn!(
61                "RUNMAT_ALLOW_PRECISION_DOWNCAST enabled: implicitly converting double inputs to the provider's native precision"
62            );
63        });
64        Ok(())
65    } else {
66        Err(match dtype {
67            NumericDType::F64 => {
68                "active provider does not advertise f64 kernels; refusing implicit downcast"
69                    .to_string()
70            }
71            NumericDType::F32 => "active provider does not support f32 kernels".to_string(),
72        })
73    }
74}
75
76pub fn downcast_permitted(dtype: NumericDType) -> bool {
77    downcast_permitted_for(dtype)
78}