Skip to main content

flodl_cli/
setup.rs

1//! `fdl setup` -- interactive guided setup wizard.
2//!
3//! Detects hardware, downloads libtorch, optionally builds Docker images.
4
5use crate::context::Context;
6use crate::libtorch::{build, detect, download};
7use crate::util::{docker, prompt, system};
8
9#[derive(Default)]
10pub struct SetupOpts {
11    /// Skip all prompts, use auto-detected defaults.
12    pub non_interactive: bool,
13    /// Re-download/rebuild even if libtorch exists.
14    pub force: bool,
15}
16
17pub fn run(opts: SetupOpts) -> Result<(), String> {
18    println!();
19    println!("  floDl Setup");
20    println!("  ===========");
21    println!();
22    println!("  floDl is a Rust deep learning framework built on libtorch");
23    println!("  (PyTorch's C++ backend). This wizard will help you set up");
24    println!("  your development environment.");
25    println!();
26
27    // ---- Step 1: Detect system ----
28
29    println!("  Step 1: Detecting your system");
30    println!("  -----------------------------");
31    println!();
32
33    let cpu = system::cpu_model().unwrap_or_else(|| "Unknown".into());
34    let threads = system::cpu_threads();
35    let ram_gb = system::ram_total_gb();
36    println!("  CPU:    {} ({} threads, {}GB RAM)", cpu, threads, ram_gb);
37
38    let has_docker = docker::has_docker();
39    let has_cargo = system::has_cargo();
40
41    if has_docker {
42        if let Some(v) = system::docker_version() {
43            println!("  Docker: {}", v);
44        } else {
45            println!("  Docker: available");
46        }
47    } else {
48        println!("  Docker: not found");
49    }
50
51    if has_cargo {
52        println!("  Rust:   available");
53    } else {
54        println!("  Rust:   not found");
55    }
56
57    let gpus = system::detect_gpus();
58    if !gpus.is_empty() {
59        println!();
60        println!("  GPUs:");
61        for g in &gpus {
62            println!(
63                "    [{}] {} -- sm_{}.{}, {}GB VRAM",
64                g.index,
65                g.name,
66                g.sm_major,
67                g.sm_minor,
68                g.total_memory_mb / 1024
69            );
70        }
71    } else {
72        println!();
73        println!("  GPU:    not detected (CPU-only mode)");
74    }
75
76    if !has_docker && !has_cargo {
77        println!();
78        println!("  You need at least one of these to continue:");
79        println!();
80        println!("    Rust:   curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh");
81        println!("    Docker: https://docs.docker.com/engine/install/");
82        println!();
83        println!("  Install one or both and run 'fdl setup' again.");
84        return Err("no Rust or Docker found".into());
85    }
86
87    // ---- Step 2: libtorch ----
88
89    println!();
90    println!("  Step 2: libtorch");
91    println!("  ----------------");
92    println!();
93    println!("  floDl needs libtorch, PyTorch's C++ library.");
94    println!("  This downloads pre-built binaries (~2GB for CUDA, ~200MB for CPU).");
95    println!();
96
97    let ctx = Context::resolve();
98    let root = &ctx.root;
99
100    if !ctx.is_project {
101        println!("  Not inside a floDl project.");
102        println!("  libtorch will be installed to: {}", ctx.libtorch_dir().display());
103        println!();
104    }
105
106    let existing = detect::read_active(root);
107    let mut skip_download = false;
108
109    if !opts.force {
110        if let Some(ref info) = existing {
111            let is_cuda = info.cuda_version.as_deref() != Some("none");
112            if is_cuda {
113                println!("  Found existing CUDA libtorch: {}", info.path);
114                if opts.non_interactive {
115                    println!("  Keeping existing installation.");
116                    skip_download = true;
117                } else if !prompt::ask_yn("  Download fresh?", false) {
118                    skip_download = true;
119                }
120                println!();
121            } else {
122                println!("  Found existing CPU libtorch.");
123            }
124        }
125    }
126
127    if !skip_download {
128        // Always download CPU variant (useful as fallback)
129        println!("  Downloading CPU libtorch...");
130        let cpu_opts = download::DownloadOpts {
131            variant: download::Variant::Cpu,
132            activate: false, // don't activate CPU if we'll also get CUDA
133            ..Default::default()
134        };
135        download::run_with_context(cpu_opts, &ctx)?;
136
137        // CUDA libtorch
138        if !gpus.is_empty() {
139            let lo_major = gpus.iter().map(|g| g.sm_major).min().unwrap_or(0);
140            let hi_major = gpus.iter().map(|g| g.sm_major).max().unwrap_or(0);
141
142            if lo_major < 7 && hi_major >= 10 {
143                // Mixed architectures -- no single prebuilt covers both
144                println!();
145                println!("  Your GPUs span sm_{}.x to sm_{}.x.", lo_major, hi_major);
146                println!("  No pre-built libtorch covers both architectures.");
147                println!();
148
149                // Check for existing source build
150                let has_source_build = detect::list_variants(root)
151                    .iter()
152                    .any(|v| v.starts_with("builds/"));
153
154                if has_source_build {
155                    println!("  Found existing source build in libtorch/builds/.");
156                } else if opts.non_interactive {
157                    println!("  Downloading cu126 (broadest coverage).");
158                    let cuda_opts = download::DownloadOpts {
159                        variant: download::Variant::Cuda126,
160                        ..Default::default()
161                    };
162                    download::run_with_context(cuda_opts, &ctx)?;
163                } else {
164                    let choice = prompt::ask_choice(
165                        "  Choice",
166                        &[
167                            "Build libtorch from source (2-6 hours, covers all GPUs)",
168                            "Download cu128 (Volta+ only, your older GPU won't work)",
169                            "Download cu126 (pre-Volta only, your newer GPU won't work)",
170                            "Skip for now",
171                        ],
172                        4,
173                    );
174
175                    match choice {
176                        1 => {
177                            println!();
178                            println!("  Starting libtorch source build...");
179                            println!("  This will take 2-6 hours. You can safely Ctrl-C and");
180                            println!("  resume later with: fdl libtorch build");
181                            println!();
182                            build::run(build::BuildOpts::default())?;
183                        }
184                        2 => {
185                            println!("  Downloading cu128...");
186                            let cuda_opts = download::DownloadOpts {
187                                variant: download::Variant::Cuda128,
188                                ..Default::default()
189                            };
190                            download::run_with_context(cuda_opts, &ctx)?;
191                        }
192                        3 => {
193                            println!("  Downloading cu126...");
194                            let cuda_opts = download::DownloadOpts {
195                                variant: download::Variant::Cuda126,
196                                ..Default::default()
197                            };
198                            download::run_with_context(cuda_opts, &ctx)?;
199                        }
200                        _ => {
201                            println!("  Skipping CUDA libtorch. You can download later with:");
202                            println!("    fdl libtorch download --cuda 12.8");
203                            println!("    # or build from source:");
204                            println!("    fdl libtorch build");
205                        }
206                    }
207                }
208            } else if lo_major < 7 {
209                println!();
210                println!("  Downloading CUDA libtorch (cu126 for your pre-Volta GPU)...");
211                let cuda_opts = download::DownloadOpts {
212                    variant: download::Variant::Cuda126,
213                    ..Default::default()
214                };
215                download::run_with_context(cuda_opts, &ctx)?;
216            } else {
217                println!();
218                println!("  Downloading CUDA libtorch (cu128 for your Volta+ GPU)...");
219                let cuda_opts = download::DownloadOpts {
220                    variant: download::Variant::Cuda128,
221                    ..Default::default()
222                };
223                download::run_with_context(cuda_opts, &ctx)?;
224            }
225        }
226    }
227
228    // ---- Step 3: Build environment (project-only) ----
229
230    if !ctx.is_project {
231        // Skip Docker image building when running standalone
232        println!();
233        println!("  Setup complete!");
234        println!("  ===============");
235        println!();
236        if let Some(info) = detect::read_active(root) {
237            let cuda_str = if info.cuda_version.as_deref() != Some("none") { "CUDA" } else { "CPU" };
238            println!("  libtorch:  {} ({})", info.path, cuda_str);
239            println!("  Location:  {}", ctx.libtorch_dir().display());
240        }
241        println!();
242        println!("  Next steps:");
243        println!("    fdl init my-project  # scaffold a new project");
244        println!("    fdl diagnose         # verify GPU compatibility");
245        println!();
246        return Ok(());
247    }
248
249    println!();
250    println!("  Step 3: Build environment");
251    println!("  -------------------------");
252    println!();
253    println!("  floDl compiles Rust code that links against libtorch.");
254    println!("  You can build with Docker (isolated, reproducible) or");
255    println!("  natively (faster iteration, requires Rust + C++ toolchain).");
256    println!();
257
258    let build_mode = if has_docker && has_cargo {
259        if opts.non_interactive {
260            "docker"
261        } else {
262            let choice = prompt::ask_choice(
263                "  Choice",
264                &[
265                    "Docker (recommended) -- isolated, reproducible builds",
266                    "Native -- faster iteration, requires C++ compiler on host",
267                    "Both -- set up Docker and show native instructions",
268                ],
269                1,
270            );
271            match choice {
272                1 => "docker",
273                2 => "native",
274                3 => "both",
275                _ => "docker",
276            }
277        }
278    } else if has_docker {
279        if opts.non_interactive {
280            "docker"
281        } else {
282            println!("  Docker is available. Rust is not installed on this machine.");
283            println!("  Docker is the easiest way to get started (no Rust install needed).");
284            println!();
285            if prompt::ask_yn("  Set up Docker build environment?", true) {
286                "docker"
287            } else {
288                "none"
289            }
290        }
291    } else {
292        println!("  Rust is available. Docker is not installed.");
293        println!("  You can build natively (requires C++ compiler on the host).");
294        println!();
295        "native"
296    };
297
298    // Build Docker images
299    if build_mode == "docker" || build_mode == "both" {
300        println!();
301        println!("  Building Docker images...");
302
303        // Create cargo cache dirs
304        let _ = std::fs::create_dir_all(".cargo-cache");
305        let _ = std::fs::create_dir_all(".cargo-git");
306
307        let status = docker::compose_run(".", &["build", "dev"])?;
308        if !status.success() {
309            println!("  Warning: CPU Docker image build failed.");
310        }
311
312        // CUDA image if we have GPUs and CUDA libtorch
313        let has_cuda_lt = detect::read_active(root)
314            .is_some_and(|i| i.cuda_version.as_deref() != Some("none"));
315
316        if !gpus.is_empty() && has_cuda_lt {
317            let _ = std::fs::create_dir_all(".cargo-cache-cuda");
318            let _ = std::fs::create_dir_all(".cargo-git-cuda");
319
320            let status = docker::compose_run(".", &["build", "cuda"])?;
321            if !status.success() {
322                println!("  Warning: CUDA Docker image build failed.");
323            }
324        }
325
326        println!("  Docker images ready.");
327    }
328
329    // ---- Summary ----
330
331    println!();
332    println!("  Setup complete!");
333    println!("  ===============");
334    println!();
335
336    // Show active libtorch
337    if let Some(info) = detect::read_active(root) {
338        let cuda_str = if info.cuda_version.as_deref() != Some("none") {
339            "CUDA"
340        } else {
341            "CPU"
342        };
343        println!("  libtorch:  {} ({})", info.path, cuda_str);
344    }
345
346    // Docker instructions
347    if build_mode == "docker" || build_mode == "both" {
348        println!();
349        println!("  Build with Docker:");
350        let has_cuda_lt = detect::read_active(root)
351            .is_some_and(|i| i.cuda_version.as_deref() != Some("none"));
352        if !gpus.is_empty() && has_cuda_lt {
353            println!("    make cuda-test       # run GPU tests");
354            println!("    make cuda-build      # compile with CUDA");
355            println!("    make cuda-shell      # interactive shell");
356        } else {
357            println!("    make test            # run tests");
358            println!("    make build           # compile");
359            println!("    make shell           # interactive shell");
360        }
361    }
362
363    // Native instructions
364    if build_mode == "native" || build_mode == "both" {
365        if let Some(info) = detect::read_active(root) {
366            let lt_path = format!("libtorch/{}", info.path);
367            println!();
368            println!("  Build natively:");
369            println!("    export LIBTORCH_PATH=\"{}\"", lt_path);
370            println!(
371                "    export LD_LIBRARY_PATH=\"$LIBTORCH_PATH/lib${{LD_LIBRARY_PATH:+:$LD_LIBRARY_PATH}}\""
372            );
373            let has_cuda_lt = info.cuda_version.as_deref() != Some("none");
374            if !gpus.is_empty() && has_cuda_lt {
375                println!("    cargo test --features cuda");
376            } else {
377                println!("    cargo test");
378            }
379        }
380    }
381
382    println!();
383    println!("  Other commands:");
384    println!("    fdl diagnose         # verify GPU compatibility");
385    println!("    fdl init my-project  # scaffold a new project");
386    println!();
387
388    Ok(())
389}