use crate::context::Context;
use crate::libtorch::{build, detect, download};
use crate::util::{docker, prompt, system};
#[derive(Default)]
pub struct SetupOpts {
pub non_interactive: bool,
pub force: bool,
}
pub fn run(opts: SetupOpts) -> Result<(), String> {
println!();
println!(" floDl Setup");
println!(" ===========");
println!();
println!(" floDl is a Rust deep learning framework built on libtorch");
println!(" (PyTorch's C++ backend). This wizard will help you set up");
println!(" your development environment.");
println!();
println!(" Step 1: Detecting your system");
println!(" -----------------------------");
println!();
let cpu = system::cpu_model().unwrap_or_else(|| "Unknown".into());
let threads = system::cpu_threads();
let ram_gb = system::ram_total_gb();
println!(" CPU: {} ({} threads, {}GB RAM)", cpu, threads, ram_gb);
let has_docker = docker::has_docker();
let has_cargo = system::has_cargo();
if has_docker {
if let Some(v) = system::docker_version() {
println!(" Docker: {}", v);
} else {
println!(" Docker: available");
}
} else {
println!(" Docker: not found");
}
if has_cargo {
println!(" Rust: available");
} else {
println!(" Rust: not found");
}
let gpus = system::detect_gpus();
if !gpus.is_empty() {
println!();
println!(" GPUs:");
for g in &gpus {
println!(
" [{}] {} -- sm_{}.{}, {}GB VRAM",
g.index,
g.name,
g.sm_major,
g.sm_minor,
g.total_memory_mb / 1024
);
}
} else {
println!();
println!(" GPU: not detected (CPU-only mode)");
}
if !has_docker && !has_cargo {
println!();
println!(" You need at least one of these to continue:");
println!();
println!(" Rust: curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh");
println!(" Docker: https://docs.docker.com/engine/install/");
println!();
println!(" Install one or both and run 'fdl setup' again.");
return Err("no Rust or Docker found".into());
}
println!();
println!(" Step 2: libtorch");
println!(" ----------------");
println!();
println!(" floDl needs libtorch, PyTorch's C++ library.");
println!(" This downloads pre-built binaries (~2GB for CUDA, ~200MB for CPU).");
println!();
let ctx = Context::resolve();
let root = &ctx.root;
if !ctx.is_project {
println!(" Not inside a floDl project.");
println!(" libtorch will be installed to: {}", ctx.libtorch_dir().display());
println!();
}
let existing = detect::read_active(root);
let mut skip_download = false;
if !opts.force {
if let Some(ref info) = existing {
let is_cuda = info.cuda_version.as_deref() != Some("none");
if is_cuda {
println!(" Found existing CUDA libtorch: {}", info.path);
if opts.non_interactive {
println!(" Keeping existing installation.");
skip_download = true;
} else if !prompt::ask_yn(" Download fresh?", false) {
skip_download = true;
}
println!();
} else {
println!(" Found existing CPU libtorch.");
}
}
}
if !skip_download {
println!(" Downloading CPU libtorch...");
let cpu_opts = download::DownloadOpts {
variant: download::Variant::Cpu,
activate: false, ..Default::default()
};
download::run_with_context(cpu_opts, &ctx)?;
if !gpus.is_empty() {
let lo_major = gpus.iter().map(|g| g.sm_major).min().unwrap_or(0);
let hi_major = gpus.iter().map(|g| g.sm_major).max().unwrap_or(0);
if lo_major < 7 && hi_major >= 10 {
println!();
println!(" Your GPUs span sm_{}.x to sm_{}.x.", lo_major, hi_major);
println!(" No pre-built libtorch covers both architectures.");
println!();
let has_source_build = detect::list_variants(root)
.iter()
.any(|v| v.starts_with("builds/"));
if has_source_build {
println!(" Found existing source build in libtorch/builds/.");
} else if opts.non_interactive {
println!(" Downloading cu126 (broadest coverage).");
let cuda_opts = download::DownloadOpts {
variant: download::Variant::Cuda126,
..Default::default()
};
download::run_with_context(cuda_opts, &ctx)?;
} else {
let choice = prompt::ask_choice(
" Choice",
&[
"Build libtorch from source (2-6 hours, covers all GPUs)",
"Download cu128 (Volta+ only, your older GPU won't work)",
"Download cu126 (pre-Volta only, your newer GPU won't work)",
"Skip for now",
],
4,
);
match choice {
1 => {
println!();
println!(" Starting libtorch source build...");
println!(" This will take 2-6 hours. You can safely Ctrl-C and");
println!(" resume later with: fdl libtorch build");
println!();
build::run(build::BuildOpts::default())?;
}
2 => {
println!(" Downloading cu128...");
let cuda_opts = download::DownloadOpts {
variant: download::Variant::Cuda128,
..Default::default()
};
download::run_with_context(cuda_opts, &ctx)?;
}
3 => {
println!(" Downloading cu126...");
let cuda_opts = download::DownloadOpts {
variant: download::Variant::Cuda126,
..Default::default()
};
download::run_with_context(cuda_opts, &ctx)?;
}
_ => {
println!(" Skipping CUDA libtorch. You can download later with:");
println!(" fdl libtorch download --cuda 12.8");
println!(" # or build from source:");
println!(" fdl libtorch build");
}
}
}
} else if lo_major < 7 {
println!();
println!(" Downloading CUDA libtorch (cu126 for your pre-Volta GPU)...");
let cuda_opts = download::DownloadOpts {
variant: download::Variant::Cuda126,
..Default::default()
};
download::run_with_context(cuda_opts, &ctx)?;
} else {
println!();
println!(" Downloading CUDA libtorch (cu128 for your Volta+ GPU)...");
let cuda_opts = download::DownloadOpts {
variant: download::Variant::Cuda128,
..Default::default()
};
download::run_with_context(cuda_opts, &ctx)?;
}
}
}
if !ctx.is_project {
println!();
println!(" Setup complete!");
println!(" ===============");
println!();
if let Some(info) = detect::read_active(root) {
let cuda_str = if info.cuda_version.as_deref() != Some("none") { "CUDA" } else { "CPU" };
println!(" libtorch: {} ({})", info.path, cuda_str);
println!(" Location: {}", ctx.libtorch_dir().display());
}
println!();
println!(" Next steps:");
println!(" fdl init my-project # scaffold a new project");
println!(" fdl diagnose # verify GPU compatibility");
println!();
return Ok(());
}
println!();
println!(" Step 3: Build environment");
println!(" -------------------------");
println!();
println!(" floDl compiles Rust code that links against libtorch.");
println!(" You can build with Docker (isolated, reproducible) or");
println!(" natively (faster iteration, requires Rust + C++ toolchain).");
println!();
let build_mode = if has_docker && has_cargo {
if opts.non_interactive {
"docker"
} else {
let choice = prompt::ask_choice(
" Choice",
&[
"Docker (recommended) -- isolated, reproducible builds",
"Native -- faster iteration, requires C++ compiler on host",
"Both -- set up Docker and show native instructions",
],
1,
);
match choice {
1 => "docker",
2 => "native",
3 => "both",
_ => "docker",
}
}
} else if has_docker {
if opts.non_interactive {
"docker"
} else {
println!(" Docker is available. Rust is not installed on this machine.");
println!(" Docker is the easiest way to get started (no Rust install needed).");
println!();
if prompt::ask_yn(" Set up Docker build environment?", true) {
"docker"
} else {
"none"
}
}
} else {
println!(" Rust is available. Docker is not installed.");
println!(" You can build natively (requires C++ compiler on the host).");
println!();
"native"
};
if build_mode == "docker" || build_mode == "both" {
println!();
println!(" Building Docker images...");
let _ = std::fs::create_dir_all(".cargo-cache");
let _ = std::fs::create_dir_all(".cargo-git");
let status = docker::compose_run(".", &["build", "dev"])?;
if !status.success() {
println!(" Warning: CPU Docker image build failed.");
}
let has_cuda_lt = detect::read_active(root)
.is_some_and(|i| i.cuda_version.as_deref() != Some("none"));
if !gpus.is_empty() && has_cuda_lt {
let _ = std::fs::create_dir_all(".cargo-cache-cuda");
let _ = std::fs::create_dir_all(".cargo-git-cuda");
let status = docker::compose_run(".", &["build", "cuda"])?;
if !status.success() {
println!(" Warning: CUDA Docker image build failed.");
}
}
println!(" Docker images ready.");
}
println!();
println!(" Setup complete!");
println!(" ===============");
println!();
if let Some(info) = detect::read_active(root) {
let cuda_str = if info.cuda_version.as_deref() != Some("none") {
"CUDA"
} else {
"CPU"
};
println!(" libtorch: {} ({})", info.path, cuda_str);
}
if build_mode == "docker" || build_mode == "both" {
println!();
println!(" Build with Docker:");
let has_cuda_lt = detect::read_active(root)
.is_some_and(|i| i.cuda_version.as_deref() != Some("none"));
if !gpus.is_empty() && has_cuda_lt {
println!(" make cuda-test # run GPU tests");
println!(" make cuda-build # compile with CUDA");
println!(" make cuda-shell # interactive shell");
} else {
println!(" make test # run tests");
println!(" make build # compile");
println!(" make shell # interactive shell");
}
}
if build_mode == "native" || build_mode == "both" {
if let Some(info) = detect::read_active(root) {
let lt_path = format!("libtorch/{}", info.path);
println!();
println!(" Build natively:");
println!(" export LIBTORCH_PATH=\"{}\"", lt_path);
println!(
" export LD_LIBRARY_PATH=\"$LIBTORCH_PATH/lib${{LD_LIBRARY_PATH:+:$LD_LIBRARY_PATH}}\""
);
let has_cuda_lt = info.cuda_version.as_deref() != Some("none");
if !gpus.is_empty() && has_cuda_lt {
println!(" cargo test --features cuda");
} else {
println!(" cargo test");
}
}
}
println!();
println!(" Other commands:");
println!(" fdl diagnose # verify GPU compatibility");
println!(" fdl init my-project # scaffold a new project");
println!();
Ok(())
}