burn_train/metric/
cuda.rs1use std::sync::Arc;
2
3use super::MetricMetadata;
4use crate::metric::{Metric, MetricName, SerializedEntry};
5use nvml_wrapper::Nvml;
6
7#[derive(Clone)]
9pub struct CudaMetric {
10 name: MetricName,
11 nvml: Arc<Option<Nvml>>,
12}
13
14impl CudaMetric {
15 pub fn new() -> Self {
17 Self {
18 name: Arc::new("Cuda".to_string()),
19 nvml: Arc::new(Nvml::init().map(Some).unwrap_or_else(|err| {
20 log::warn!("Unable to initialize CUDA Metric: {err}");
21 None
22 })),
23 }
24 }
25}
26
27impl Default for CudaMetric {
28 fn default() -> Self {
29 Self::new()
30 }
31}
32
33impl Metric for CudaMetric {
34 type Input = ();
35
36 fn update(&mut self, _item: &(), _metadata: &MetricMetadata) -> SerializedEntry {
37 let not_available =
38 || SerializedEntry::new("Unavailable".to_string(), "Unavailable".to_string());
39
40 let available = |nvml: &Nvml| {
41 let mut formatted = String::new();
42 let mut raw_running = String::new();
43
44 let device_count = match nvml.device_count() {
45 Ok(val) => val,
46 Err(err) => {
47 log::warn!("Unable to get the number of cuda devices: {err}");
48 return not_available();
49 }
50 };
51
52 for index in 0..device_count {
53 let device = match nvml.device_by_index(index) {
54 Ok(val) => val,
55 Err(err) => {
56 log::warn!("Unable to get device {index}: {err}");
57 return not_available();
58 }
59 };
60 let memory_info = match device.memory_info() {
61 Ok(info) => info,
62 Err(err) => {
63 log::warn!("Unable to get memory info from device {index}: {err}");
64 return not_available();
65 }
66 };
67
68 let used_gb = memory_info.used as f64 * 1e-9;
69 let total_gb = memory_info.total as f64 * 1e-9;
70
71 let memory_info_formatted = format!("{used_gb:.2}/{total_gb:.2} Gb");
72 let memory_info_raw = format!("{used_gb}/{total_gb}");
73
74 formatted = format!("{formatted} GPU #{index} - Memory {memory_info_formatted}");
75 raw_running = format!("{memory_info_raw} ");
76
77 let utilization_rates = match device.utilization_rates() {
78 Ok(rate) => rate,
79 Err(err) => {
80 log::warn!("Unable to get utilization rates from device {index}: {err}");
81 return not_available();
82 }
83 };
84 let utilization_rate_formatted = format!("{}%", utilization_rates.gpu);
85 formatted = format!("{formatted} - Usage {utilization_rate_formatted}");
86 }
87
88 SerializedEntry::new(formatted, raw_running)
89 };
90
91 match self.nvml.as_ref() {
92 Some(nvml) => available(nvml),
93 None => not_available(),
94 }
95 }
96
97 fn clear(&mut self) {}
98
99 fn name(&self) -> MetricName {
100 self.name.clone()
101 }
102}