burn_train/metric/
cuda.rs

1use super::{Adaptor, MetricMetadata};
2use crate::metric::{Metric, MetricEntry};
3use nvml_wrapper::Nvml;
4
5/// Track basic cuda infos.
6pub struct CudaMetric {
7    nvml: Option<Nvml>,
8}
9
10impl CudaMetric {
11    /// Creates a new metric for CUDA.
12    pub fn new() -> Self {
13        Self {
14            nvml: Nvml::init().map(Some).unwrap_or_else(|err| {
15                log::warn!("Unable to initialize CUDA Metric: {err}");
16                None
17            }),
18        }
19    }
20}
21
22impl Default for CudaMetric {
23    fn default() -> Self {
24        Self::new()
25    }
26}
27
28impl<T> Adaptor<()> for T {
29    fn adapt(&self) {}
30}
31
32impl Metric for CudaMetric {
33    const NAME: &'static str = "CUDA Stats";
34
35    type Input = ();
36
37    fn update(&mut self, _item: &(), _metadata: &MetricMetadata) -> MetricEntry {
38        let not_available = || {
39            MetricEntry::new(
40                Self::NAME.to_string(),
41                "Unavailable".to_string(),
42                "Unavailable".to_string(),
43            )
44        };
45
46        let available = |nvml: &Nvml| {
47            let mut formatted = String::new();
48            let mut raw_running = String::new();
49
50            let device_count = match nvml.device_count() {
51                Ok(val) => val,
52                Err(err) => {
53                    log::warn!("Unable to get the number of cuda devices: {err}");
54                    return not_available();
55                }
56            };
57
58            for index in 0..device_count {
59                let device = match nvml.device_by_index(index) {
60                    Ok(val) => val,
61                    Err(err) => {
62                        log::warn!("Unable to get device {index}: {err}");
63                        return not_available();
64                    }
65                };
66                let memory_info = match device.memory_info() {
67                    Ok(info) => info,
68                    Err(err) => {
69                        log::warn!("Unable to get memory info from device {index}: {err}");
70                        return not_available();
71                    }
72                };
73
74                let used_gb = memory_info.used as f64 * 1e-9;
75                let total_gb = memory_info.total as f64 * 1e-9;
76
77                let memory_info_formatted = format!("{used_gb:.2}/{total_gb:.2} Gb");
78                let memory_info_raw = format!("{used_gb}/{total_gb}");
79
80                formatted = format!("{formatted} GPU #{index} - Memory {memory_info_formatted}");
81                raw_running = format!("{memory_info_raw} ");
82
83                let utilization_rates = match device.utilization_rates() {
84                    Ok(rate) => rate,
85                    Err(err) => {
86                        log::warn!("Unable to get utilization rates from device {index}: {err}");
87                        return not_available();
88                    }
89                };
90                let utilization_rate_formatted = format!("{}%", utilization_rates.gpu);
91                formatted = format!("{formatted} - Usage {utilization_rate_formatted}");
92            }
93
94            MetricEntry::new(Self::NAME.to_string(), formatted, raw_running)
95        };
96
97        match &self.nvml {
98            Some(nvml) => available(nvml),
99            None => not_available(),
100        }
101    }
102
103    fn clear(&mut self) {}
104}