burn_train/metric/
cuda.rs1use super::{Adaptor, MetricMetadata};
2use crate::metric::{Metric, MetricEntry};
3use nvml_wrapper::Nvml;
4
5pub struct CudaMetric {
7 nvml: Option<Nvml>,
8}
9
10impl CudaMetric {
11 pub fn new() -> Self {
13 Self {
14 nvml: Nvml::init().map(Some).unwrap_or_else(|err| {
15 log::warn!("Unable to initialize CUDA Metric: {err}");
16 None
17 }),
18 }
19 }
20}
21
22impl Default for CudaMetric {
23 fn default() -> Self {
24 Self::new()
25 }
26}
27
28impl<T> Adaptor<()> for T {
29 fn adapt(&self) {}
30}
31
32impl Metric for CudaMetric {
33 type Input = ();
34
35 fn update(&mut self, _item: &(), _metadata: &MetricMetadata) -> MetricEntry {
36 let not_available = || {
37 MetricEntry::new(
38 self.name(),
39 "Unavailable".to_string(),
40 "Unavailable".to_string(),
41 )
42 };
43
44 let available = |nvml: &Nvml| {
45 let mut formatted = String::new();
46 let mut raw_running = String::new();
47
48 let device_count = match nvml.device_count() {
49 Ok(val) => val,
50 Err(err) => {
51 log::warn!("Unable to get the number of cuda devices: {err}");
52 return not_available();
53 }
54 };
55
56 for index in 0..device_count {
57 let device = match nvml.device_by_index(index) {
58 Ok(val) => val,
59 Err(err) => {
60 log::warn!("Unable to get device {index}: {err}");
61 return not_available();
62 }
63 };
64 let memory_info = match device.memory_info() {
65 Ok(info) => info,
66 Err(err) => {
67 log::warn!("Unable to get memory info from device {index}: {err}");
68 return not_available();
69 }
70 };
71
72 let used_gb = memory_info.used as f64 * 1e-9;
73 let total_gb = memory_info.total as f64 * 1e-9;
74
75 let memory_info_formatted = format!("{used_gb:.2}/{total_gb:.2} Gb");
76 let memory_info_raw = format!("{used_gb}/{total_gb}");
77
78 formatted = format!("{formatted} GPU #{index} - Memory {memory_info_formatted}");
79 raw_running = format!("{memory_info_raw} ");
80
81 let utilization_rates = match device.utilization_rates() {
82 Ok(rate) => rate,
83 Err(err) => {
84 log::warn!("Unable to get utilization rates from device {index}: {err}");
85 return not_available();
86 }
87 };
88 let utilization_rate_formatted = format!("{}%", utilization_rates.gpu);
89 formatted = format!("{formatted} - Usage {utilization_rate_formatted}");
90 }
91
92 MetricEntry::new(self.name(), formatted, raw_running)
93 };
94
95 match &self.nvml {
96 Some(nvml) => available(nvml),
97 None => not_available(),
98 }
99 }
100
101 fn clear(&mut self) {}
102
103 fn name(&self) -> String {
104 "CUDA Stats".to_string()
105 }
106}