burn_train/metric/
cuda.rs1use super::{Adaptor, MetricMetadata};
2use crate::metric::{Metric, MetricEntry};
3use nvml_wrapper::Nvml;
4
5pub struct CudaMetric {
7 nvml: Option<Nvml>,
8}
9
10impl CudaMetric {
11 pub fn new() -> Self {
13 Self {
14 nvml: Nvml::init().map(Some).unwrap_or_else(|err| {
15 log::warn!("Unable to initialize CUDA Metric: {err}");
16 None
17 }),
18 }
19 }
20}
21
22impl Default for CudaMetric {
23 fn default() -> Self {
24 Self::new()
25 }
26}
27
28impl<T> Adaptor<()> for T {
29 fn adapt(&self) {}
30}
31
32impl Metric for CudaMetric {
33 const NAME: &'static str = "CUDA Stats";
34
35 type Input = ();
36
37 fn update(&mut self, _item: &(), _metadata: &MetricMetadata) -> MetricEntry {
38 let not_available = || {
39 MetricEntry::new(
40 Self::NAME.to_string(),
41 "Unavailable".to_string(),
42 "Unavailable".to_string(),
43 )
44 };
45
46 let available = |nvml: &Nvml| {
47 let mut formatted = String::new();
48 let mut raw_running = String::new();
49
50 let device_count = match nvml.device_count() {
51 Ok(val) => val,
52 Err(err) => {
53 log::warn!("Unable to get the number of cuda devices: {err}");
54 return not_available();
55 }
56 };
57
58 for index in 0..device_count {
59 let device = match nvml.device_by_index(index) {
60 Ok(val) => val,
61 Err(err) => {
62 log::warn!("Unable to get device {index}: {err}");
63 return not_available();
64 }
65 };
66 let memory_info = match device.memory_info() {
67 Ok(info) => info,
68 Err(err) => {
69 log::warn!("Unable to get memory info from device {index}: {err}");
70 return not_available();
71 }
72 };
73
74 let used_gb = memory_info.used as f64 * 1e-9;
75 let total_gb = memory_info.total as f64 * 1e-9;
76
77 let memory_info_formatted = format!("{used_gb:.2}/{total_gb:.2} Gb");
78 let memory_info_raw = format!("{used_gb}/{total_gb}");
79
80 formatted = format!("{formatted} GPU #{index} - Memory {memory_info_formatted}");
81 raw_running = format!("{memory_info_raw} ");
82
83 let utilization_rates = match device.utilization_rates() {
84 Ok(rate) => rate,
85 Err(err) => {
86 log::warn!("Unable to get utilization rates from device {index}: {err}");
87 return not_available();
88 }
89 };
90 let utilization_rate_formatted = format!("{}%", utilization_rates.gpu);
91 formatted = format!("{formatted} - Usage {utilization_rate_formatted}");
92 }
93
94 MetricEntry::new(Self::NAME.to_string(), formatted, raw_running)
95 };
96
97 match &self.nvml {
98 Some(nvml) => available(nvml),
99 None => not_available(),
100 }
101 }
102
103 fn clear(&mut self) {}
104}