1use std::fs;
8use std::io::Write;
9use std::path::Path;
10use std::process::{Command, Stdio};
11
12use crate::context::Context;
13use crate::util::docker;
14use crate::util::prompt;
15use crate::util::system;
16use super::detect;
17
18const DOCKERFILE_CONTENT: &str = include_str!("../../assets/Dockerfile.cuda.source");
19const IMAGE_NAME: &str = "flodl-libtorch-builder";
20const LIBTORCH_VERSION: &str = "2.10.0";
21const PYTORCH_VERSION: &str = "v2.10.0";
22
23const PYTHON_DEPS: &[&str] = &[
24 "typing_extensions", "pyyaml", "filelock",
25 "jinja2", "networkx", "sympy", "packaging",
26];
27
28#[derive(Default)]
33pub enum BuildBackend {
34 #[default]
36 Auto,
37 Docker,
39 Native,
41}
42
43pub struct BuildOpts {
44 pub archs: Option<String>,
47 pub max_jobs: usize,
49 pub dry_run: bool,
51 pub backend: BuildBackend,
53}
54
55impl Default for BuildOpts {
56 fn default() -> Self {
57 Self {
58 archs: None,
59 max_jobs: 6,
60 dry_run: false,
61 backend: BuildBackend::Auto,
62 }
63 }
64}
65
66fn detect_arch_list() -> Result<String, String> {
71 let gpus = system::detect_gpus();
72 if gpus.is_empty() {
73 return Err(
74 "No NVIDIA GPUs detected.\n\
75 Source builds require GPUs to auto-detect architectures.\n\
76 Use --archs to specify manually (e.g. --archs \"8.6;12.0\")."
77 .into(),
78 );
79 }
80
81 let mut caps: Vec<(u32, u32)> = gpus
83 .iter()
84 .map(|g| (g.sm_major, g.sm_minor))
85 .collect();
86 caps.sort();
87 caps.dedup();
88 let caps: Vec<String> = caps.iter().map(|(ma, mi)| format!("{}.{}", ma, mi)).collect();
89
90 println!(" GPUs detected:");
91 for g in &gpus {
92 println!(
93 " [{}] {} (sm_{}.{})",
94 g.index, g.short_name(), g.sm_major, g.sm_minor
95 );
96 }
97
98 Ok(caps.join(";"))
99}
100
101fn arch_dir_name(archs: &str) -> String {
103 archs
104 .split(';')
105 .map(|cap| {
106 let clean = cap.replace('.', "");
107 format!("sm{}", clean)
108 })
109 .collect::<Vec<_>>()
110 .join("-")
111}
112
113struct NativeTools {
118 nvcc: bool,
119 cmake: bool,
120 python3: bool,
121 git: bool,
122 gcc: bool,
123}
124
125impl NativeTools {
126 fn detect() -> Self {
127 Self {
128 nvcc: has_tool("nvcc"),
129 cmake: has_tool("cmake"),
130 python3: has_tool("python3"),
131 git: has_tool("git"),
132 gcc: has_tool("gcc") || has_tool("cc"),
133 }
134 }
135
136 fn ready(&self) -> bool {
137 self.nvcc && self.cmake && self.python3 && self.git && self.gcc
138 }
139
140 fn missing(&self) -> Vec<&'static str> {
141 let mut m = Vec::new();
142 if !self.nvcc { m.push("nvcc (CUDA toolkit)"); }
143 if !self.cmake { m.push("cmake"); }
144 if !self.python3 { m.push("python3"); }
145 if !self.git { m.push("git"); }
146 if !self.gcc { m.push("gcc/cc (C++ compiler)"); }
147 m
148 }
149}
150
151fn has_tool(name: &str) -> bool {
152 Command::new(name)
153 .arg("--version")
154 .stdout(Stdio::null())
155 .stderr(Stdio::null())
156 .status()
157 .is_ok_and(|s| s.success())
158}
159
160fn select_backend(backend: &BuildBackend) -> Result<&'static str, String> {
165 let has_docker = docker::has_docker();
166 let native = NativeTools::detect();
167
168 match backend {
169 BuildBackend::Docker => {
170 if !has_docker {
171 return Err(
172 "Docker was requested but is not available.\n\
173 Install Docker: https://docs.docker.com/engine/install/"
174 .into(),
175 );
176 }
177 Ok("docker")
178 }
179 BuildBackend::Native => {
180 if !native.ready() {
181 let missing = native.missing();
182 return Err(format!(
183 "Native build was requested but these tools are missing:\n {}\n\n\
184 Install them or use --docker instead.",
185 missing.join("\n ")
186 ));
187 }
188 Ok("native")
189 }
190 BuildBackend::Auto => {
191 if has_docker && native.ready() {
192 println!();
194 println!(" Both Docker and native toolchains are available.");
195 println!();
196 let choice = prompt::ask_choice(
197 " Build method",
198 &[
199 "Docker (isolated, reproducible, resumes via layer cache)",
200 "Native (faster, uses your host CUDA toolkit directly)",
201 ],
202 1,
203 );
204 Ok(if choice == 2 { "native" } else { "docker" })
205 } else if has_docker {
206 println!(" Using Docker (native toolchain not complete).");
207 Ok("docker")
208 } else if native.ready() {
209 println!(" Using native build (Docker not available).");
210 Ok("native")
211 } else {
212 let missing = native.missing();
213 Err(format!(
214 "Cannot build libtorch. Need either:\n\n\
215 \x20 Docker: https://docs.docker.com/engine/install/\n\n\
216 Or native tools (missing: {})",
217 missing.join(", ")
218 ))
219 }
220 }
221 }
222}
223
224pub fn run(opts: BuildOpts) -> Result<(), String> {
229 let ctx = Context::resolve();
230
231 let archs = match &opts.archs {
233 Some(a) => {
234 println!(" Using specified architectures: {}", a);
235 a.clone()
236 }
237 None => detect_arch_list()?,
238 };
239
240 let arch_dir = arch_dir_name(&archs);
241 let install_path = ctx.root.join(format!("libtorch/builds/{}", arch_dir));
242 let variant_id = format!("builds/{}", arch_dir);
243
244 let backend = select_backend(&opts.backend)?;
246
247 println!();
248 println!(" libtorch source build");
249 println!(" Archs: {}", archs);
250 println!(" Output: {}", install_path.display());
251 println!(" Jobs: {}", opts.max_jobs);
252 println!(" Method: {}", backend);
253 println!();
254
255 if opts.dry_run {
256 println!(" [dry-run] Would build libtorch from source via {}.", backend);
257 println!(" This typically takes 2-6 hours depending on CPU cores.");
258 return Ok(());
259 }
260
261 println!(" This will take 2-6 hours. You can safely Ctrl-C and resume later.");
262 println!();
263
264 let install_str = install_path.to_str().unwrap_or("libtorch/builds");
265 match backend {
266 "docker" => build_docker(&archs, install_str, opts.max_jobs)?,
267 "native" => build_native(&archs, install_str, &ctx, opts.max_jobs)?,
268 _ => unreachable!(),
269 }
270
271 let lib_dir = install_path.join("lib");
273 if !lib_dir.join("libtorch.so").exists() && !lib_dir.join("libtorch.dylib").exists() {
274 return Err(format!(
275 "libtorch library not found at {}.\n\
276 The build may have failed silently.",
277 lib_dir.display()
278 ));
279 }
280
281 let arch_spaces = archs.replace(';', " ");
283 let arch_content = format!(
284 "cuda=12.8\ntorch={}\narchs={}\nsource=compiled\n",
285 LIBTORCH_VERSION, arch_spaces
286 );
287 fs::write(install_path.join(".arch"), arch_content)
288 .map_err(|e| format!("cannot write .arch: {}", e))?;
289
290 detect::set_active(&ctx.root, &variant_id)?;
292
293 println!();
294 println!(" ================================================");
295 println!(" libtorch {} (source build) complete!", LIBTORCH_VERSION);
296 println!(" Archs: {}", arch_spaces);
297 println!(" Path: {}", install_path.display());
298 println!(" Active: {}", variant_id);
299 println!(" ================================================");
300 println!();
301 if ctx.is_project {
302 println!(" Run 'make cuda-test' to verify.");
303 } else {
304 println!(" To use, add to your shell profile:");
305 println!(" export LIBTORCH=\"{}\"", install_path.display());
306 println!(
307 " export LD_LIBRARY_PATH=\"{}/lib${{LD_LIBRARY_PATH:+:$LD_LIBRARY_PATH}}\"",
308 install_path.display()
309 );
310 }
311
312 Ok(())
313}
314
315fn build_docker(archs: &str, install_path: &str, max_jobs: usize) -> Result<(), String> {
320 println!(" Docker layer caching means restarting picks up where it left off.");
321 println!();
322
323 let tmp_dir = std::env::temp_dir();
325 let dockerfile_path = tmp_dir.join("flodl-libtorch-builder.Dockerfile");
326 {
327 let mut f = fs::File::create(&dockerfile_path)
328 .map_err(|e| format!("cannot write Dockerfile: {}", e))?;
329 f.write_all(DOCKERFILE_CONTENT.as_bytes())
330 .map_err(|e| format!("cannot write Dockerfile: {}", e))?;
331 }
332
333 println!(" Building Docker image...");
335 let status = docker::docker_run(&[
336 "build",
337 "-t",
338 IMAGE_NAME,
339 "--build-arg",
340 &format!("TORCH_CUDA_ARCH_LIST={}", archs),
341 "--build-arg",
342 &format!("MAX_JOBS={}", max_jobs),
343 "-f",
344 dockerfile_path
345 .to_str()
346 .ok_or("temp path not UTF-8")?,
347 ".",
348 ])?;
349
350 let _ = fs::remove_file(&dockerfile_path);
351
352 if !status.success() {
353 return Err(format!(
354 "Docker build failed (exit code {}).\n\
355 Check the output above for errors.\n\
356 You can re-run this command to resume (Docker caches completed layers).",
357 status.code().unwrap_or(-1)
358 ));
359 }
360
361 println!();
363 println!(" Extracting libtorch from builder image...");
364
365 let container_out = docker::docker_output(&["create", IMAGE_NAME])?;
366 if !container_out.status.success() {
367 return Err("failed to create container from builder image".into());
368 }
369 let container_id = String::from_utf8_lossy(&container_out.stdout)
370 .trim()
371 .to_string();
372
373 fs::create_dir_all(install_path)
374 .map_err(|e| format!("cannot create {}: {}", install_path, e))?;
375
376 let cp_status = docker::docker_run(&[
377 "cp",
378 &format!("{}:/usr/local/libtorch/.", container_id),
379 install_path,
380 ])?;
381
382 let _ = docker::docker_output(&["rm", &container_id]);
383
384 if !cp_status.success() {
385 return Err("failed to extract libtorch from builder container".into());
386 }
387
388 Ok(())
389}
390
391fn build_native(archs: &str, install_path: &str, ctx: &Context, max_jobs: usize) -> Result<(), String> {
396 let build_dir = ctx.root.join("libtorch/.build-cache/pytorch");
397
398 if !build_dir.join(".git").exists() {
400 println!(" Cloning PyTorch {}...", PYTORCH_VERSION);
401 fs::create_dir_all(ctx.root.join("libtorch/.build-cache"))
402 .map_err(|e| format!("cannot create build cache: {}", e))?;
403
404 let status = Command::new("git")
405 .args([
406 "clone", "--depth", "1",
407 "--branch", PYTORCH_VERSION,
408 "--recurse-submodules", "--shallow-submodules",
409 "https://github.com/pytorch/pytorch.git",
410 build_dir.to_str().ok_or("path not UTF-8")?,
411 ])
412 .stdout(Stdio::inherit())
413 .stderr(Stdio::inherit())
414 .status()
415 .map_err(|e| format!("failed to run git: {}", e))?;
416
417 if !status.success() {
418 let _ = fs::remove_dir_all(build_dir);
420 return Err("git clone failed. Check your network connection.".into());
421 }
422 } else {
423 println!(" Using cached PyTorch source at {}", build_dir.display());
424 }
425
426 println!(" Checking Python dependencies...");
428 let pip_status = Command::new("pip3")
429 .args(["install", "--quiet"])
430 .args(PYTHON_DEPS)
431 .stdout(Stdio::inherit())
432 .stderr(Stdio::inherit())
433 .status();
434
435 if pip_status.is_err() || !pip_status.unwrap().success() {
437 let _ = Command::new("pip3")
438 .args(["install", "--quiet", "--break-system-packages"])
439 .args(PYTHON_DEPS)
440 .stdout(Stdio::inherit())
441 .stderr(Stdio::inherit())
442 .status();
443 }
444
445 println!(" Building libtorch (TORCH_CUDA_ARCH_LIST=\"{}\", MAX_JOBS={})...", archs, max_jobs);
447 println!();
448
449 let status = Command::new("python3")
450 .arg("tools/build_libtorch.py")
451 .current_dir(&build_dir)
452 .env("TORCH_CUDA_ARCH_LIST", archs)
453 .env("USE_CUDA", "1")
454 .env("USE_CUDNN", "1")
455 .env("USE_NCCL", "1")
456 .env("USE_DISTRIBUTED", "1")
457 .env("BUILD_SHARED_LIBS", "ON")
458 .env("CMAKE_BUILD_TYPE", "Release")
459 .env("MAX_JOBS", max_jobs.to_string())
460 .env("BUILD_PYTHON", "OFF")
461 .env("BUILD_TEST", "OFF")
462 .env("BUILD_CAFFE2", "OFF")
463 .stdout(Stdio::inherit())
464 .stderr(Stdio::inherit())
465 .status()
466 .map_err(|e| format!("failed to run build_libtorch.py: {}", e))?;
467
468 if !status.success() {
469 return Err(format!(
470 "Native build failed (exit code {}).\n\
471 Check the output above for errors.\n\
472 The PyTorch source is cached at {} -- re-running will skip the clone.",
473 status.code().unwrap_or(-1),
474 build_dir.display()
475 ));
476 }
477
478 println!();
480 println!(" Packaging libtorch to {}...", install_path);
481
482 let torch_dir = build_dir.join("torch");
483 fs::create_dir_all(install_path)
484 .map_err(|e| format!("cannot create {}: {}", install_path, e))?;
485
486 for subdir in ["lib", "include", "share"] {
487 let src = torch_dir.join(subdir);
488 let dst = Path::new(install_path).join(subdir);
489 if src.is_dir() {
490 copy_dir_recursive(&src, &dst)?;
491 }
492 }
493
494 Ok(())
495}
496
497fn copy_dir_recursive(src: &Path, dest: &Path) -> Result<(), String> {
498 fs::create_dir_all(dest)
499 .map_err(|e| format!("cannot create {}: {}", dest.display(), e))?;
500
501 for entry in fs::read_dir(src).map_err(|e| format!("read {}: {}", src.display(), e))? {
502 let entry = entry.map_err(|e| format!("read_dir error: {}", e))?;
503 let from = entry.path();
504 let to = dest.join(entry.file_name());
505
506 if from.is_dir() {
507 copy_dir_recursive(&from, &to)?;
508 } else {
509 fs::copy(&from, &to)
510 .map_err(|e| format!("copy {} -> {}: {}", from.display(), to.display(), e))?;
511 }
512 }
513 Ok(())
514}