alith_devices/devices/
cuda.rs1use super::gpu::GpuDevice;
2use nvml_wrapper::Nvml;
3
4pub const CUDA_OVERHEAD: u64 = 500 * 1024 * 1024;
6
7#[derive(Debug, Clone, Default)]
8pub struct CudaConfig {
9 pub main_gpu: Option<u32>,
11 pub use_cuda_devices: Vec<u32>,
13 pub(crate) cuda_devices: Vec<CudaDevice>,
14 pub(crate) total_vram_bytes: u64,
15}
16
17impl CudaConfig {
18 pub fn new_from_cuda_devices(use_cuda_devices: Vec<u32>) -> Self {
19 Self {
20 use_cuda_devices,
21 ..Default::default()
22 }
23 }
24
25 pub fn new_with_main_device(use_cuda_devices: Vec<u32>, main_gpu: u32) -> Self {
26 Self {
27 main_gpu: Some(main_gpu),
28 use_cuda_devices,
29 ..Default::default()
30 }
31 }
32
33 pub(crate) fn initialize(&mut self, error_on_config_issue: bool) -> crate::Result<()> {
34 let nvml: Nvml = init_nvml_wrapper()?;
35 if self.use_cuda_devices.is_empty() {
36 self.cuda_devices = get_all_cuda_devices(Some(&nvml))?;
37 } else {
38 for ordinal in &self.use_cuda_devices {
39 match CudaDevice::new(*ordinal, Some(&nvml)) {
40 Ok(cuda_device) => self.cuda_devices.push(cuda_device),
41 Err(e) => {
42 if error_on_config_issue {
43 crate::bail!(
44 "Failed to get device {} specified in cuda_devices: {}",
45 ordinal,
46 e
47 );
48 } else {
49 crate::warn!(
50 "Failed to get device {} specified in cuda_devices: {}",
51 ordinal,
52 e
53 );
54 }
55 }
56 }
57 }
58 }
59 if self.cuda_devices.is_empty() {
60 crate::bail!("No CUDA devices found");
61 }
62
63 self.main_gpu = Some(self.main_gpu(error_on_config_issue)?);
64
65 self.total_vram_bytes = self
66 .cuda_devices
67 .iter()
68 .map(|d| (d.available_vram_bytes))
69 .sum();
70 Ok(())
71 }
72
73 pub(crate) fn device_count(&self) -> usize {
74 self.cuda_devices.len()
75 }
76
77 pub(crate) fn main_gpu(&self, error_on_config_issue: bool) -> crate::Result<u32> {
78 if let Some(main_gpu) = self.main_gpu {
79 for device in &self.cuda_devices {
80 if device.ordinal == main_gpu {
81 return Ok(main_gpu);
82 }
83 }
84 if error_on_config_issue {
85 crate::bail!(
86 "Main GPU set by user {} not found in CUDA devices",
87 main_gpu
88 );
89 } else {
90 crate::warn!(
91 "Main GPU set by user {} not found in CUDA devices. Using largest VRAM device.",
92 main_gpu
93 );
94 }
95 };
96 let main_gpu = self
97 .cuda_devices
98 .iter()
99 .max_by_key(|d| d.available_vram_bytes)
100 .ok_or_else(|| crate::anyhow!("No devices found when setting main gpu"))?
101 .ordinal;
102 for device in &self.cuda_devices {
103 if device.ordinal == main_gpu {
104 return Ok(main_gpu);
105 }
106 }
107 crate::bail!("Main GPU {} not found in CUDA devices", main_gpu);
108 }
109
110 pub(crate) fn to_generic_gpu_devices(
111 &self,
112 error_on_config_issue: bool,
113 ) -> crate::Result<Vec<GpuDevice>> {
114 let mut gpu_devices: Vec<GpuDevice> = self
115 .cuda_devices
116 .iter()
117 .map(|d| d.to_generic_gpu())
118 .collect();
119 let main_gpu = self.main_gpu(error_on_config_issue)?;
120 for gpu in &mut gpu_devices {
121 if gpu.ordinal == main_gpu {
122 gpu.is_main_gpu = true;
123 }
124 }
125 Ok(gpu_devices)
126 }
127}
128
129pub fn get_all_cuda_devices(nvml: Option<&Nvml>) -> crate::Result<Vec<CudaDevice>> {
130 let nvml = match nvml {
131 Some(nvml) => nvml,
132 None => &init_nvml_wrapper()?,
133 };
134 let device_count = nvml.device_count()?;
135 let mut cuda_devices: Vec<CudaDevice> = Vec::new();
136 let mut ordinal = 0;
137 while cuda_devices.len() < device_count as usize {
138 if let Ok(nvml_device) = CudaDevice::new(ordinal, Some(nvml)) {
139 cuda_devices.push(nvml_device);
140 }
141 if ordinal > 100 {
142 crate::warn!(
143 "nvml_wrapper reported {device_count} devices, but we were only able to get {}",
144 cuda_devices.len()
145 );
146 }
147 ordinal += 1;
148 }
149 if cuda_devices.is_empty() {
150 crate::bail!("No CUDA devices found");
151 }
152 Ok(cuda_devices)
153}
154
155#[derive(Debug, Clone)]
156pub struct CudaDevice {
157 pub ordinal: u32,
158 pub available_vram_bytes: u64,
159 pub name: Option<String>,
160 pub power_limit: Option<u32>,
161 pub driver_major: Option<i32>,
162 pub driver_minor: Option<i32>,
163}
164
165impl CudaDevice {
166 pub fn new(ordinal: u32, nvml: Option<&Nvml>) -> crate::Result<Self> {
167 let nvml = match nvml {
168 Some(nvml) => nvml,
169 None => &init_nvml_wrapper()?,
170 };
171 if let Ok(nvml_device) = nvml.device_by_index(ordinal) {
172 if let Ok(memory_info) = nvml_device.memory_info() {
173 if memory_info.total != 0 {
174 let name = if let Ok(name) = nvml_device.name() {
175 Some(name)
176 } else {
177 None
178 };
179 let power_limit = if let Ok(power_limit) = nvml_device.enforced_power_limit() {
180 Some(power_limit)
181 } else {
182 None
183 };
184 let (driver_major, driver_minor) = if let Ok(cuda_compute_capability) =
185 nvml_device.cuda_compute_capability()
186 {
187 (
188 Some(cuda_compute_capability.major),
189 Some(cuda_compute_capability.minor),
190 )
191 } else {
192 (None, None)
193 };
194 let cuda_device = CudaDevice {
195 ordinal,
196 available_vram_bytes: memory_info.total - CUDA_OVERHEAD,
197 name,
198 power_limit,
199 driver_major,
200 driver_minor,
201 };
202
203 Ok(cuda_device)
204 } else {
205 crate::bail!("Device {} has 0 bytes of VRAM. Skipping device.", ordinal);
206 }
207 } else {
208 crate::bail!("Failed to get device {}", ordinal);
209 }
210 } else {
211 crate::bail!("Failed to get device {}", ordinal);
212 }
213 }
214
215 pub fn to_generic_gpu(&self) -> GpuDevice {
216 GpuDevice {
217 ordinal: self.ordinal,
218 available_vram_bytes: self.available_vram_bytes,
219 ..Default::default()
220 }
221 }
222}
223
224pub fn init_nvml_wrapper() -> crate::Result<Nvml> {
225 let library_names = vec![
226 "libnvidia-ml.so", "libnvidia-ml.so.1", "nvml.dll", ];
230 for library_name in library_names {
231 match Nvml::builder().lib_path(library_name.as_ref()).init() {
232 Ok(nvml) => return Ok(nvml),
233 Err(_) => {
234 continue;
235 }
236 }
237 }
238 crate::bail!("Failed to initialize nvml_wrapper::Nvml")
239}
240
241impl std::fmt::Display for CudaConfig {
242 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
243 writeln!(f)?;
244 writeln!(f, "CudaConfig:")?;
245 crate::i_nlns(
246 f,
247 &[
248 format_args!("Main GPU: {:?}", self.main_gpu),
249 format_args!(
250 "Total vram size: {:.2} GB",
251 (self.total_vram_bytes as f64) / 1_073_741_824.0
252 ),
253 ],
254 )?;
255 for device in &self.cuda_devices {
256 crate::i_ln(f, format_args!("{}", device))?;
257 }
258 Ok(())
259 }
260}
261
262impl std::fmt::Display for CudaDevice {
263 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
264 writeln!(f, "CudaDevice:")?;
265 crate::i_nlns(
266 f,
267 &[
268 format_args!("Ordinal: {:?}", self.ordinal),
269 format_args!(
270 "Available VRAM: {:.2} GB",
271 (self.available_vram_bytes as f64) / 1_073_741_824.0
272 ),
273 format_args!("Name: {:?}", self.name),
274 format_args!("Power limit: {:?}", self.power_limit),
275 format_args!(
276 "Driver version: {}.{}",
277 self.driver_major.unwrap_or(-1),
278 self.driver_minor.unwrap_or(-1)
279 ),
280 ],
281 )
282 }
283}