Skip to main content

burn_train/metric/
cuda.rs

1use std::sync::Arc;
2
3use super::MetricMetadata;
4use crate::metric::{Metric, MetricName, SerializedEntry};
5use nvml_wrapper::Nvml;
6
7/// Track basic cuda infos.
8#[derive(Clone)]
9pub struct CudaMetric {
10    name: MetricName,
11    nvml: Arc<Option<Nvml>>,
12}
13
14impl CudaMetric {
15    /// Creates a new metric for CUDA.
16    pub fn new() -> Self {
17        Self {
18            name: Arc::new("Cuda".to_string()),
19            nvml: Arc::new(Nvml::init().map(Some).unwrap_or_else(|err| {
20                log::warn!("Unable to initialize CUDA Metric: {err}");
21                None
22            })),
23        }
24    }
25}
26
27impl Default for CudaMetric {
28    fn default() -> Self {
29        Self::new()
30    }
31}
32
33impl Metric for CudaMetric {
34    type Input = ();
35
36    fn update(&mut self, _item: &(), _metadata: &MetricMetadata) -> SerializedEntry {
37        let not_available =
38            || SerializedEntry::new("Unavailable".to_string(), "Unavailable".to_string());
39
40        let available = |nvml: &Nvml| {
41            let mut formatted = String::new();
42            let mut raw_running = String::new();
43
44            let device_count = match nvml.device_count() {
45                Ok(val) => val,
46                Err(err) => {
47                    log::warn!("Unable to get the number of cuda devices: {err}");
48                    return not_available();
49                }
50            };
51
52            for index in 0..device_count {
53                let device = match nvml.device_by_index(index) {
54                    Ok(val) => val,
55                    Err(err) => {
56                        log::warn!("Unable to get device {index}: {err}");
57                        return not_available();
58                    }
59                };
60                let memory_info = match device.memory_info() {
61                    Ok(info) => info,
62                    Err(err) => {
63                        log::warn!("Unable to get memory info from device {index}: {err}");
64                        return not_available();
65                    }
66                };
67
68                let used_gb = memory_info.used as f64 * 1e-9;
69                let total_gb = memory_info.total as f64 * 1e-9;
70
71                let memory_info_formatted = format!("{used_gb:.2}/{total_gb:.2} Gb");
72                let memory_info_raw = format!("{used_gb}/{total_gb}");
73
74                formatted = format!("{formatted} GPU #{index} - Memory {memory_info_formatted}");
75                raw_running = format!("{memory_info_raw} ");
76
77                let utilization_rates = match device.utilization_rates() {
78                    Ok(rate) => rate,
79                    Err(err) => {
80                        log::warn!("Unable to get utilization rates from device {index}: {err}");
81                        return not_available();
82                    }
83                };
84                let utilization_rate_formatted = format!("{}%", utilization_rates.gpu);
85                formatted = format!("{formatted} - Usage {utilization_rate_formatted}");
86            }
87
88            SerializedEntry::new(formatted, raw_running)
89        };
90
91        match self.nvml.as_ref() {
92            Some(nvml) => available(nvml),
93            None => not_available(),
94        }
95    }
96
97    fn clear(&mut self) {}
98
99    fn name(&self) -> MetricName {
100        self.name.clone()
101    }
102}