alith_devices/devices/
cuda.rs

1use super::gpu::GpuDevice;
2use nvml_wrapper::Nvml;
3
4// See https://gist.github.com/jrruethe/8974d2c8b4ece242a071d1a1526aa763#file-vram-rb-L64
5pub const CUDA_OVERHEAD: u64 = 500 * 1024 * 1024;
6
7#[derive(Debug, Clone, Default)]
8pub struct CudaConfig {
9    /// The main GPU device ordinal. Defaults to the largest VRAM device.
10    pub main_gpu: Option<u32>,
11    /// Ordinals of the devices to use.
12    pub use_cuda_devices: Vec<u32>,
13    pub(crate) cuda_devices: Vec<CudaDevice>,
14    pub(crate) total_vram_bytes: u64,
15}
16
17impl CudaConfig {
18    pub fn new_from_cuda_devices(use_cuda_devices: Vec<u32>) -> Self {
19        Self {
20            use_cuda_devices,
21            ..Default::default()
22        }
23    }
24
25    pub fn new_with_main_device(use_cuda_devices: Vec<u32>, main_gpu: u32) -> Self {
26        Self {
27            main_gpu: Some(main_gpu),
28            use_cuda_devices,
29            ..Default::default()
30        }
31    }
32
33    pub(crate) fn initialize(&mut self, error_on_config_issue: bool) -> crate::Result<()> {
34        let nvml: Nvml = init_nvml_wrapper()?;
35        if self.use_cuda_devices.is_empty() {
36            self.cuda_devices = get_all_cuda_devices(Some(&nvml))?;
37        } else {
38            for ordinal in &self.use_cuda_devices {
39                match CudaDevice::new(*ordinal, Some(&nvml)) {
40                    Ok(cuda_device) => self.cuda_devices.push(cuda_device),
41                    Err(e) => {
42                        if error_on_config_issue {
43                            crate::bail!(
44                                "Failed to get device {} specified in cuda_devices: {}",
45                                ordinal,
46                                e
47                            );
48                        } else {
49                            crate::warn!(
50                                "Failed to get device {} specified in cuda_devices: {}",
51                                ordinal,
52                                e
53                            );
54                        }
55                    }
56                }
57            }
58        }
59        if self.cuda_devices.is_empty() {
60            crate::bail!("No CUDA devices found");
61        }
62
63        self.main_gpu = Some(self.main_gpu(error_on_config_issue)?);
64
65        self.total_vram_bytes = self
66            .cuda_devices
67            .iter()
68            .map(|d| (d.available_vram_bytes))
69            .sum();
70        Ok(())
71    }
72
73    pub(crate) fn device_count(&self) -> usize {
74        self.cuda_devices.len()
75    }
76
77    pub(crate) fn main_gpu(&self, error_on_config_issue: bool) -> crate::Result<u32> {
78        if let Some(main_gpu) = self.main_gpu {
79            for device in &self.cuda_devices {
80                if device.ordinal == main_gpu {
81                    return Ok(main_gpu);
82                }
83            }
84            if error_on_config_issue {
85                crate::bail!(
86                    "Main GPU set by user {} not found in CUDA devices",
87                    main_gpu
88                );
89            } else {
90                crate::warn!(
91                    "Main GPU set by user {} not found in CUDA devices. Using largest VRAM device.",
92                    main_gpu
93                );
94            }
95        };
96        let main_gpu = self
97            .cuda_devices
98            .iter()
99            .max_by_key(|d| d.available_vram_bytes)
100            .ok_or_else(|| crate::anyhow!("No devices found when setting main gpu"))?
101            .ordinal;
102        for device in &self.cuda_devices {
103            if device.ordinal == main_gpu {
104                return Ok(main_gpu);
105            }
106        }
107        crate::bail!("Main GPU {} not found in CUDA devices", main_gpu);
108    }
109
110    pub(crate) fn to_generic_gpu_devices(
111        &self,
112        error_on_config_issue: bool,
113    ) -> crate::Result<Vec<GpuDevice>> {
114        let mut gpu_devices: Vec<GpuDevice> = self
115            .cuda_devices
116            .iter()
117            .map(|d| d.to_generic_gpu())
118            .collect();
119        let main_gpu = self.main_gpu(error_on_config_issue)?;
120        for gpu in &mut gpu_devices {
121            if gpu.ordinal == main_gpu {
122                gpu.is_main_gpu = true;
123            }
124        }
125        Ok(gpu_devices)
126    }
127}
128
129pub fn get_all_cuda_devices(nvml: Option<&Nvml>) -> crate::Result<Vec<CudaDevice>> {
130    let nvml = match nvml {
131        Some(nvml) => nvml,
132        None => &init_nvml_wrapper()?,
133    };
134    let device_count = nvml.device_count()?;
135    let mut cuda_devices: Vec<CudaDevice> = Vec::new();
136    let mut ordinal = 0;
137    while cuda_devices.len() < device_count as usize {
138        if let Ok(nvml_device) = CudaDevice::new(ordinal, Some(nvml)) {
139            cuda_devices.push(nvml_device);
140        }
141        if ordinal > 100 {
142            crate::warn!(
143                "nvml_wrapper reported {device_count} devices, but we were only able to get {}",
144                cuda_devices.len()
145            );
146        }
147        ordinal += 1;
148    }
149    if cuda_devices.is_empty() {
150        crate::bail!("No CUDA devices found");
151    }
152    Ok(cuda_devices)
153}
154
155#[derive(Debug, Clone)]
156pub struct CudaDevice {
157    pub ordinal: u32,
158    pub available_vram_bytes: u64,
159    pub name: Option<String>,
160    pub power_limit: Option<u32>,
161    pub driver_major: Option<i32>,
162    pub driver_minor: Option<i32>,
163}
164
165impl CudaDevice {
166    pub fn new(ordinal: u32, nvml: Option<&Nvml>) -> crate::Result<Self> {
167        let nvml = match nvml {
168            Some(nvml) => nvml,
169            None => &init_nvml_wrapper()?,
170        };
171        if let Ok(nvml_device) = nvml.device_by_index(ordinal) {
172            if let Ok(memory_info) = nvml_device.memory_info() {
173                if memory_info.total != 0 {
174                    let name = if let Ok(name) = nvml_device.name() {
175                        Some(name)
176                    } else {
177                        None
178                    };
179                    let power_limit = if let Ok(power_limit) = nvml_device.enforced_power_limit() {
180                        Some(power_limit)
181                    } else {
182                        None
183                    };
184                    let (driver_major, driver_minor) = if let Ok(cuda_compute_capability) =
185                        nvml_device.cuda_compute_capability()
186                    {
187                        (
188                            Some(cuda_compute_capability.major),
189                            Some(cuda_compute_capability.minor),
190                        )
191                    } else {
192                        (None, None)
193                    };
194                    let cuda_device = CudaDevice {
195                        ordinal,
196                        available_vram_bytes: memory_info.total - CUDA_OVERHEAD,
197                        name,
198                        power_limit,
199                        driver_major,
200                        driver_minor,
201                    };
202
203                    Ok(cuda_device)
204                } else {
205                    crate::bail!("Device {} has 0 bytes of VRAM. Skipping device.", ordinal);
206                }
207            } else {
208                crate::bail!("Failed to get device {}", ordinal);
209            }
210        } else {
211            crate::bail!("Failed to get device {}", ordinal);
212        }
213    }
214
215    pub fn to_generic_gpu(&self) -> GpuDevice {
216        GpuDevice {
217            ordinal: self.ordinal,
218            available_vram_bytes: self.available_vram_bytes,
219            ..Default::default()
220        }
221    }
222}
223
224pub fn init_nvml_wrapper() -> crate::Result<Nvml> {
225    let library_names = vec![
226        "libnvidia-ml.so",   // For Linux
227        "libnvidia-ml.so.1", // For WSL
228        "nvml.dll",          // For Windows
229    ];
230    for library_name in library_names {
231        match Nvml::builder().lib_path(library_name.as_ref()).init() {
232            Ok(nvml) => return Ok(nvml),
233            Err(_) => {
234                continue;
235            }
236        }
237    }
238    crate::bail!("Failed to initialize nvml_wrapper::Nvml")
239}
240
241impl std::fmt::Display for CudaConfig {
242    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
243        writeln!(f)?;
244        writeln!(f, "CudaConfig:")?;
245        crate::i_nlns(
246            f,
247            &[
248                format_args!("Main GPU: {:?}", self.main_gpu),
249                format_args!(
250                    "Total vram size: {:.2} GB",
251                    (self.total_vram_bytes as f64) / 1_073_741_824.0
252                ),
253            ],
254        )?;
255        for device in &self.cuda_devices {
256            crate::i_ln(f, format_args!("{}", device))?;
257        }
258        Ok(())
259    }
260}
261
262impl std::fmt::Display for CudaDevice {
263    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
264        writeln!(f, "CudaDevice:")?;
265        crate::i_nlns(
266            f,
267            &[
268                format_args!("Ordinal: {:?}", self.ordinal),
269                format_args!(
270                    "Available VRAM: {:.2} GB",
271                    (self.available_vram_bytes as f64) / 1_073_741_824.0
272                ),
273                format_args!("Name: {:?}", self.name),
274                format_args!("Power limit: {:?}", self.power_limit),
275                format_args!(
276                    "Driver version: {}.{}",
277                    self.driver_major.unwrap_or(-1),
278                    self.driver_minor.unwrap_or(-1)
279                ),
280            ],
281        )
282    }
283}