1use crate::context::Context;
6use crate::libtorch::{build, detect, download};
7use crate::util::{docker, prompt, system};
8
9#[derive(Default)]
10pub struct SetupOpts {
11 pub non_interactive: bool,
13 pub force: bool,
15}
16
17pub fn run(opts: SetupOpts) -> Result<(), String> {
18 println!();
19 println!(" floDl Setup");
20 println!(" ===========");
21 println!();
22 println!(" floDl is a Rust deep learning framework built on libtorch");
23 println!(" (PyTorch's C++ backend). This wizard will help you set up");
24 println!(" your development environment.");
25 println!();
26
27 println!(" Step 1: Detecting your system");
30 println!(" -----------------------------");
31 println!();
32
33 let cpu = system::cpu_model().unwrap_or_else(|| "Unknown".into());
34 let threads = system::cpu_threads();
35 let ram_gb = system::ram_total_gb();
36 println!(" CPU: {} ({} threads, {}GB RAM)", cpu, threads, ram_gb);
37
38 let has_docker = docker::has_docker();
39 let has_cargo = system::has_cargo();
40
41 if has_docker {
42 if let Some(v) = system::docker_version() {
43 println!(" Docker: {}", v);
44 } else {
45 println!(" Docker: available");
46 }
47 } else {
48 println!(" Docker: not found");
49 }
50
51 if has_cargo {
52 println!(" Rust: available");
53 } else {
54 println!(" Rust: not found");
55 }
56
57 let gpus = system::detect_gpus();
58 if !gpus.is_empty() {
59 println!();
60 println!(" GPUs:");
61 for g in &gpus {
62 println!(
63 " [{}] {} -- sm_{}.{}, {}GB VRAM",
64 g.index,
65 g.name,
66 g.sm_major,
67 g.sm_minor,
68 g.total_memory_mb / 1024
69 );
70 }
71 } else {
72 println!();
73 println!(" GPU: not detected (CPU-only mode)");
74 }
75
76 if !has_docker && !has_cargo {
77 println!();
78 println!(" You need at least one of these to continue:");
79 println!();
80 println!(" Rust: curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh");
81 println!(" Docker: https://docs.docker.com/engine/install/");
82 println!();
83 println!(" Install one or both and run 'fdl setup' again.");
84 return Err("no Rust or Docker found".into());
85 }
86
87 println!();
90 println!(" Step 2: libtorch");
91 println!(" ----------------");
92 println!();
93 println!(" floDl needs libtorch, PyTorch's C++ library.");
94 println!(" This downloads pre-built binaries (~2GB for CUDA, ~200MB for CPU).");
95 println!();
96
97 let ctx = Context::resolve();
98 let root = &ctx.root;
99
100 if !ctx.is_project {
101 println!(" Not inside a floDl project.");
102 println!(" libtorch will be installed to: {}", ctx.libtorch_dir().display());
103 println!();
104 }
105
106 let existing = detect::read_active(root);
107 let mut skip_download = false;
108
109 if !opts.force {
110 if let Some(ref info) = existing {
111 let is_cuda = info.cuda_version.as_deref() != Some("none");
112 if is_cuda {
113 println!(" Found existing CUDA libtorch: {}", info.path);
114 if opts.non_interactive {
115 println!(" Keeping existing installation.");
116 skip_download = true;
117 } else if !prompt::ask_yn(" Download fresh?", false) {
118 skip_download = true;
119 }
120 println!();
121 } else {
122 println!(" Found existing CPU libtorch.");
123 }
124 }
125 }
126
127 if !skip_download {
128 println!(" Downloading CPU libtorch...");
130 let cpu_opts = download::DownloadOpts {
131 variant: download::Variant::Cpu,
132 activate: false, ..Default::default()
134 };
135 download::run_with_context(cpu_opts, &ctx)?;
136
137 if !gpus.is_empty() {
139 let lo_major = gpus.iter().map(|g| g.sm_major).min().unwrap_or(0);
140 let hi_major = gpus.iter().map(|g| g.sm_major).max().unwrap_or(0);
141
142 if lo_major < 7 && hi_major >= 10 {
143 println!();
145 println!(" Your GPUs span sm_{}.x to sm_{}.x.", lo_major, hi_major);
146 println!(" No pre-built libtorch covers both architectures.");
147 println!();
148
149 let has_source_build = detect::list_variants(root)
151 .iter()
152 .any(|v| v.starts_with("builds/"));
153
154 if has_source_build {
155 println!(" Found existing source build in libtorch/builds/.");
156 } else if opts.non_interactive {
157 println!(" Downloading cu126 (broadest coverage).");
158 let cuda_opts = download::DownloadOpts {
159 variant: download::Variant::Cuda126,
160 ..Default::default()
161 };
162 download::run_with_context(cuda_opts, &ctx)?;
163 } else {
164 let choice = prompt::ask_choice(
165 " Choice",
166 &[
167 "Build libtorch from source (2-6 hours, covers all GPUs)",
168 "Download cu128 (Volta+ only, your older GPU won't work)",
169 "Download cu126 (pre-Volta only, your newer GPU won't work)",
170 "Skip for now",
171 ],
172 4,
173 );
174
175 match choice {
176 1 => {
177 println!();
178 println!(" Starting libtorch source build...");
179 println!(" This will take 2-6 hours. You can safely Ctrl-C and");
180 println!(" resume later with: fdl libtorch build");
181 println!();
182 build::run(build::BuildOpts::default())?;
183 }
184 2 => {
185 println!(" Downloading cu128...");
186 let cuda_opts = download::DownloadOpts {
187 variant: download::Variant::Cuda128,
188 ..Default::default()
189 };
190 download::run_with_context(cuda_opts, &ctx)?;
191 }
192 3 => {
193 println!(" Downloading cu126...");
194 let cuda_opts = download::DownloadOpts {
195 variant: download::Variant::Cuda126,
196 ..Default::default()
197 };
198 download::run_with_context(cuda_opts, &ctx)?;
199 }
200 _ => {
201 println!(" Skipping CUDA libtorch. You can download later with:");
202 println!(" fdl libtorch download --cuda 12.8");
203 println!(" # or build from source:");
204 println!(" fdl libtorch build");
205 }
206 }
207 }
208 } else if lo_major < 7 {
209 println!();
210 println!(" Downloading CUDA libtorch (cu126 for your pre-Volta GPU)...");
211 let cuda_opts = download::DownloadOpts {
212 variant: download::Variant::Cuda126,
213 ..Default::default()
214 };
215 download::run_with_context(cuda_opts, &ctx)?;
216 } else {
217 println!();
218 println!(" Downloading CUDA libtorch (cu128 for your Volta+ GPU)...");
219 let cuda_opts = download::DownloadOpts {
220 variant: download::Variant::Cuda128,
221 ..Default::default()
222 };
223 download::run_with_context(cuda_opts, &ctx)?;
224 }
225 }
226 }
227
228 if !ctx.is_project {
231 println!();
233 println!(" Setup complete!");
234 println!(" ===============");
235 println!();
236 if let Some(info) = detect::read_active(root) {
237 let cuda_str = if info.cuda_version.as_deref() != Some("none") { "CUDA" } else { "CPU" };
238 println!(" libtorch: {} ({})", info.path, cuda_str);
239 println!(" Location: {}", ctx.libtorch_dir().display());
240 }
241 println!();
242 println!(" Next steps:");
243 println!(" fdl init my-project # scaffold a new project");
244 println!(" fdl diagnose # verify GPU compatibility");
245 println!();
246 return Ok(());
247 }
248
249 println!();
250 println!(" Step 3: Build environment");
251 println!(" -------------------------");
252 println!();
253 println!(" floDl compiles Rust code that links against libtorch.");
254 println!(" You can build with Docker (isolated, reproducible) or");
255 println!(" natively (faster iteration, requires Rust + C++ toolchain).");
256 println!();
257
258 let build_mode = if has_docker && has_cargo {
259 if opts.non_interactive {
260 "docker"
261 } else {
262 let choice = prompt::ask_choice(
263 " Choice",
264 &[
265 "Docker (recommended) -- isolated, reproducible builds",
266 "Native -- faster iteration, requires C++ compiler on host",
267 "Both -- set up Docker and show native instructions",
268 ],
269 1,
270 );
271 match choice {
272 1 => "docker",
273 2 => "native",
274 3 => "both",
275 _ => "docker",
276 }
277 }
278 } else if has_docker {
279 if opts.non_interactive {
280 "docker"
281 } else {
282 println!(" Docker is available. Rust is not installed on this machine.");
283 println!(" Docker is the easiest way to get started (no Rust install needed).");
284 println!();
285 if prompt::ask_yn(" Set up Docker build environment?", true) {
286 "docker"
287 } else {
288 "none"
289 }
290 }
291 } else {
292 println!(" Rust is available. Docker is not installed.");
293 println!(" You can build natively (requires C++ compiler on the host).");
294 println!();
295 "native"
296 };
297
298 if build_mode == "docker" || build_mode == "both" {
300 println!();
301 println!(" Building Docker images...");
302
303 let _ = std::fs::create_dir_all(".cargo-cache");
305 let _ = std::fs::create_dir_all(".cargo-git");
306
307 let status = docker::compose_run(".", &["build", "dev"])?;
308 if !status.success() {
309 println!(" Warning: CPU Docker image build failed.");
310 }
311
312 let has_cuda_lt = detect::read_active(root)
314 .is_some_and(|i| i.cuda_version.as_deref() != Some("none"));
315
316 if !gpus.is_empty() && has_cuda_lt {
317 let _ = std::fs::create_dir_all(".cargo-cache-cuda");
318 let _ = std::fs::create_dir_all(".cargo-git-cuda");
319
320 let status = docker::compose_run(".", &["build", "cuda"])?;
321 if !status.success() {
322 println!(" Warning: CUDA Docker image build failed.");
323 }
324 }
325
326 println!(" Docker images ready.");
327 }
328
329 println!();
332 println!(" Setup complete!");
333 println!(" ===============");
334 println!();
335
336 if let Some(info) = detect::read_active(root) {
338 let cuda_str = if info.cuda_version.as_deref() != Some("none") {
339 "CUDA"
340 } else {
341 "CPU"
342 };
343 println!(" libtorch: {} ({})", info.path, cuda_str);
344 }
345
346 if build_mode == "docker" || build_mode == "both" {
348 println!();
349 println!(" Build with Docker:");
350 let has_cuda_lt = detect::read_active(root)
351 .is_some_and(|i| i.cuda_version.as_deref() != Some("none"));
352 if !gpus.is_empty() && has_cuda_lt {
353 println!(" make cuda-test # run GPU tests");
354 println!(" make cuda-build # compile with CUDA");
355 println!(" make cuda-shell # interactive shell");
356 } else {
357 println!(" make test # run tests");
358 println!(" make build # compile");
359 println!(" make shell # interactive shell");
360 }
361 }
362
363 if build_mode == "native" || build_mode == "both" {
365 if let Some(info) = detect::read_active(root) {
366 let lt_path = format!("libtorch/{}", info.path);
367 println!();
368 println!(" Build natively:");
369 println!(" export LIBTORCH_PATH=\"{}\"", lt_path);
370 println!(
371 " export LD_LIBRARY_PATH=\"$LIBTORCH_PATH/lib${{LD_LIBRARY_PATH:+:$LD_LIBRARY_PATH}}\""
372 );
373 let has_cuda_lt = info.cuda_version.as_deref() != Some("none");
374 if !gpus.is_empty() && has_cuda_lt {
375 println!(" cargo test --features cuda");
376 } else {
377 println!(" cargo test");
378 }
379 }
380 }
381
382 println!();
383 println!(" Other commands:");
384 println!(" fdl diagnose # verify GPU compatibility");
385 println!(" fdl init my-project # scaffold a new project");
386 println!();
387
388 Ok(())
389}