gni_lib/
lib.rs

1// Copyright (c) Meta Platforms, Inc. and affiliates.
2// All rights reserved.
3
4// This source code is licensed under the license found in the
5// LICENSE file in the root directory of this source tree.
6
7use anyhow::{anyhow, Context, Result};
8use std::fs;
9use std::path::Path;
10use std::path::PathBuf;
11
12#[cfg(feature = "c")]
13mod c;
14
15#[cfg(feature = "cpp")]
16mod cpp;
17
18#[cfg(feature = "python")]
19mod python;
20
21mod constants;
22
23use blake3::{Hash, Hasher};
24use nvml_wrapper::{Device, Nvml};
25
26fn blake3_hash_string(input: &str) -> String {
27    let mut hasher: Hasher = Hasher::new();
28    hasher.update(input.as_bytes());
29    let result: Hash = hasher.finalize();
30    result.to_hex().to_string()
31}
32
33/// Returns the GPU Node ID as String
34pub fn get_gpu_node_id(cache_file_path: Option<&PathBuf>) -> Result<String, anyhow::Error> {
35    let default_path: &Path = Path::new(constants::DEFAULT_CACHE_FILEPATH);
36    let binding: PathBuf = default_path.to_path_buf();
37    let path: &PathBuf = cache_file_path.unwrap_or(&binding);
38
39    if Path::new(path).exists() {
40        let contents: String = fs::read_to_string(path).context("Failed to read cache file")?;
41        return Ok(contents);
42    }
43
44    let nvml: Nvml = Nvml::init().context("Failed to init nvml")?;
45    let device_count: u32 = nvml
46        .device_count()
47        .context("Failed to get nvml device count")?;
48    let mut uuids: Vec<String> = Vec::new();
49
50    for n in 0..device_count {
51        let device: Device<'_> = nvml
52            .device_by_index(n)
53            .context("Failed to get nvml device by index")?;
54        let uuid: String = device.uuid().context("Failed to get device uuid")?;
55        uuids.push(uuid);
56    }
57
58    if uuids.is_empty() {
59        return Err(anyhow!("No GPUs found"));
60    }
61
62    // sort the UUIDs to ensure a consistent hash (the node ID should be the same regardless of the order of the GPUs)
63    uuids.sort();
64
65    let concatenated_uuids: String = uuids.join("");
66
67    let gpu_node_id: String = blake3_hash_string(&concatenated_uuids);
68
69    fs::write(path, &gpu_node_id).context("Failed to write cache")?;
70
71    Ok(gpu_node_id)
72}