Skip to main content

flodl_cli/libtorch/
download.rs

1//! `fdl libtorch download` -- download pre-built libtorch.
2
3use std::fs;
4use std::path::{Path, PathBuf};
5
6use crate::context::Context;
7use crate::util::http;
8use crate::util::archive;
9use crate::util::system;
10use super::detect;
11
12// ---------------------------------------------------------------------------
13// Constants
14// ---------------------------------------------------------------------------
15
16const LIBTORCH_VERSION: &str = "2.10.0";
17
18/// Pre-built variant metadata.
19struct VariantSpec {
20    /// Label for display (e.g. "CUDA 12.8").
21    label: &'static str,
22    /// Directory name under precompiled/ (e.g. "cu128").
23    dir_name: &'static str,
24    /// Value for .arch `cuda=` field.
25    arch_cuda: &'static str,
26    /// Space-separated compute capabilities covered.
27    arch_archs: &'static str,
28    /// Value for .arch `variant=` field.
29    arch_variant: &'static str,
30}
31
32const CPU_SPEC: VariantSpec = VariantSpec {
33    label: "CPU",
34    dir_name: "cpu",
35    arch_cuda: "none",
36    arch_archs: "cpu",
37    arch_variant: "cpu",
38};
39
40const CU126_SPEC: VariantSpec = VariantSpec {
41    label: "CUDA 12.6",
42    dir_name: "cu126",
43    arch_cuda: "12.6",
44    arch_archs: "5.0 5.2 6.0 6.1 7.0 7.5 8.0 8.6 8.9 9.0",
45    arch_variant: "cu126",
46};
47
48const CU128_SPEC: VariantSpec = VariantSpec {
49    label: "CUDA 12.8",
50    dir_name: "cu128",
51    arch_cuda: "12.8",
52    arch_archs: "7.0 7.5 8.0 8.6 8.9 9.0 12.0",
53    arch_variant: "cu128",
54};
55
56// ---------------------------------------------------------------------------
57// Download options
58// ---------------------------------------------------------------------------
59
60pub enum Variant {
61    Cpu,
62    Cuda126,
63    Cuda128,
64    Auto,
65}
66
67pub struct DownloadOpts {
68    pub variant: Variant,
69    pub custom_path: Option<PathBuf>,
70    pub activate: bool,
71    pub dry_run: bool,
72}
73
74impl Default for DownloadOpts {
75    fn default() -> Self {
76        Self {
77            variant: Variant::Auto,
78            custom_path: None,
79            activate: true,
80            dry_run: false,
81        }
82    }
83}
84
85// ---------------------------------------------------------------------------
86// URL construction
87// ---------------------------------------------------------------------------
88
89fn download_url(spec: &VariantSpec) -> Result<String, String> {
90    let os = std::env::consts::OS;
91    let arch = std::env::consts::ARCH;
92
93    match (os, arch) {
94        ("linux", "x86_64") => {}
95        ("macos", "aarch64") => {
96            if spec.arch_cuda != "none" {
97                return Err("macOS only supports CPU libtorch".into());
98            }
99        }
100        ("macos", _) => {
101            return Err(format!(
102                "macOS libtorch requires Apple Silicon (arm64), got {}.\n\
103                 macOS x86_64 was dropped after PyTorch 2.2.",
104                arch
105            ));
106        }
107        ("windows", "x86_64") => {}
108        _ => {
109            return Err(format!(
110                "Unsupported platform: {} {}.\n\
111                 libtorch is available for Linux x86_64, macOS arm64, and Windows x86_64.",
112                os, arch
113            ));
114        }
115    }
116
117    // macOS ARM has a different filename pattern
118    if os == "macos" {
119        return Ok(format!(
120            "https://download.pytorch.org/libtorch/cpu/libtorch-macos-arm64-{}.zip",
121            LIBTORCH_VERSION
122        ));
123    }
124
125    // Linux and Windows use the same URL pattern
126    let filename = match spec.arch_variant {
127        "cpu" => format!(
128            "libtorch-shared-with-deps-{}%2Bcpu.zip",
129            LIBTORCH_VERSION
130        ),
131        variant => format!(
132            "libtorch-shared-with-deps-{}%2B{}.zip",
133            LIBTORCH_VERSION, variant
134        ),
135    };
136
137    let bucket = spec.arch_variant; // "cpu", "cu126", "cu128"
138    Ok(format!(
139        "https://download.pytorch.org/libtorch/{}/{}",
140        bucket, filename
141    ))
142}
143
144// ---------------------------------------------------------------------------
145// Auto-detection
146// ---------------------------------------------------------------------------
147
148fn auto_detect_variant() -> &'static VariantSpec {
149    let gpus = system::detect_gpus();
150    if gpus.is_empty() {
151        println!("  No NVIDIA GPU detected. Using CPU variant.");
152        return &CPU_SPEC;
153    }
154
155    // Find lowest and highest major compute capability
156    let lo_major = gpus.iter().map(|g| g.sm_major).min().unwrap_or(0);
157    let hi_major = gpus.iter().map(|g| g.sm_major).max().unwrap_or(0);
158
159    // cu128 requires Volta+ (sm_70+), cu126 supports down to sm_50
160    if lo_major >= 7 {
161        println!("  Detected Volta+ GPU(s). Using cu128.");
162        &CU128_SPEC
163    } else if hi_major >= 10 {
164        // Mixed: old + new GPUs. cu126 covers the old ones, cu128 covers the new.
165        // Default to cu126 which covers more architectures.
166        println!(
167            "  Mixed GPU architectures (sm_{}.x to sm_{}.x).",
168            lo_major, hi_major
169        );
170        println!("  Using cu126 (broadest pre-Volta coverage).");
171        println!("  For all GPUs, consider: fdl libtorch build");
172        &CU126_SPEC
173    } else {
174        println!("  Detected pre-Volta GPU(s). Using cu126.");
175        &CU126_SPEC
176    }
177}
178
179fn resolve_variant(variant: &Variant) -> &'static VariantSpec {
180    match variant {
181        Variant::Cpu => &CPU_SPEC,
182        Variant::Cuda126 => &CU126_SPEC,
183        Variant::Cuda128 => &CU128_SPEC,
184        Variant::Auto => auto_detect_variant(),
185    }
186}
187
188// ---------------------------------------------------------------------------
189// Core download logic
190// ---------------------------------------------------------------------------
191
192pub fn run(opts: DownloadOpts) -> Result<(), String> {
193    let ctx = Context::resolve();
194    run_with_context(opts, &ctx)
195}
196
197/// Run with an explicit context (used by `setup` which has its own context).
198pub fn run_with_context(opts: DownloadOpts, ctx: &Context) -> Result<(), String> {
199    let spec = resolve_variant(&opts.variant);
200    let url = download_url(spec)?;
201
202    // Determine install path
203    let install_path = if let Some(ref p) = opts.custom_path {
204        p.clone()
205    } else {
206        ctx.root.join(format!("libtorch/precompiled/{}", spec.dir_name))
207    };
208
209    let variant_id = format!("precompiled/{}", spec.dir_name);
210
211    println!();
212    println!("  libtorch {} ({})", LIBTORCH_VERSION, spec.label);
213    println!("  URL:  {}", url);
214    println!("  Path: {}", install_path.display());
215
216    if opts.dry_run {
217        println!();
218        println!("  [dry-run] Would download and extract to above path.");
219        return Ok(());
220    }
221
222    // Check existing installation
223    if install_path.exists() {
224        let build_ver_path = install_path.join("build-version");
225        let existing_ver = fs::read_to_string(&build_ver_path)
226            .ok()
227            .map(|s| s.trim().to_string());
228
229        // build-version may contain variant suffix (e.g. "2.10.0+cpu")
230        let ver_matches = existing_ver.as_deref().is_some_and(|v| {
231            v == LIBTORCH_VERSION || v.starts_with(&format!("{}+", LIBTORCH_VERSION))
232        });
233
234        if ver_matches {
235            println!();
236            println!("  Already installed (version {}).", LIBTORCH_VERSION);
237            return Ok(());
238        }
239
240        println!();
241        println!(
242            "  Removing existing installation (version: {})...",
243            existing_ver.as_deref().unwrap_or("unknown")
244        );
245        fs::remove_dir_all(&install_path)
246            .map_err(|e| format!("cannot remove {}: {}", install_path.display(), e))?;
247    }
248
249    // Download to temp file
250    let tmp_dir = std::env::temp_dir();
251    let tmp_zip = tmp_dir.join(format!("libtorch-{}-{}.zip", spec.dir_name, LIBTORCH_VERSION));
252
253    println!();
254    println!("  Downloading...");
255    http::download_file(&url, &tmp_zip)?;
256
257    // Extract to temp directory (zip contains a top-level "libtorch/" dir)
258    let tmp_extract = tmp_dir.join(format!("libtorch-extract-{}", std::process::id()));
259    println!("  Extracting...");
260    archive::extract_zip(&tmp_zip, &tmp_extract)?;
261
262    // Move extracted contents to target path
263    let extracted_lt = tmp_extract.join("libtorch");
264    let source = if extracted_lt.is_dir() {
265        &extracted_lt
266    } else {
267        &tmp_extract
268    };
269
270    fs::create_dir_all(&install_path)
271        .map_err(|e| format!("cannot create {}: {}", install_path.display(), e))?;
272
273    // Move all files from extracted dir to install path
274    move_contents(source, &install_path)?;
275
276    // Cleanup temp files
277    let _ = fs::remove_file(&tmp_zip);
278    let _ = fs::remove_dir_all(&tmp_extract);
279
280    // Verify
281    let lib_dir = install_path.join("lib");
282    let has_lib = lib_dir.join("libtorch.so").exists()
283        || lib_dir.join("libtorch.dylib").exists()
284        || lib_dir.join("torch.lib").exists();
285
286    if !has_lib {
287        return Err(format!(
288            "libtorch library not found at {}.\n\
289             The archive structure may have changed.\n\
290             Check: ls {}",
291            lib_dir.display(),
292            lib_dir.display()
293        ));
294    }
295
296    // Write .arch metadata (always, both project and global)
297    let arch_content = format!(
298        "cuda={}\ntorch={}\narchs={}\nsource=precompiled\nvariant={}\n",
299        spec.arch_cuda, LIBTORCH_VERSION, spec.arch_archs, spec.arch_variant
300    );
301    fs::write(install_path.join(".arch"), arch_content)
302        .map_err(|e| format!("cannot write .arch: {}", e))?;
303
304    if opts.activate {
305        detect::set_active(&ctx.root, &variant_id)?;
306    }
307
308    println!();
309    println!("  ================================================");
310    println!("  libtorch {} ({}) installed", LIBTORCH_VERSION, spec.label);
311    println!("  {}", install_path.display());
312    println!("  ================================================");
313
314    if ctx.is_project {
315        println!();
316        println!("  .arch:   {}/.arch", install_path.display());
317        if opts.activate {
318            println!("  .active: libtorch/.active -> {}", variant_id);
319        }
320        println!();
321        if spec.arch_cuda != "none" {
322            println!("  Run 'make cuda-test' to verify.");
323        } else {
324            println!("  Run 'make test' to verify.");
325        }
326    } else {
327        println!();
328        println!("  Installed to: {}", install_path.display());
329        println!();
330        println!("  To use with tch-rs or flodl, add to your shell profile:");
331        println!();
332        println!("    export LIBTORCH=\"{}\"", install_path.display());
333        println!(
334            "    export LD_LIBRARY_PATH=\"{}/lib${{LD_LIBRARY_PATH:+:$LD_LIBRARY_PATH}}\"",
335            install_path.display()
336        );
337        println!();
338        println!("  Or start a new floDl project:");
339        println!("    fdl init my-project");
340    }
341
342    Ok(())
343}
344
345// ---------------------------------------------------------------------------
346// Helpers
347// ---------------------------------------------------------------------------
348
349/// Move all files and directories from `src` into `dest`.
350fn move_contents(src: &Path, dest: &Path) -> Result<(), String> {
351    let entries = fs::read_dir(src)
352        .map_err(|e| format!("cannot read {}: {}", src.display(), e))?;
353
354    for entry in entries {
355        let entry = entry.map_err(|e| format!("read_dir error: {}", e))?;
356        let from = entry.path();
357        let name = entry.file_name();
358        let to = dest.join(&name);
359
360        // Try rename first (fast, same filesystem). Fall back to copy.
361        if fs::rename(&from, &to).is_err() {
362            if from.is_dir() {
363                copy_dir_recursive(&from, &to)?;
364            } else {
365                fs::copy(&from, &to)
366                    .map_err(|e| format!("copy {} -> {}: {}", from.display(), to.display(), e))?;
367            }
368        }
369    }
370    Ok(())
371}
372
373fn copy_dir_recursive(src: &Path, dest: &Path) -> Result<(), String> {
374    fs::create_dir_all(dest)
375        .map_err(|e| format!("cannot create {}: {}", dest.display(), e))?;
376
377    for entry in fs::read_dir(src).map_err(|e| format!("read {}: {}", src.display(), e))? {
378        let entry = entry.map_err(|e| format!("read_dir error: {}", e))?;
379        let from = entry.path();
380        let to = dest.join(entry.file_name());
381
382        if from.is_dir() {
383            copy_dir_recursive(&from, &to)?;
384        } else {
385            fs::copy(&from, &to)
386                .map_err(|e| format!("copy {} -> {}: {}", from.display(), to.display(), e))?;
387        }
388    }
389    Ok(())
390}
391
392/// Get the current libtorch version constant (for display and checks).
393#[allow(dead_code)]
394pub fn libtorch_version() -> &'static str {
395    LIBTORCH_VERSION
396}