burn_train/metric/
cuda.rs

1use super::{Adaptor, MetricMetadata};
2use crate::metric::{Metric, MetricEntry};
3use nvml_wrapper::Nvml;
4
5/// Track basic cuda infos.
6pub struct CudaMetric {
7    nvml: Option<Nvml>,
8}
9
10impl CudaMetric {
11    /// Creates a new metric for CUDA.
12    pub fn new() -> Self {
13        Self {
14            nvml: Nvml::init().map(Some).unwrap_or_else(|err| {
15                log::warn!("Unable to initialize CUDA Metric: {err}");
16                None
17            }),
18        }
19    }
20}
21
22impl Default for CudaMetric {
23    fn default() -> Self {
24        Self::new()
25    }
26}
27
28impl<T> Adaptor<()> for T {
29    fn adapt(&self) {}
30}
31
32impl Metric for CudaMetric {
33    type Input = ();
34
35    fn update(&mut self, _item: &(), _metadata: &MetricMetadata) -> MetricEntry {
36        let not_available = || {
37            MetricEntry::new(
38                self.name(),
39                "Unavailable".to_string(),
40                "Unavailable".to_string(),
41            )
42        };
43
44        let available = |nvml: &Nvml| {
45            let mut formatted = String::new();
46            let mut raw_running = String::new();
47
48            let device_count = match nvml.device_count() {
49                Ok(val) => val,
50                Err(err) => {
51                    log::warn!("Unable to get the number of cuda devices: {err}");
52                    return not_available();
53                }
54            };
55
56            for index in 0..device_count {
57                let device = match nvml.device_by_index(index) {
58                    Ok(val) => val,
59                    Err(err) => {
60                        log::warn!("Unable to get device {index}: {err}");
61                        return not_available();
62                    }
63                };
64                let memory_info = match device.memory_info() {
65                    Ok(info) => info,
66                    Err(err) => {
67                        log::warn!("Unable to get memory info from device {index}: {err}");
68                        return not_available();
69                    }
70                };
71
72                let used_gb = memory_info.used as f64 * 1e-9;
73                let total_gb = memory_info.total as f64 * 1e-9;
74
75                let memory_info_formatted = format!("{used_gb:.2}/{total_gb:.2} Gb");
76                let memory_info_raw = format!("{used_gb}/{total_gb}");
77
78                formatted = format!("{formatted} GPU #{index} - Memory {memory_info_formatted}");
79                raw_running = format!("{memory_info_raw} ");
80
81                let utilization_rates = match device.utilization_rates() {
82                    Ok(rate) => rate,
83                    Err(err) => {
84                        log::warn!("Unable to get utilization rates from device {index}: {err}");
85                        return not_available();
86                    }
87                };
88                let utilization_rate_formatted = format!("{}%", utilization_rates.gpu);
89                formatted = format!("{formatted} - Usage {utilization_rate_formatted}");
90            }
91
92            MetricEntry::new(self.name(), formatted, raw_running)
93        };
94
95        match &self.nvml {
96            Some(nvml) => available(nvml),
97            None => not_available(),
98        }
99    }
100
101    fn clear(&mut self) {}
102
103    fn name(&self) -> String {
104        "CUDA Stats".to_string()
105    }
106}