Skip to main content

hanzo_engine/utils/
normal.rs

1#![allow(dead_code, unused)]
2
3use std::{fmt::Display, str::FromStr};
4
5use anyhow::Result;
6use hanzo_ml::{DType, Device, Tensor};
7use hanzo_quant::log::once_log_info;
8use serde::{Deserialize, Serialize};
9use tracing::debug;
10
11#[derive(Clone, Copy, Default, Debug, Deserialize, Serialize, PartialEq)]
12#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq, eq_int))]
13/// DType for the model.
14///
15/// If the model is quantized, this is ignored so it is reasonable to use the [`Default`] impl.
16///
17/// Note: When using `Auto`, fallback pattern is: BF16 -> F16 -> 32
18pub enum ModelDType {
19    #[default]
20    #[serde(rename = "auto")]
21    Auto,
22    #[serde(rename = "bf16")]
23    BF16,
24    #[serde(rename = "f16")]
25    F16,
26    #[serde(rename = "f32")]
27    F32,
28}
29
30impl Display for ModelDType {
31    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32        match self {
33            Self::Auto => write!(f, "auto"),
34            Self::BF16 => write!(f, "bf16"),
35            Self::F16 => write!(f, "f16"),
36            Self::F32 => write!(f, "f32"),
37        }
38    }
39}
40
41impl FromStr for ModelDType {
42    type Err = String;
43    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
44        match s.to_lowercase().as_str() {
45            "auto" => Ok(Self::Auto),
46            "bf16" => Ok(Self::BF16),
47            "f16" => Ok(Self::F16),
48            "f32" => Ok(Self::F32),
49            other => Err(format!("Model DType `{other}` is not supported.")),
50        }
51    }
52}
53
54/// Type which can be converted to a DType
55pub trait TryIntoDType {
56    fn try_into_dtype(&self, devices: &[&Device]) -> Result<DType>;
57}
58
59impl TryIntoDType for DType {
60    fn try_into_dtype(&self, _: &[&Device]) -> Result<DType> {
61        if !matches!(self, DType::BF16 | DType::F32 | DType::F64 | DType::F16) {
62            anyhow::bail!("DType must be one of BF16, F16, F32, F64");
63        }
64        once_log_info(format!("DType selected is {self:?}."));
65        Ok(*self)
66    }
67}
68
69#[cfg(feature = "cuda")]
70fn get_dtypes() -> Vec<DType> {
71    use std::process::Command;
72
73    // >= is supported
74    const MIN_BF16_CC: usize = 800;
75    // >= is supported
76    const MIN_F16_CC: usize = 530;
77
78    let raw_out = Command::new("nvidia-smi")
79        .arg("--query-gpu=compute_cap")
80        .arg("--format=csv")
81        .output()
82        .expect("Failed to run `nvidia-smi` but CUDA is selected.")
83        .stdout;
84    let out = String::from_utf8(raw_out).expect("`nvidia-smi` did not return valid utf8");
85    // This reduce-min will always return at least one value so unwrap is OK.
86    let min_cc = out
87        .split('\n')
88        .skip(1)
89        .filter(|cc| !cc.trim().is_empty())
90        .map(|cc| cc.trim().parse::<f32>().unwrap())
91        .reduce(|a, b| if a < b { a } else { b })
92        .unwrap();
93    debug!("Detected minimum CUDA compute capability {min_cc}");
94    // 7.5 -> 750
95    #[allow(clippy::cast_possible_truncation)]
96    let min_cc = (min_cc * 100.) as usize;
97
98    let mut dtypes = Vec::new();
99    if min_cc >= MIN_BF16_CC {
100        dtypes.push(DType::BF16);
101    } else {
102        debug!("Skipping BF16 because CC < 8.0");
103    }
104    if min_cc >= MIN_F16_CC {
105        dtypes.push(DType::F16);
106    } else {
107        debug!("Skipping F16 because CC < 5.3");
108    }
109    dtypes
110}
111
112fn get_dtypes_non_cuda() -> Vec<DType> {
113    vec![DType::BF16, DType::F16]
114}
115
116#[cfg(not(feature = "cuda"))]
117fn get_dtypes() -> Vec<DType> {
118    get_dtypes_non_cuda()
119}
120
121fn determine_auto_dtype_all(devices: &[&Device]) -> hanzo_ml::Result<DType> {
122    // We can safely use bf16 for accelerate because we cast up to f32 in all matmuls anyway.
123    #[cfg(feature = "accelerate")]
124    return Ok(DType::BF16);
125    #[cfg(not(feature = "accelerate"))]
126    {
127        let dev_dtypes = get_dtypes();
128        for dtype in get_dtypes_non_cuda()
129            .iter()
130            .filter(|x| dev_dtypes.contains(x))
131        {
132            let mut results = Vec::new();
133            for device in devices {
134                // Try a matmul
135                let x = Tensor::zeros((2, 2), *dtype, device)?;
136                results.push(x.matmul(&x));
137            }
138            if results.iter().all(|x| x.is_ok()) {
139                return Ok(*dtype);
140            } else {
141                for result in results {
142                    match result {
143                        Ok(_) => (),
144                        Err(e) => match e {
145                            // For CUDA
146                            hanzo_ml::Error::UnsupportedDTypeForOp(_, _) => continue,
147                            // Accelerate backend doesn't support f16/bf16
148                            // Metal backend doesn't support f16
149                            hanzo_ml::Error::Msg(_) => continue,
150                            // This is when the metal backend doesn't support bf16
151                            hanzo_ml::Error::Metal(_) => continue,
152                            // If running with RUST_BACKTRACE=1
153                            hanzo_ml::Error::WithBacktrace { .. } => continue,
154                            other => return Err(other),
155                        },
156                    }
157                }
158            }
159        }
160        Ok(DType::F32)
161    }
162}
163
164impl TryIntoDType for ModelDType {
165    fn try_into_dtype(&self, devices: &[&Device]) -> Result<DType> {
166        let dtype = match self {
167            Self::Auto => determine_auto_dtype_all(devices).map_err(anyhow::Error::msg)?,
168            Self::BF16 => DType::BF16,
169            Self::F16 => DType::F16,
170            Self::F32 => DType::F32,
171        };
172        once_log_info(format!("DType selected is {dtype:?}."));
173        Ok(dtype)
174    }
175}
176
177/// Returns `true` if the given device has integrated/unified memory where CPU and GPU
178/// share the same physical memory. This includes:
179/// - Metal (Apple Silicon)
180/// - CUDA integrated GPUs (e.g. NVIDIA Grace Hopper, Grace Blackwell)
181///
182/// On such systems, loading tensors to CPU first provides no memory benefit.
183pub fn is_integrated_gpu(device: &Device) -> bool {
184    match device {
185        #[cfg(feature = "metal")]
186        Device::Metal(_) => true,
187        #[cfg(feature = "cuda")]
188        Device::Cuda(dev) => {
189            use hanzo_ml::cuda::cudarc::driver::{result, sys};
190            let ordinal = dev.cuda_stream().context().ordinal();
191            #[allow(clippy::cast_possible_truncation)]
192            let cu_device = match result::device::get(ordinal as i32) {
193                Ok(d) => d,
194                Err(_) => return false,
195            };
196            unsafe {
197                result::device::get_attribute(
198                    cu_device,
199                    sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_INTEGRATED,
200                )
201                .map(|v| v != 0)
202                .unwrap_or(false)
203            }
204        }
205        _ => false,
206    }
207}