runmat_accelerate/
precision.rs1use once_cell::sync::{Lazy, OnceCell};
2use runmat_accelerate_api::{AccelProvider, ProviderPrecision};
3use runmat_builtins::{NumericDType, Tensor, Value};
4use std::env;
5
6pub fn value_numeric_dtype(value: &Value) -> Option<NumericDType> {
8 match value {
9 Value::Tensor(t) => Some(t.dtype),
10 Value::Num(_) | Value::Int(_) | Value::Bool(_) => Some(NumericDType::F64),
11 Value::LogicalArray(_) | Value::CharArray(_) => Some(NumericDType::F64),
12 Value::GpuTensor(_) => None, _ => None,
14 }
15}
16
17pub fn tensor_numeric_dtype(tensor: &Tensor) -> NumericDType {
19 tensor.dtype
20}
21
22fn parse_bool(s: &str) -> Option<bool> {
23 match s.trim().to_ascii_lowercase().as_str() {
24 "1" | "true" | "yes" | "on" => Some(true),
25 "0" | "false" | "no" | "off" => Some(false),
26 _ => None,
27 }
28}
29
30static ALLOW_DOWNCAST: Lazy<bool> = Lazy::new(|| {
31 env::var("RUNMAT_ALLOW_PRECISION_DOWNCAST")
32 .ok()
33 .and_then(|value| parse_bool(&value))
34 .unwrap_or(false)
35});
36
37static DOWNCAST_WARNING: OnceCell<()> = OnceCell::new();
38
39pub fn provider_supports_dtype(provider: &dyn AccelProvider, dtype: NumericDType) -> bool {
41 match dtype {
42 NumericDType::F32 => true,
43 NumericDType::F64 => provider.precision() == ProviderPrecision::F64,
44 }
45}
46
47fn downcast_permitted_for(dtype: NumericDType) -> bool {
48 matches!(dtype, NumericDType::F64) && *ALLOW_DOWNCAST
49}
50
51pub fn ensure_provider_supports_dtype(
53 provider: &dyn AccelProvider,
54 dtype: NumericDType,
55) -> Result<(), String> {
56 if provider_supports_dtype(provider, dtype) {
57 Ok(())
58 } else if downcast_permitted_for(dtype) {
59 DOWNCAST_WARNING.get_or_init(|| {
60 log::warn!(
61 "RUNMAT_ALLOW_PRECISION_DOWNCAST enabled: implicitly converting double inputs to the provider's native precision"
62 );
63 });
64 Ok(())
65 } else {
66 Err(match dtype {
67 NumericDType::F64 => {
68 "active provider does not advertise f64 kernels; refusing implicit downcast"
69 .to_string()
70 }
71 NumericDType::F32 => "active provider does not support f32 kernels".to_string(),
72 })
73 }
74}
75
76pub fn downcast_permitted(dtype: NumericDType) -> bool {
77 downcast_permitted_for(dtype)
78}