1use anyhow::{anyhow, Context, Result};
8use std::fs;
9use std::path::Path;
10use std::path::PathBuf;
11
12#[cfg(feature = "c")]
13mod c;
14
15#[cfg(feature = "cpp")]
16mod cpp;
17
18#[cfg(feature = "python")]
19mod python;
20
21mod constants;
22
23use blake3::{Hash, Hasher};
24use nvml_wrapper::{Device, Nvml};
25
26fn blake3_hash_string(input: &str) -> String {
27 let mut hasher: Hasher = Hasher::new();
28 hasher.update(input.as_bytes());
29 let result: Hash = hasher.finalize();
30 result.to_hex().to_string()
31}
32
33pub fn get_gpu_node_id(cache_file_path: Option<&PathBuf>) -> Result<String, anyhow::Error> {
35 let default_path: &Path = Path::new(constants::DEFAULT_CACHE_FILEPATH);
36 let binding: PathBuf = default_path.to_path_buf();
37 let path: &PathBuf = cache_file_path.unwrap_or(&binding);
38
39 if Path::new(path).exists() {
40 let contents: String = fs::read_to_string(path).context("Failed to read cache file")?;
41 return Ok(contents);
42 }
43
44 let nvml: Nvml = Nvml::init().context("Failed to init nvml")?;
45 let device_count: u32 = nvml
46 .device_count()
47 .context("Failed to get nvml device count")?;
48 let mut uuids: Vec<String> = Vec::new();
49
50 for n in 0..device_count {
51 let device: Device<'_> = nvml
52 .device_by_index(n)
53 .context("Failed to get nvml device by index")?;
54 let uuid: String = device.uuid().context("Failed to get device uuid")?;
55 uuids.push(uuid);
56 }
57
58 if uuids.is_empty() {
59 return Err(anyhow!("No GPUs found"));
60 }
61
62 uuids.sort();
64
65 let concatenated_uuids: String = uuids.join("");
66
67 let gpu_node_id: String = blake3_hash_string(&concatenated_uuids);
68
69 fs::write(path, &gpu_node_id).context("Failed to write cache")?;
70
71 Ok(gpu_node_id)
72}