Skip to main content

flodl_cli/libtorch/
detect.rs

1//! libtorch installation detection and .arch metadata parsing.
2
3use std::fs;
4use std::path::Path;
5
6use crate::util::system::GpuInfo;
7
8// ---------------------------------------------------------------------------
9// Types
10// ---------------------------------------------------------------------------
11
12/// Metadata about an installed libtorch variant (from `.arch` file).
13pub struct LibtorchInfo {
14    /// Relative path from project root (e.g. "precompiled/cu128", "builds/sm61-sm120").
15    pub path: String,
16    pub torch_version: Option<String>,
17    pub cuda_version: Option<String>,
18    pub archs: Option<String>,
19    pub source: Option<String>,
20}
21
22// ---------------------------------------------------------------------------
23// Detection
24// ---------------------------------------------------------------------------
25
26/// Read the active libtorch variant from `<root>/libtorch/.active` and parse
27/// its `.arch` metadata.
28pub fn read_active(root: &Path) -> Option<LibtorchInfo> {
29    let active_path = root.join("libtorch/.active");
30    let active = fs::read_to_string(active_path).ok()?;
31    let path = active.trim().to_string();
32    if path.is_empty() {
33        return None;
34    }
35
36    let arch_path = root.join(format!("libtorch/{}/.arch", path));
37    let mut info = LibtorchInfo {
38        path,
39        torch_version: None,
40        cuda_version: None,
41        archs: None,
42        source: None,
43    };
44
45    if let Ok(content) = fs::read_to_string(arch_path) {
46        for line in content.lines() {
47            if let Some(val) = line.strip_prefix("torch=") {
48                info.torch_version = Some(val.to_string());
49            } else if let Some(val) = line.strip_prefix("cuda=") {
50                info.cuda_version = Some(val.to_string());
51            } else if let Some(val) = line.strip_prefix("archs=") {
52                info.archs = Some(val.to_string());
53            } else if let Some(val) = line.strip_prefix("source=") {
54                info.source = Some(val.to_string());
55            }
56        }
57    }
58
59    Some(info)
60}
61
62/// List all installed libtorch variants under `<root>/libtorch/`.
63///
64/// Scans `precompiled/` and `builds/` subdirectories.
65pub fn list_variants(root: &Path) -> Vec<String> {
66    let mut variants = Vec::new();
67    let lt_dir = root.join("libtorch");
68
69    for subdir in ["precompiled", "builds"] {
70        let dir = lt_dir.join(subdir);
71        if let Ok(entries) = fs::read_dir(&dir) {
72            for entry in entries.flatten() {
73                if entry.path().join("lib").is_dir() {
74                    if let Some(name) = entry.file_name().to_str() {
75                        variants.push(format!("{}/{}", subdir, name));
76                    }
77                }
78            }
79        }
80    }
81
82    variants.sort();
83    variants
84}
85
86/// Check whether a GPU's compute capability is covered by the libtorch
87/// variant's compiled architectures (from the .arch file).
88pub fn arch_compatible(gpu: &GpuInfo, archs: &str) -> bool {
89    let exact = format!("{}.{}", gpu.sm_major, gpu.sm_minor);
90    archs.contains(&exact) || archs.contains(&format!("{}", gpu.sm_major))
91}
92
93/// Check whether a libtorch variant directory looks valid (has lib/).
94pub fn is_valid_variant(root: &Path, variant: &str) -> bool {
95    root.join(format!("libtorch/{}/lib", variant)).is_dir()
96}
97
98/// Set the active libtorch variant by writing `<root>/libtorch/.active`.
99pub fn set_active(root: &Path, variant: &str) -> Result<(), String> {
100    let lt_dir = root.join("libtorch");
101    fs::create_dir_all(&lt_dir)
102        .map_err(|e| format!("cannot create libtorch/: {}", e))?;
103    fs::write(lt_dir.join(".active"), format!("{}\n", variant))
104        .map_err(|e| format!("cannot write libtorch/.active: {}", e))
105}