1use std::fs;
4use std::path::{Path, PathBuf};
5
6use crate::context::Context;
7use crate::util::http;
8use crate::util::archive;
9use crate::util::system;
10use super::detect;
11
12const LIBTORCH_VERSION: &str = "2.10.0";
17
18struct VariantSpec {
20 label: &'static str,
22 dir_name: &'static str,
24 arch_cuda: &'static str,
26 arch_archs: &'static str,
28 arch_variant: &'static str,
30}
31
32const CPU_SPEC: VariantSpec = VariantSpec {
33 label: "CPU",
34 dir_name: "cpu",
35 arch_cuda: "none",
36 arch_archs: "cpu",
37 arch_variant: "cpu",
38};
39
40const CU126_SPEC: VariantSpec = VariantSpec {
41 label: "CUDA 12.6",
42 dir_name: "cu126",
43 arch_cuda: "12.6",
44 arch_archs: "5.0 5.2 6.0 6.1 7.0 7.5 8.0 8.6 8.9 9.0",
45 arch_variant: "cu126",
46};
47
48const CU128_SPEC: VariantSpec = VariantSpec {
49 label: "CUDA 12.8",
50 dir_name: "cu128",
51 arch_cuda: "12.8",
52 arch_archs: "7.0 7.5 8.0 8.6 8.9 9.0 12.0",
53 arch_variant: "cu128",
54};
55
56pub enum Variant {
61 Cpu,
62 Cuda126,
63 Cuda128,
64 Auto,
65}
66
67pub struct DownloadOpts {
68 pub variant: Variant,
69 pub custom_path: Option<PathBuf>,
70 pub activate: bool,
71 pub dry_run: bool,
72}
73
74impl Default for DownloadOpts {
75 fn default() -> Self {
76 Self {
77 variant: Variant::Auto,
78 custom_path: None,
79 activate: true,
80 dry_run: false,
81 }
82 }
83}
84
85fn download_url(spec: &VariantSpec) -> Result<String, String> {
90 let os = std::env::consts::OS;
91 let arch = std::env::consts::ARCH;
92
93 match (os, arch) {
94 ("linux", "x86_64") => {}
95 ("macos", "aarch64") => {
96 if spec.arch_cuda != "none" {
97 return Err("macOS only supports CPU libtorch".into());
98 }
99 }
100 ("macos", _) => {
101 return Err(format!(
102 "macOS libtorch requires Apple Silicon (arm64), got {}.\n\
103 macOS x86_64 was dropped after PyTorch 2.2.",
104 arch
105 ));
106 }
107 ("windows", "x86_64") => {}
108 _ => {
109 return Err(format!(
110 "Unsupported platform: {} {}.\n\
111 libtorch is available for Linux x86_64, macOS arm64, and Windows x86_64.",
112 os, arch
113 ));
114 }
115 }
116
117 if os == "macos" {
119 return Ok(format!(
120 "https://download.pytorch.org/libtorch/cpu/libtorch-macos-arm64-{}.zip",
121 LIBTORCH_VERSION
122 ));
123 }
124
125 let filename = match spec.arch_variant {
127 "cpu" => format!(
128 "libtorch-shared-with-deps-{}%2Bcpu.zip",
129 LIBTORCH_VERSION
130 ),
131 variant => format!(
132 "libtorch-shared-with-deps-{}%2B{}.zip",
133 LIBTORCH_VERSION, variant
134 ),
135 };
136
137 let bucket = spec.arch_variant; Ok(format!(
139 "https://download.pytorch.org/libtorch/{}/{}",
140 bucket, filename
141 ))
142}
143
144fn auto_detect_variant() -> &'static VariantSpec {
149 let gpus = system::detect_gpus();
150 if gpus.is_empty() {
151 println!(" No NVIDIA GPU detected. Using CPU variant.");
152 return &CPU_SPEC;
153 }
154
155 let lo_major = gpus.iter().map(|g| g.sm_major).min().unwrap_or(0);
157 let hi_major = gpus.iter().map(|g| g.sm_major).max().unwrap_or(0);
158
159 if lo_major >= 7 {
161 println!(" Detected Volta+ GPU(s). Using cu128.");
162 &CU128_SPEC
163 } else if hi_major >= 10 {
164 println!(
167 " Mixed GPU architectures (sm_{}.x to sm_{}.x).",
168 lo_major, hi_major
169 );
170 println!(" Using cu126 (broadest pre-Volta coverage).");
171 println!(" For all GPUs, consider: fdl libtorch build");
172 &CU126_SPEC
173 } else {
174 println!(" Detected pre-Volta GPU(s). Using cu126.");
175 &CU126_SPEC
176 }
177}
178
179fn resolve_variant(variant: &Variant) -> &'static VariantSpec {
180 match variant {
181 Variant::Cpu => &CPU_SPEC,
182 Variant::Cuda126 => &CU126_SPEC,
183 Variant::Cuda128 => &CU128_SPEC,
184 Variant::Auto => auto_detect_variant(),
185 }
186}
187
188pub fn run(opts: DownloadOpts) -> Result<(), String> {
193 let ctx = Context::resolve();
194 run_with_context(opts, &ctx)
195}
196
197pub fn run_with_context(opts: DownloadOpts, ctx: &Context) -> Result<(), String> {
199 let spec = resolve_variant(&opts.variant);
200 let url = download_url(spec)?;
201
202 let install_path = if let Some(ref p) = opts.custom_path {
204 p.clone()
205 } else {
206 ctx.root.join(format!("libtorch/precompiled/{}", spec.dir_name))
207 };
208
209 let variant_id = format!("precompiled/{}", spec.dir_name);
210
211 println!();
212 println!(" libtorch {} ({})", LIBTORCH_VERSION, spec.label);
213 println!(" URL: {}", url);
214 println!(" Path: {}", install_path.display());
215
216 if opts.dry_run {
217 println!();
218 println!(" [dry-run] Would download and extract to above path.");
219 return Ok(());
220 }
221
222 if install_path.exists() {
224 let build_ver_path = install_path.join("build-version");
225 let existing_ver = fs::read_to_string(&build_ver_path)
226 .ok()
227 .map(|s| s.trim().to_string());
228
229 let ver_matches = existing_ver.as_deref().is_some_and(|v| {
231 v == LIBTORCH_VERSION || v.starts_with(&format!("{}+", LIBTORCH_VERSION))
232 });
233
234 if ver_matches {
235 println!();
236 println!(" Already installed (version {}).", LIBTORCH_VERSION);
237 return Ok(());
238 }
239
240 println!();
241 println!(
242 " Removing existing installation (version: {})...",
243 existing_ver.as_deref().unwrap_or("unknown")
244 );
245 fs::remove_dir_all(&install_path)
246 .map_err(|e| format!("cannot remove {}: {}", install_path.display(), e))?;
247 }
248
249 let tmp_dir = std::env::temp_dir();
251 let tmp_zip = tmp_dir.join(format!("libtorch-{}-{}.zip", spec.dir_name, LIBTORCH_VERSION));
252
253 println!();
254 println!(" Downloading...");
255 http::download_file(&url, &tmp_zip)?;
256
257 let tmp_extract = tmp_dir.join(format!("libtorch-extract-{}", std::process::id()));
259 println!(" Extracting...");
260 archive::extract_zip(&tmp_zip, &tmp_extract)?;
261
262 let extracted_lt = tmp_extract.join("libtorch");
264 let source = if extracted_lt.is_dir() {
265 &extracted_lt
266 } else {
267 &tmp_extract
268 };
269
270 fs::create_dir_all(&install_path)
271 .map_err(|e| format!("cannot create {}: {}", install_path.display(), e))?;
272
273 move_contents(source, &install_path)?;
275
276 let _ = fs::remove_file(&tmp_zip);
278 let _ = fs::remove_dir_all(&tmp_extract);
279
280 let lib_dir = install_path.join("lib");
282 let has_lib = lib_dir.join("libtorch.so").exists()
283 || lib_dir.join("libtorch.dylib").exists()
284 || lib_dir.join("torch.lib").exists();
285
286 if !has_lib {
287 return Err(format!(
288 "libtorch library not found at {}.\n\
289 The archive structure may have changed.\n\
290 Check: ls {}",
291 lib_dir.display(),
292 lib_dir.display()
293 ));
294 }
295
296 let arch_content = format!(
298 "cuda={}\ntorch={}\narchs={}\nsource=precompiled\nvariant={}\n",
299 spec.arch_cuda, LIBTORCH_VERSION, spec.arch_archs, spec.arch_variant
300 );
301 fs::write(install_path.join(".arch"), arch_content)
302 .map_err(|e| format!("cannot write .arch: {}", e))?;
303
304 if opts.activate {
305 detect::set_active(&ctx.root, &variant_id)?;
306 }
307
308 println!();
309 println!(" ================================================");
310 println!(" libtorch {} ({}) installed", LIBTORCH_VERSION, spec.label);
311 println!(" {}", install_path.display());
312 println!(" ================================================");
313
314 if ctx.is_project {
315 println!();
316 println!(" .arch: {}/.arch", install_path.display());
317 if opts.activate {
318 println!(" .active: libtorch/.active -> {}", variant_id);
319 }
320 println!();
321 if spec.arch_cuda != "none" {
322 println!(" Run 'make cuda-test' to verify.");
323 } else {
324 println!(" Run 'make test' to verify.");
325 }
326 } else {
327 println!();
328 println!(" Installed to: {}", install_path.display());
329 println!();
330 println!(" To use with tch-rs or flodl, add to your shell profile:");
331 println!();
332 println!(" export LIBTORCH=\"{}\"", install_path.display());
333 println!(
334 " export LD_LIBRARY_PATH=\"{}/lib${{LD_LIBRARY_PATH:+:$LD_LIBRARY_PATH}}\"",
335 install_path.display()
336 );
337 println!();
338 println!(" Or start a new floDl project:");
339 println!(" fdl init my-project");
340 }
341
342 Ok(())
343}
344
345fn move_contents(src: &Path, dest: &Path) -> Result<(), String> {
351 let entries = fs::read_dir(src)
352 .map_err(|e| format!("cannot read {}: {}", src.display(), e))?;
353
354 for entry in entries {
355 let entry = entry.map_err(|e| format!("read_dir error: {}", e))?;
356 let from = entry.path();
357 let name = entry.file_name();
358 let to = dest.join(&name);
359
360 if fs::rename(&from, &to).is_err() {
362 if from.is_dir() {
363 copy_dir_recursive(&from, &to)?;
364 } else {
365 fs::copy(&from, &to)
366 .map_err(|e| format!("copy {} -> {}: {}", from.display(), to.display(), e))?;
367 }
368 }
369 }
370 Ok(())
371}
372
373fn copy_dir_recursive(src: &Path, dest: &Path) -> Result<(), String> {
374 fs::create_dir_all(dest)
375 .map_err(|e| format!("cannot create {}: {}", dest.display(), e))?;
376
377 for entry in fs::read_dir(src).map_err(|e| format!("read {}: {}", src.display(), e))? {
378 let entry = entry.map_err(|e| format!("read_dir error: {}", e))?;
379 let from = entry.path();
380 let to = dest.join(entry.file_name());
381
382 if from.is_dir() {
383 copy_dir_recursive(&from, &to)?;
384 } else {
385 fs::copy(&from, &to)
386 .map_err(|e| format!("copy {} -> {}: {}", from.display(), to.display(), e))?;
387 }
388 }
389 Ok(())
390}
391
392#[allow(dead_code)]
394pub fn libtorch_version() -> &'static str {
395 LIBTORCH_VERSION
396}