hanzo_engine/utils/
normal.rs1#![allow(dead_code, unused)]
2
3use std::{fmt::Display, str::FromStr};
4
5use anyhow::Result;
6use hanzo_ml::{DType, Device, Tensor};
7use hanzo_quant::log::once_log_info;
8use serde::{Deserialize, Serialize};
9use tracing::debug;
10
11#[derive(Clone, Copy, Default, Debug, Deserialize, Serialize, PartialEq)]
12#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq, eq_int))]
13pub enum ModelDType {
19 #[default]
20 #[serde(rename = "auto")]
21 Auto,
22 #[serde(rename = "bf16")]
23 BF16,
24 #[serde(rename = "f16")]
25 F16,
26 #[serde(rename = "f32")]
27 F32,
28}
29
30impl Display for ModelDType {
31 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32 match self {
33 Self::Auto => write!(f, "auto"),
34 Self::BF16 => write!(f, "bf16"),
35 Self::F16 => write!(f, "f16"),
36 Self::F32 => write!(f, "f32"),
37 }
38 }
39}
40
41impl FromStr for ModelDType {
42 type Err = String;
43 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
44 match s.to_lowercase().as_str() {
45 "auto" => Ok(Self::Auto),
46 "bf16" => Ok(Self::BF16),
47 "f16" => Ok(Self::F16),
48 "f32" => Ok(Self::F32),
49 other => Err(format!("Model DType `{other}` is not supported.")),
50 }
51 }
52}
53
54pub trait TryIntoDType {
56 fn try_into_dtype(&self, devices: &[&Device]) -> Result<DType>;
57}
58
59impl TryIntoDType for DType {
60 fn try_into_dtype(&self, _: &[&Device]) -> Result<DType> {
61 if !matches!(self, DType::BF16 | DType::F32 | DType::F64 | DType::F16) {
62 anyhow::bail!("DType must be one of BF16, F16, F32, F64");
63 }
64 once_log_info(format!("DType selected is {self:?}."));
65 Ok(*self)
66 }
67}
68
69#[cfg(feature = "cuda")]
70fn get_dtypes() -> Vec<DType> {
71 use std::process::Command;
72
73 const MIN_BF16_CC: usize = 800;
75 const MIN_F16_CC: usize = 530;
77
78 let raw_out = Command::new("nvidia-smi")
79 .arg("--query-gpu=compute_cap")
80 .arg("--format=csv")
81 .output()
82 .expect("Failed to run `nvidia-smi` but CUDA is selected.")
83 .stdout;
84 let out = String::from_utf8(raw_out).expect("`nvidia-smi` did not return valid utf8");
85 let min_cc = out
87 .split('\n')
88 .skip(1)
89 .filter(|cc| !cc.trim().is_empty())
90 .map(|cc| cc.trim().parse::<f32>().unwrap())
91 .reduce(|a, b| if a < b { a } else { b })
92 .unwrap();
93 debug!("Detected minimum CUDA compute capability {min_cc}");
94 #[allow(clippy::cast_possible_truncation)]
96 let min_cc = (min_cc * 100.) as usize;
97
98 let mut dtypes = Vec::new();
99 if min_cc >= MIN_BF16_CC {
100 dtypes.push(DType::BF16);
101 } else {
102 debug!("Skipping BF16 because CC < 8.0");
103 }
104 if min_cc >= MIN_F16_CC {
105 dtypes.push(DType::F16);
106 } else {
107 debug!("Skipping F16 because CC < 5.3");
108 }
109 dtypes
110}
111
112fn get_dtypes_non_cuda() -> Vec<DType> {
113 vec![DType::BF16, DType::F16]
114}
115
116#[cfg(not(feature = "cuda"))]
117fn get_dtypes() -> Vec<DType> {
118 get_dtypes_non_cuda()
119}
120
121fn determine_auto_dtype_all(devices: &[&Device]) -> hanzo_ml::Result<DType> {
122 #[cfg(feature = "accelerate")]
124 return Ok(DType::BF16);
125 #[cfg(not(feature = "accelerate"))]
126 {
127 let dev_dtypes = get_dtypes();
128 for dtype in get_dtypes_non_cuda()
129 .iter()
130 .filter(|x| dev_dtypes.contains(x))
131 {
132 let mut results = Vec::new();
133 for device in devices {
134 let x = Tensor::zeros((2, 2), *dtype, device)?;
136 results.push(x.matmul(&x));
137 }
138 if results.iter().all(|x| x.is_ok()) {
139 return Ok(*dtype);
140 } else {
141 for result in results {
142 match result {
143 Ok(_) => (),
144 Err(e) => match e {
145 hanzo_ml::Error::UnsupportedDTypeForOp(_, _) => continue,
147 hanzo_ml::Error::Msg(_) => continue,
150 hanzo_ml::Error::Metal(_) => continue,
152 hanzo_ml::Error::WithBacktrace { .. } => continue,
154 other => return Err(other),
155 },
156 }
157 }
158 }
159 }
160 Ok(DType::F32)
161 }
162}
163
164impl TryIntoDType for ModelDType {
165 fn try_into_dtype(&self, devices: &[&Device]) -> Result<DType> {
166 let dtype = match self {
167 Self::Auto => determine_auto_dtype_all(devices).map_err(anyhow::Error::msg)?,
168 Self::BF16 => DType::BF16,
169 Self::F16 => DType::F16,
170 Self::F32 => DType::F32,
171 };
172 once_log_info(format!("DType selected is {dtype:?}."));
173 Ok(dtype)
174 }
175}
176
177pub fn is_integrated_gpu(device: &Device) -> bool {
184 match device {
185 #[cfg(feature = "metal")]
186 Device::Metal(_) => true,
187 #[cfg(feature = "cuda")]
188 Device::Cuda(dev) => {
189 use hanzo_ml::cuda::cudarc::driver::{result, sys};
190 let ordinal = dev.cuda_stream().context().ordinal();
191 #[allow(clippy::cast_possible_truncation)]
192 let cu_device = match result::device::get(ordinal as i32) {
193 Ok(d) => d,
194 Err(_) => return false,
195 };
196 unsafe {
197 result::device::get_attribute(
198 cu_device,
199 sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_INTEGRATED,
200 )
201 .map(|v| v != 0)
202 .unwrap_or(false)
203 }
204 }
205 _ => false,
206 }
207}