Skip to main content

runmat_accelerate/
precision.rs

1use once_cell::sync::{Lazy, OnceCell};
2use runmat_accelerate_api::{AccelProvider, ProviderPrecision};
3use runmat_builtins::{NumericDType, Tensor, Value};
4use std::env;
5
6/// Return the logical numeric dtype associated with the provided value, if any.
7pub fn value_numeric_dtype(value: &Value) -> Option<NumericDType> {
8    match value {
9        Value::Tensor(t) => Some(t.dtype),
10        Value::Num(_) | Value::Int(_) | Value::Bool(_) => Some(NumericDType::F64),
11        Value::LogicalArray(_) | Value::CharArray(_) => Some(NumericDType::F64),
12        Value::GpuTensor(_) => None, // already resident; assume provider handled dtype
13        _ => None,
14    }
15}
16
17/// Return the logical dtype represented by a tensor.
18pub fn tensor_numeric_dtype(tensor: &Tensor) -> NumericDType {
19    tensor.dtype
20}
21
22fn parse_bool(s: &str) -> Option<bool> {
23    match s.trim().to_ascii_lowercase().as_str() {
24        "1" | "true" | "yes" | "on" => Some(true),
25        "0" | "false" | "no" | "off" => Some(false),
26        _ => None,
27    }
28}
29
30static ALLOW_DOWNCAST: Lazy<bool> = Lazy::new(|| {
31    env::var("RUNMAT_ALLOW_PRECISION_DOWNCAST")
32        .ok()
33        .and_then(|value| parse_bool(&value))
34        .unwrap_or(false)
35});
36
37static DOWNCAST_WARNING: OnceCell<()> = OnceCell::new();
38
39/// True if the provider can execute kernels with the requested logical dtype.
40pub fn provider_supports_dtype(provider: &dyn AccelProvider, dtype: NumericDType) -> bool {
41    match dtype {
42        NumericDType::F32 => true,
43        NumericDType::F64 => provider.precision() == ProviderPrecision::F64,
44        NumericDType::U8 | NumericDType::U16 => false,
45    }
46}
47
48fn downcast_permitted_for(dtype: NumericDType) -> bool {
49    matches!(dtype, NumericDType::F64) && *ALLOW_DOWNCAST
50}
51
52/// Returns an error message if the provider cannot execute the requested dtype.
53pub fn ensure_provider_supports_dtype(
54    provider: &dyn AccelProvider,
55    dtype: NumericDType,
56) -> Result<(), String> {
57    if provider_supports_dtype(provider, dtype) {
58        Ok(())
59    } else if downcast_permitted_for(dtype) {
60        DOWNCAST_WARNING.get_or_init(|| {
61            log::warn!(
62                "RUNMAT_ALLOW_PRECISION_DOWNCAST enabled: implicitly converting double inputs to the provider's native precision"
63            );
64        });
65        Ok(())
66    } else {
67        Err(match dtype {
68            NumericDType::F64 => {
69                "active provider does not advertise f64 kernels; refusing implicit downcast"
70                    .to_string()
71            }
72            NumericDType::F32 => "active provider does not support f32 kernels".to_string(),
73            NumericDType::U8 | NumericDType::U16 => {
74                format!(
75                    "active provider does not support {} kernels",
76                    dtype.class_name()
77                )
78            }
79        })
80    }
81}
82
83pub fn downcast_permitted(dtype: NumericDType) -> bool {
84    downcast_permitted_for(dtype)
85}