Skip to main content

flodl_cli/libtorch/
build.rs

1//! `fdl libtorch build` -- compile libtorch from PyTorch source.
2//!
3//! Two backends: Docker (isolated, reproducible) or native (faster, requires
4//! CUDA toolkit + build tools on host). Auto-detects available backends and
5//! asks the user when both are present.
6
7use std::fs;
8use std::io::Write;
9use std::path::Path;
10use std::process::{Command, Stdio};
11
12use crate::context::Context;
13use crate::util::docker;
14use crate::util::prompt;
15use crate::util::system;
16use super::detect;
17
18const DOCKERFILE_CONTENT: &str = include_str!("../../assets/Dockerfile.cuda.source");
19const IMAGE_NAME: &str = "flodl-libtorch-builder";
20const LIBTORCH_VERSION: &str = "2.10.0";
21const PYTORCH_VERSION: &str = "v2.10.0";
22
23const PYTHON_DEPS: &[&str] = &[
24    "typing_extensions", "pyyaml", "filelock",
25    "jinja2", "networkx", "sympy", "packaging",
26];
27
28// ---------------------------------------------------------------------------
29// Options
30// ---------------------------------------------------------------------------
31
32#[derive(Default)]
33pub enum BuildBackend {
34    /// Auto-detect: ask user if both available, otherwise use whatever works.
35    #[default]
36    Auto,
37    /// Force Docker build.
38    Docker,
39    /// Force native build (no Docker).
40    Native,
41}
42
43pub struct BuildOpts {
44    /// Override CUDA architectures (semicolon-separated, e.g. "6.1;12.0").
45    /// None = auto-detect from GPUs.
46    pub archs: Option<String>,
47    /// Override MAX_JOBS for compilation. Default: 6.
48    pub max_jobs: usize,
49    /// Print what would happen without building.
50    pub dry_run: bool,
51    /// Which backend to use.
52    pub backend: BuildBackend,
53}
54
55impl Default for BuildOpts {
56    fn default() -> Self {
57        Self {
58            archs: None,
59            max_jobs: 6,
60            dry_run: false,
61            backend: BuildBackend::Auto,
62        }
63    }
64}
65
66// ---------------------------------------------------------------------------
67// Auto-detect GPU architectures
68// ---------------------------------------------------------------------------
69
70fn detect_arch_list() -> Result<String, String> {
71    let gpus = system::detect_gpus();
72    if gpus.is_empty() {
73        return Err(
74            "No NVIDIA GPUs detected.\n\
75             Source builds require GPUs to auto-detect architectures.\n\
76             Use --archs to specify manually (e.g. --archs \"8.6;12.0\")."
77                .into(),
78        );
79    }
80
81    // Collect unique compute capabilities, sorted numerically
82    let mut caps: Vec<(u32, u32)> = gpus
83        .iter()
84        .map(|g| (g.sm_major, g.sm_minor))
85        .collect();
86    caps.sort();
87    caps.dedup();
88    let caps: Vec<String> = caps.iter().map(|(ma, mi)| format!("{}.{}", ma, mi)).collect();
89
90    println!("  GPUs detected:");
91    for g in &gpus {
92        println!(
93            "    [{}] {} (sm_{}.{})",
94            g.index, g.short_name(), g.sm_major, g.sm_minor
95        );
96    }
97
98    Ok(caps.join(";"))
99}
100
101/// Convert "6.1;12.0" -> "sm61-sm120" for directory naming.
102fn arch_dir_name(archs: &str) -> String {
103    archs
104        .split(';')
105        .map(|cap| {
106            let clean = cap.replace('.', "");
107            format!("sm{}", clean)
108        })
109        .collect::<Vec<_>>()
110        .join("-")
111}
112
113// ---------------------------------------------------------------------------
114// Native toolchain detection
115// ---------------------------------------------------------------------------
116
117struct NativeTools {
118    nvcc: bool,
119    cmake: bool,
120    python3: bool,
121    git: bool,
122    gcc: bool,
123}
124
125impl NativeTools {
126    fn detect() -> Self {
127        Self {
128            nvcc: has_tool("nvcc"),
129            cmake: has_tool("cmake"),
130            python3: has_tool("python3"),
131            git: has_tool("git"),
132            gcc: has_tool("gcc") || has_tool("cc"),
133        }
134    }
135
136    fn ready(&self) -> bool {
137        self.nvcc && self.cmake && self.python3 && self.git && self.gcc
138    }
139
140    fn missing(&self) -> Vec<&'static str> {
141        let mut m = Vec::new();
142        if !self.nvcc { m.push("nvcc (CUDA toolkit)"); }
143        if !self.cmake { m.push("cmake"); }
144        if !self.python3 { m.push("python3"); }
145        if !self.git { m.push("git"); }
146        if !self.gcc { m.push("gcc/cc (C++ compiler)"); }
147        m
148    }
149}
150
151fn has_tool(name: &str) -> bool {
152    Command::new(name)
153        .arg("--version")
154        .stdout(Stdio::null())
155        .stderr(Stdio::null())
156        .status()
157        .is_ok_and(|s| s.success())
158}
159
160// ---------------------------------------------------------------------------
161// Backend selection
162// ---------------------------------------------------------------------------
163
164fn select_backend(backend: &BuildBackend) -> Result<&'static str, String> {
165    let has_docker = docker::has_docker();
166    let native = NativeTools::detect();
167
168    match backend {
169        BuildBackend::Docker => {
170            if !has_docker {
171                return Err(
172                    "Docker was requested but is not available.\n\
173                     Install Docker: https://docs.docker.com/engine/install/"
174                        .into(),
175                );
176            }
177            Ok("docker")
178        }
179        BuildBackend::Native => {
180            if !native.ready() {
181                let missing = native.missing();
182                return Err(format!(
183                    "Native build was requested but these tools are missing:\n  {}\n\n\
184                     Install them or use --docker instead.",
185                    missing.join("\n  ")
186                ));
187            }
188            Ok("native")
189        }
190        BuildBackend::Auto => {
191            if has_docker && native.ready() {
192                // Both available, ask the user
193                println!();
194                println!("  Both Docker and native toolchains are available.");
195                println!();
196                let choice = prompt::ask_choice(
197                    "  Build method",
198                    &[
199                        "Docker (isolated, reproducible, resumes via layer cache)",
200                        "Native (faster, uses your host CUDA toolkit directly)",
201                    ],
202                    1,
203                );
204                Ok(if choice == 2 { "native" } else { "docker" })
205            } else if has_docker {
206                println!("  Using Docker (native toolchain not complete).");
207                Ok("docker")
208            } else if native.ready() {
209                println!("  Using native build (Docker not available).");
210                Ok("native")
211            } else {
212                let missing = native.missing();
213                Err(format!(
214                    "Cannot build libtorch. Need either:\n\n\
215                     \x20 Docker: https://docs.docker.com/engine/install/\n\n\
216                     Or native tools (missing: {})",
217                    missing.join(", ")
218                ))
219            }
220        }
221    }
222}
223
224// ---------------------------------------------------------------------------
225// Entry point
226// ---------------------------------------------------------------------------
227
228pub fn run(opts: BuildOpts) -> Result<(), String> {
229    let ctx = Context::resolve();
230
231    // Determine architectures
232    let archs = match &opts.archs {
233        Some(a) => {
234            println!("  Using specified architectures: {}", a);
235            a.clone()
236        }
237        None => detect_arch_list()?,
238    };
239
240    let arch_dir = arch_dir_name(&archs);
241    let install_path = ctx.root.join(format!("libtorch/builds/{}", arch_dir));
242    let variant_id = format!("builds/{}", arch_dir);
243
244    // Select backend
245    let backend = select_backend(&opts.backend)?;
246
247    println!();
248    println!("  libtorch source build");
249    println!("  Archs:   {}", archs);
250    println!("  Output:  {}", install_path.display());
251    println!("  Jobs:    {}", opts.max_jobs);
252    println!("  Method:  {}", backend);
253    println!();
254
255    if opts.dry_run {
256        println!("  [dry-run] Would build libtorch from source via {}.", backend);
257        println!("  This typically takes 2-6 hours depending on CPU cores.");
258        return Ok(());
259    }
260
261    println!("  This will take 2-6 hours. You can safely Ctrl-C and resume later.");
262    println!();
263
264    let install_str = install_path.to_str().unwrap_or("libtorch/builds");
265    match backend {
266        "docker" => build_docker(&archs, install_str, opts.max_jobs)?,
267        "native" => build_native(&archs, install_str, &ctx, opts.max_jobs)?,
268        _ => unreachable!(),
269    }
270
271    // Verify
272    let lib_dir = install_path.join("lib");
273    if !lib_dir.join("libtorch.so").exists() && !lib_dir.join("libtorch.dylib").exists() {
274        return Err(format!(
275            "libtorch library not found at {}.\n\
276             The build may have failed silently.",
277            lib_dir.display()
278        ));
279    }
280
281    // Write .arch metadata
282    let arch_spaces = archs.replace(';', " ");
283    let arch_content = format!(
284        "cuda=12.8\ntorch={}\narchs={}\nsource=compiled\n",
285        LIBTORCH_VERSION, arch_spaces
286    );
287    fs::write(install_path.join(".arch"), arch_content)
288        .map_err(|e| format!("cannot write .arch: {}", e))?;
289
290    // Set as active
291    detect::set_active(&ctx.root, &variant_id)?;
292
293    println!();
294    println!("  ================================================");
295    println!("  libtorch {} (source build) complete!", LIBTORCH_VERSION);
296    println!("  Archs:  {}", arch_spaces);
297    println!("  Path:   {}", install_path.display());
298    println!("  Active: {}", variant_id);
299    println!("  ================================================");
300    println!();
301    if ctx.is_project {
302        println!("  Run 'make cuda-test' to verify.");
303    } else {
304        println!("  To use, add to your shell profile:");
305        println!("    export LIBTORCH=\"{}\"", install_path.display());
306        println!(
307            "    export LD_LIBRARY_PATH=\"{}/lib${{LD_LIBRARY_PATH:+:$LD_LIBRARY_PATH}}\"",
308            install_path.display()
309        );
310    }
311
312    Ok(())
313}
314
315// ---------------------------------------------------------------------------
316// Docker backend
317// ---------------------------------------------------------------------------
318
319fn build_docker(archs: &str, install_path: &str, max_jobs: usize) -> Result<(), String> {
320    println!("  Docker layer caching means restarting picks up where it left off.");
321    println!();
322
323    // Write Dockerfile to temp location
324    let tmp_dir = std::env::temp_dir();
325    let dockerfile_path = tmp_dir.join("flodl-libtorch-builder.Dockerfile");
326    {
327        let mut f = fs::File::create(&dockerfile_path)
328            .map_err(|e| format!("cannot write Dockerfile: {}", e))?;
329        f.write_all(DOCKERFILE_CONTENT.as_bytes())
330            .map_err(|e| format!("cannot write Dockerfile: {}", e))?;
331    }
332
333    // Build the Docker image
334    println!("  Building Docker image...");
335    let status = docker::docker_run(&[
336        "build",
337        "-t",
338        IMAGE_NAME,
339        "--build-arg",
340        &format!("TORCH_CUDA_ARCH_LIST={}", archs),
341        "--build-arg",
342        &format!("MAX_JOBS={}", max_jobs),
343        "-f",
344        dockerfile_path
345            .to_str()
346            .ok_or("temp path not UTF-8")?,
347        ".",
348    ])?;
349
350    let _ = fs::remove_file(&dockerfile_path);
351
352    if !status.success() {
353        return Err(format!(
354            "Docker build failed (exit code {}).\n\
355             Check the output above for errors.\n\
356             You can re-run this command to resume (Docker caches completed layers).",
357            status.code().unwrap_or(-1)
358        ));
359    }
360
361    // Extract libtorch from the builder image
362    println!();
363    println!("  Extracting libtorch from builder image...");
364
365    let container_out = docker::docker_output(&["create", IMAGE_NAME])?;
366    if !container_out.status.success() {
367        return Err("failed to create container from builder image".into());
368    }
369    let container_id = String::from_utf8_lossy(&container_out.stdout)
370        .trim()
371        .to_string();
372
373    fs::create_dir_all(install_path)
374        .map_err(|e| format!("cannot create {}: {}", install_path, e))?;
375
376    let cp_status = docker::docker_run(&[
377        "cp",
378        &format!("{}:/usr/local/libtorch/.", container_id),
379        install_path,
380    ])?;
381
382    let _ = docker::docker_output(&["rm", &container_id]);
383
384    if !cp_status.success() {
385        return Err("failed to extract libtorch from builder container".into());
386    }
387
388    Ok(())
389}
390
391// ---------------------------------------------------------------------------
392// Native backend
393// ---------------------------------------------------------------------------
394
395fn build_native(archs: &str, install_path: &str, ctx: &Context, max_jobs: usize) -> Result<(), String> {
396    let build_dir = ctx.root.join("libtorch/.build-cache/pytorch");
397
398    // Clone PyTorch if not cached
399    if !build_dir.join(".git").exists() {
400        println!("  Cloning PyTorch {}...", PYTORCH_VERSION);
401        fs::create_dir_all(ctx.root.join("libtorch/.build-cache"))
402            .map_err(|e| format!("cannot create build cache: {}", e))?;
403
404        let status = Command::new("git")
405            .args([
406                "clone", "--depth", "1",
407                "--branch", PYTORCH_VERSION,
408                "--recurse-submodules", "--shallow-submodules",
409                "https://github.com/pytorch/pytorch.git",
410                build_dir.to_str().ok_or("path not UTF-8")?,
411            ])
412            .stdout(Stdio::inherit())
413            .stderr(Stdio::inherit())
414            .status()
415            .map_err(|e| format!("failed to run git: {}", e))?;
416
417        if !status.success() {
418            // Clean up failed clone
419            let _ = fs::remove_dir_all(build_dir);
420            return Err("git clone failed. Check your network connection.".into());
421        }
422    } else {
423        println!("  Using cached PyTorch source at {}", build_dir.display());
424    }
425
426    // Install Python dependencies
427    println!("  Checking Python dependencies...");
428    let pip_status = Command::new("pip3")
429        .args(["install", "--quiet"])
430        .args(PYTHON_DEPS)
431        .stdout(Stdio::inherit())
432        .stderr(Stdio::inherit())
433        .status();
434
435    // Try --break-system-packages if the first attempt fails (Ubuntu 24.04+)
436    if pip_status.is_err() || !pip_status.unwrap().success() {
437        let _ = Command::new("pip3")
438            .args(["install", "--quiet", "--break-system-packages"])
439            .args(PYTHON_DEPS)
440            .stdout(Stdio::inherit())
441            .stderr(Stdio::inherit())
442            .status();
443    }
444
445    // Build libtorch
446    println!("  Building libtorch (TORCH_CUDA_ARCH_LIST=\"{}\", MAX_JOBS={})...", archs, max_jobs);
447    println!();
448
449    let status = Command::new("python3")
450        .arg("tools/build_libtorch.py")
451        .current_dir(&build_dir)
452        .env("TORCH_CUDA_ARCH_LIST", archs)
453        .env("USE_CUDA", "1")
454        .env("USE_CUDNN", "1")
455        .env("USE_NCCL", "1")
456        .env("USE_DISTRIBUTED", "1")
457        .env("BUILD_SHARED_LIBS", "ON")
458        .env("CMAKE_BUILD_TYPE", "Release")
459        .env("MAX_JOBS", max_jobs.to_string())
460        .env("BUILD_PYTHON", "OFF")
461        .env("BUILD_TEST", "OFF")
462        .env("BUILD_CAFFE2", "OFF")
463        .stdout(Stdio::inherit())
464        .stderr(Stdio::inherit())
465        .status()
466        .map_err(|e| format!("failed to run build_libtorch.py: {}", e))?;
467
468    if !status.success() {
469        return Err(format!(
470            "Native build failed (exit code {}).\n\
471             Check the output above for errors.\n\
472             The PyTorch source is cached at {} -- re-running will skip the clone.",
473            status.code().unwrap_or(-1),
474            build_dir.display()
475        ));
476    }
477
478    // Copy output to install path
479    println!();
480    println!("  Packaging libtorch to {}...", install_path);
481
482    let torch_dir = build_dir.join("torch");
483    fs::create_dir_all(install_path)
484        .map_err(|e| format!("cannot create {}: {}", install_path, e))?;
485
486    for subdir in ["lib", "include", "share"] {
487        let src = torch_dir.join(subdir);
488        let dst = Path::new(install_path).join(subdir);
489        if src.is_dir() {
490            copy_dir_recursive(&src, &dst)?;
491        }
492    }
493
494    Ok(())
495}
496
497fn copy_dir_recursive(src: &Path, dest: &Path) -> Result<(), String> {
498    fs::create_dir_all(dest)
499        .map_err(|e| format!("cannot create {}: {}", dest.display(), e))?;
500
501    for entry in fs::read_dir(src).map_err(|e| format!("read {}: {}", src.display(), e))? {
502        let entry = entry.map_err(|e| format!("read_dir error: {}", e))?;
503        let from = entry.path();
504        let to = dest.join(entry.file_name());
505
506        if from.is_dir() {
507            copy_dir_recursive(&from, &to)?;
508        } else {
509            fs::copy(&from, &to)
510                .map_err(|e| format!("copy {} -> {}: {}", from.display(), to.display(), e))?;
511        }
512    }
513    Ok(())
514}