runmat_accelerate/
precision.rs1use once_cell::sync::{Lazy, OnceCell};
2use runmat_accelerate_api::{AccelProvider, ProviderPrecision};
3use runmat_builtins::{NumericDType, Tensor, Value};
4use std::env;
5
6pub fn value_numeric_dtype(value: &Value) -> Option<NumericDType> {
8 match value {
9 Value::Tensor(t) => Some(t.dtype),
10 Value::Num(_) | Value::Int(_) | Value::Bool(_) => Some(NumericDType::F64),
11 Value::LogicalArray(_) | Value::CharArray(_) => Some(NumericDType::F64),
12 Value::GpuTensor(_) => None, _ => None,
14 }
15}
16
17pub fn tensor_numeric_dtype(tensor: &Tensor) -> NumericDType {
19 tensor.dtype
20}
21
22fn parse_bool(s: &str) -> Option<bool> {
23 match s.trim().to_ascii_lowercase().as_str() {
24 "1" | "true" | "yes" | "on" => Some(true),
25 "0" | "false" | "no" | "off" => Some(false),
26 _ => None,
27 }
28}
29
30static ALLOW_DOWNCAST: Lazy<bool> = Lazy::new(|| {
31 env::var("RUNMAT_ALLOW_PRECISION_DOWNCAST")
32 .ok()
33 .and_then(|value| parse_bool(&value))
34 .unwrap_or(false)
35});
36
37static DOWNCAST_WARNING: OnceCell<()> = OnceCell::new();
38
39pub fn provider_supports_dtype(provider: &dyn AccelProvider, dtype: NumericDType) -> bool {
41 match dtype {
42 NumericDType::F32 => true,
43 NumericDType::F64 => provider.precision() == ProviderPrecision::F64,
44 NumericDType::U8 | NumericDType::U16 => false,
45 }
46}
47
48fn downcast_permitted_for(dtype: NumericDType) -> bool {
49 matches!(dtype, NumericDType::F64) && *ALLOW_DOWNCAST
50}
51
52pub fn ensure_provider_supports_dtype(
54 provider: &dyn AccelProvider,
55 dtype: NumericDType,
56) -> Result<(), String> {
57 if provider_supports_dtype(provider, dtype) {
58 Ok(())
59 } else if downcast_permitted_for(dtype) {
60 DOWNCAST_WARNING.get_or_init(|| {
61 log::warn!(
62 "RUNMAT_ALLOW_PRECISION_DOWNCAST enabled: implicitly converting double inputs to the provider's native precision"
63 );
64 });
65 Ok(())
66 } else {
67 Err(match dtype {
68 NumericDType::F64 => {
69 "active provider does not advertise f64 kernels; refusing implicit downcast"
70 .to_string()
71 }
72 NumericDType::F32 => "active provider does not support f32 kernels".to_string(),
73 NumericDType::U8 | NumericDType::U16 => {
74 format!(
75 "active provider does not support {} kernels",
76 dtype.class_name()
77 )
78 }
79 })
80 }
81}
82
83pub fn downcast_permitted(dtype: NumericDType) -> bool {
84 downcast_permitted_for(dtype)
85}