Skip to main content

flodl_cli/
init.rs

1//! `fdl init <name>` -- scaffold a new floDl project.
2//!
3//! Three modes, selected by flag or interactive prompt:
4//! - `Mounted` (default): Docker with libtorch host-mounted at runtime.
5//! - `Docker` (`--docker`): Docker with libtorch baked into the image.
6//! - `Native` (`--native`): no Docker; libtorch and cargo provided on the host.
7
8use std::fs;
9use std::path::Path;
10use std::process::Command;
11
12use crate::util::prompt;
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15enum Mode {
16    Mounted,
17    Docker,
18    Native,
19}
20
21pub fn run(
22    name: Option<&str>,
23    docker: bool,
24    native: bool,
25    with_hf: bool,
26) -> Result<(), String> {
27    let name = name.ok_or("usage: fdl init <project-name>")?;
28    validate_name(name)?;
29
30    if Path::new(name).exists() {
31        return Err(format!("'{}' already exists", name));
32    }
33
34    if docker && native {
35        return Err("--docker and --native are mutually exclusive".into());
36    }
37    let flag_driven = docker || native || with_hf;
38    let mode = if docker {
39        Mode::Docker
40    } else if native {
41        Mode::Native
42    } else {
43        pick_mode_interactively()
44    };
45    // `--with-hf` bypasses the prompt entirely for scripted init.
46    // Without any flag, ask after mode selection; with *any* flag set
47    // the user signalled non-interactive intent, so respect `--with-hf`
48    // verbatim and skip the prompt.
49    let include_hf = if flag_driven {
50        with_hf
51    } else {
52        prompt::ask_yn(
53            "Include flodl-hf (HuggingFace: BERT/RoBERTa/DistilBERT, Hub loader, tokenizer)?",
54            false,
55        )
56    };
57
58    let crate_name = name.replace('-', "_");
59    let flodl_dep = resolve_flodl_dep();
60
61    fs::create_dir_all(format!("{}/src", name))
62        .map_err(|e| format!("cannot create directory: {}", e))?;
63
64    match mode {
65        Mode::Mounted => scaffold_mounted(name, &crate_name, &flodl_dep)?,
66        Mode::Docker => scaffold_docker(name, &crate_name, &flodl_dep)?,
67        Mode::Native => scaffold_native(name, &crate_name, &flodl_dep)?,
68    }
69
70    // Shared across all modes.
71    write_file(
72        &format!("{}/src/main.rs", name),
73        &main_rs_template(),
74    )?;
75    write_file(
76        &format!("{}/.gitignore", name),
77        &gitignore_template(mode),
78    )?;
79    write_file(
80        &format!("{}/fdl.yml.example", name),
81        &fdl_yml_example_template(name, mode),
82    )?;
83    write_fdl_bootstrap(name)?;
84
85    if include_hf {
86        let project_dir = Path::new(name);
87        if let Err(e) = crate::add::add_flodl_hf_at(project_dir) {
88            // Scaffolded project is still usable even if the HF sub-crate
89            // failed; surface the error but don't roll back.
90            eprintln!("warning: flodl-hf scaffold failed: {e}");
91            eprintln!("You can retry after `cd {}` with `fdl add flodl-hf`.", name);
92        }
93    }
94
95    print_next_steps(name, mode, include_hf);
96    crate::util::install_prompt::offer_global_install();
97    Ok(())
98}
99
100/// Ask the user interactively which mode to generate. Falls through to
101/// `Mounted` when no TTY is attached (the same default as passing no flag
102/// to `--non-interactive` tooling).
103fn pick_mode_interactively() -> Mode {
104    println!();
105    if !prompt::ask_yn("Use Docker for builds?", true) {
106        return Mode::Native;
107    }
108    // 1-based: 1 = mounted (default), 2 = baked-in.
109    let choice = prompt::ask_choice(
110        "libtorch location",
111        &[
112            "Mounted from host (recommended: lighter image, swap CUDA variants)",
113            "Baked into the Docker image (zero host dependencies)",
114        ],
115        1,
116    );
117    match choice {
118        2 => Mode::Docker,
119        _ => Mode::Mounted,
120    }
121}
122
123fn print_next_steps(name: &str, mode: Mode, include_hf: bool) {
124    println!();
125    println!("Project '{}' created. Next steps:", name);
126    println!();
127    println!("  cd {}", name);
128    match mode {
129        Mode::Mounted => {
130            println!("  ./fdl setup   # detect hardware + download libtorch");
131            println!("  ./fdl build   # build the project");
132        }
133        Mode::Docker => {
134            println!("  ./fdl build   # first build (downloads libtorch, ~5 min)");
135        }
136        Mode::Native => {
137            println!("  ./fdl libtorch download --cpu     # or --cuda 12.8");
138            println!("  ./fdl build                       # cargo build on the host");
139        }
140    }
141    println!("  ./fdl test    # run tests");
142    println!("  ./fdl run     # train the model");
143    if mode != Mode::Native {
144        println!("  ./fdl shell   # interactive shell");
145    }
146    if include_hf {
147        println!();
148        println!("  cd flodl-hf && fdl classify   # try the HuggingFace playground");
149    }
150    println!();
151    println!("`./fdl --help` lists every command defined in fdl.yml.");
152    println!("Edit src/main.rs to build your model.");
153    println!();
154    println!("Guides:");
155    println!("  Tutorials:         https://flodl.dev/guide/tensors");
156    println!("  Graph Tree:        https://flodl.dev/guide/graph-tree");
157    println!("  PyTorch migration: https://flodl.dev/guide/migration");
158    println!("  Troubleshooting:   https://flodl.dev/guide/troubleshooting");
159}
160
161fn write_fdl_bootstrap(name: &str) -> Result<(), String> {
162    let fdl_script = include_str!("../assets/fdl");
163    write_file(&format!("{}/fdl", name), fdl_script)?;
164    #[cfg(unix)]
165    {
166        use std::os::unix::fs::PermissionsExt;
167        let _ = fs::set_permissions(
168            format!("{}/fdl", name),
169            fs::Permissions::from_mode(0o755),
170        );
171    }
172    Ok(())
173}
174
175fn validate_name(name: &str) -> Result<(), String> {
176    if name.is_empty() {
177        return Err("project name cannot be empty".into());
178    }
179    if !name.chars().all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_') {
180        return Err("project name must contain only letters, digits, hyphens, underscores".into());
181    }
182    Ok(())
183}
184
185fn resolve_flodl_dep() -> String {
186    // Try crates.io for the latest version
187    if let Some(version) = crates_io_version() {
188        format!("flodl = \"{}\"", version)
189    } else {
190        "flodl = { git = \"https://github.com/flodl-labs/flodl.git\" }".into()
191    }
192}
193
194fn crates_io_version() -> Option<String> {
195    let output = Command::new("curl")
196        .args(["-sL", "https://crates.io/api/v1/crates/flodl"])
197        .output()
198        .ok()?;
199    let body = String::from_utf8_lossy(&output.stdout);
200    // Extract "max_stable_version":"X.Y.Z"
201    let marker = "\"max_stable_version\":\"";
202    let start = body.find(marker)? + marker.len();
203    let end = start + body[start..].find('"')?;
204    let version = &body[start..end];
205    if version.is_empty() { None } else { Some(version.to_string()) }
206}
207
208// ---------------------------------------------------------------------------
209// Docker scaffold (standalone, libtorch baked into images)
210// ---------------------------------------------------------------------------
211
212fn scaffold_docker(name: &str, crate_name: &str, flodl_dep: &str) -> Result<(), String> {
213    write_file(
214        &format!("{}/Cargo.toml", name),
215        &cargo_toml_template(crate_name, flodl_dep),
216    )?;
217    write_file(
218        &format!("{}/Dockerfile.cpu", name),
219        DOCKERFILE_CPU,
220    )?;
221    write_file(
222        &format!("{}/Dockerfile.cuda", name),
223        DOCKERFILE_CUDA,
224    )?;
225    write_file(
226        &format!("{}/docker-compose.yml", name),
227        &docker_compose_template(crate_name, true),
228    )?;
229    Ok(())
230}
231
232// ---------------------------------------------------------------------------
233// Mounted scaffold (libtorch from host, like the main repo)
234// ---------------------------------------------------------------------------
235
236fn scaffold_mounted(name: &str, crate_name: &str, flodl_dep: &str) -> Result<(), String> {
237    write_file(
238        &format!("{}/Cargo.toml", name),
239        &cargo_toml_template(crate_name, flodl_dep),
240    )?;
241    write_file(
242        &format!("{}/Dockerfile", name),
243        DOCKERFILE_MOUNTED,
244    )?;
245    write_file(
246        &format!("{}/Dockerfile.cuda", name),
247        DOCKERFILE_CUDA_MOUNTED,
248    )?;
249    write_file(
250        &format!("{}/docker-compose.yml", name),
251        &docker_compose_template(crate_name, false),
252    )?;
253    Ok(())
254}
255
256// ---------------------------------------------------------------------------
257// Native scaffold (no Docker; libtorch and cargo live on the host)
258// ---------------------------------------------------------------------------
259
260fn scaffold_native(name: &str, crate_name: &str, flodl_dep: &str) -> Result<(), String> {
261    write_file(
262        &format!("{}/Cargo.toml", name),
263        &cargo_toml_template(crate_name, flodl_dep),
264    )?;
265    // Intentionally no Dockerfile*/docker-compose.yml -- the user opted out
266    // of Docker. They can switch later by regenerating or adding their own.
267    Ok(())
268}
269
270// ---------------------------------------------------------------------------
271// Templates
272// ---------------------------------------------------------------------------
273
274fn cargo_toml_template(crate_name: &str, flodl_dep: &str) -> String {
275    format!(
276        r#"[package]
277name = "{crate_name}"
278version = "0.1.0"
279edition = "2024"
280
281[dependencies]
282{flodl_dep}
283
284# Optimize floDl in dev builds -- your code stays fast to compile.
285# After the first build, only your graph code recompiles (~2s).
286[profile.dev.package.flodl]
287opt-level = 3
288
289[profile.dev.package.flodl-sys]
290opt-level = 3
291
292# Release: cross-crate optimization for maximum throughput.
293[profile.release]
294lto = "thin"
295codegen-units = 1
296"#
297    )
298}
299
300fn main_rs_template() -> String {
301    r#"//! floDl training template.
302//!
303//! This is a starting point for your model. Edit the architecture,
304//! data loading, and training loop to fit your task.
305//!
306//! New to Rust? Read: https://flodl.dev/guide/rust-primer
307//! Stuck?       Read: https://flodl.dev/guide/troubleshooting
308
309use flodl::*;
310use flodl::monitor::Monitor;
311
312fn main() -> Result<()> {
313    // --- Model ---
314    let model = FlowBuilder::from(Linear::new(4, 32)?)
315        .through(GELU)
316        .through(LayerNorm::new(32)?)
317        .also(Linear::new(32, 32)?)       // residual connection
318        .through(Linear::new(32, 1)?)
319        .build()?;
320
321    // --- Optimizer ---
322    let params = model.parameters();
323    let mut optimizer = Adam::new(&params, 0.001);
324    let scheduler = CosineScheduler::new(0.001, 1e-6, 100);
325    model.train();
326
327    // --- Data ---
328    // Replace this with your data loading.
329    let opts = TensorOptions::default();
330    let batches: Vec<(Tensor, Tensor)> = (0..32)
331        .map(|_| {
332            let x = Tensor::randn(&[16, 4], opts).unwrap();
333            let y = Tensor::randn(&[16, 1], opts).unwrap();
334            (x, y)
335        })
336        .collect();
337
338    // --- Training loop ---
339    let num_epochs = 100usize;
340    let mut monitor = Monitor::new(num_epochs);
341    // monitor.serve(3000)?;              // uncomment for live dashboard
342    // monitor.watch(&model);             // uncomment to show graph SVG
343    // monitor.save_html("report.html");  // uncomment to save HTML report
344
345    for epoch in 0..num_epochs {
346        let t = std::time::Instant::now();
347        let mut epoch_loss = 0.0;
348
349        for (input_t, target_t) in &batches {
350            let input = Variable::new(input_t.clone(), true);
351            let target = Variable::new(target_t.clone(), false);
352
353            optimizer.zero_grad();
354            let pred = model.forward(&input)?;
355            let loss = mse_loss(&pred, &target)?;
356            loss.backward()?;
357            clip_grad_norm(&params, 1.0)?;
358            optimizer.step()?;
359
360            epoch_loss += loss.item()?;
361        }
362
363        let avg_loss = epoch_loss / batches.len() as f64;
364        let lr = scheduler.lr(epoch);
365        optimizer.set_lr(lr);
366        monitor.log(epoch, t.elapsed(), &[("loss", avg_loss), ("lr", lr)]);
367    }
368
369    monitor.finish();
370    Ok(())
371}
372"#
373    .into()
374}
375
376fn gitignore_template(mode: Mode) -> String {
377    let mut s = String::from(
378        "/target
379*.fdl
380*.log
381*.csv
382*.html
383
384# Local fdl config (fdl.yml.example is committed; fdl copies it on first run)
385fdl.yml
386fdl.yaml
387",
388    );
389    match mode {
390        Mode::Docker => {
391            // libtorch is baked into the image, nothing on host to ignore.
392            s.push_str(
393                ".cargo-cache/
394.cargo-git/
395.cargo-cache-cuda/
396.cargo-git-cuda/
397",
398            );
399        }
400        Mode::Mounted => {
401            // Mounted libtorch + separate cargo caches per docker service.
402            s.push_str(
403                ".cargo-cache/
404.cargo-git/
405.cargo-cache-cuda/
406.cargo-git-cuda/
407libtorch/
408",
409            );
410        }
411        Mode::Native => {
412            // No docker, no container caches. libtorch/ is still ignored
413            // because `./fdl libtorch download` installs it locally.
414            s.push_str("libtorch/\n");
415        }
416    }
417    s
418}
419
420fn docker_compose_template(crate_name: &str, baked: bool) -> String {
421    if baked {
422        format!(
423            r#"services:
424  dev:
425    build:
426      context: .
427      dockerfile: Dockerfile.cpu
428    image: {crate_name}-dev
429    user: "${{UID:-1000}}:${{GID:-1000}}"
430    volumes:
431      - .:/workspace
432      - ./.cargo-cache:/usr/local/cargo/registry
433      - ./.cargo-git:/usr/local/cargo/git
434    working_dir: /workspace
435    stdin_open: true
436    tty: true
437
438  cuda:
439    build:
440      context: .
441      dockerfile: Dockerfile.cuda
442    image: {crate_name}-cuda
443    user: "${{UID:-1000}}:${{GID:-1000}}"
444    volumes:
445      - .:/workspace
446      - ./.cargo-cache-cuda:/usr/local/cargo/registry
447      - ./.cargo-git-cuda:/usr/local/cargo/git
448    working_dir: /workspace
449    stdin_open: true
450    tty: true
451    deploy:
452      resources:
453        reservations:
454          devices:
455            - driver: nvidia
456              count: all
457              capabilities: [gpu]
458"#
459        )
460    } else {
461        format!(
462            r#"services:
463  dev:
464    build:
465      context: .
466      dockerfile: Dockerfile
467    image: {crate_name}-dev
468    user: "${{UID:-1000}}:${{GID:-1000}}"
469    volumes:
470      - .:/workspace
471      - ./.cargo-cache:/usr/local/cargo/registry
472      - ./.cargo-git:/usr/local/cargo/git
473      - ${{LIBTORCH_CPU_PATH:-./libtorch/precompiled/cpu}}:/usr/local/libtorch:ro
474    working_dir: /workspace
475    stdin_open: true
476    tty: true
477
478  cuda:
479    build:
480      context: .
481      dockerfile: Dockerfile.cuda
482      args:
483        CUDA_VERSION: ${{CUDA_VERSION:-12.8.0}}
484    image: {crate_name}-cuda:${{CUDA_TAG:-12.8}}
485    user: "${{UID:-1000}}:${{GID:-1000}}"
486    volumes:
487      - .:/workspace
488      - ./.cargo-cache-cuda:/usr/local/cargo/registry
489      - ./.cargo-git-cuda:/usr/local/cargo/git
490      - ${{LIBTORCH_HOST_PATH:-./libtorch/precompiled/cu128}}:/usr/local/libtorch:ro
491    working_dir: /workspace
492    stdin_open: true
493    tty: true
494    deploy:
495      resources:
496        reservations:
497          devices:
498            - driver: nvidia
499              count: all
500              capabilities: [gpu]
501"#
502        )
503    }
504}
505
506// ---------------------------------------------------------------------------
507// Dockerfile templates
508// ---------------------------------------------------------------------------
509
510// Docker mode: libtorch baked into images
511const DOCKERFILE_CPU: &str = r#"# CPU-only dev image for floDl projects.
512FROM ubuntu:24.04
513
514ENV DEBIAN_FRONTEND=noninteractive
515
516RUN apt-get update && apt-get install -y --no-install-recommends \
517    wget curl unzip ca-certificates git gcc g++ pkg-config graphviz \
518    && rm -rf /var/lib/apt/lists/*
519
520# Rust
521ENV CARGO_HOME="/usr/local/cargo"
522ENV RUSTUP_HOME="/usr/local/rustup"
523RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --default-toolchain stable \
524    && chmod -R a+rwx "$CARGO_HOME" "$RUSTUP_HOME"
525ENV PATH="${CARGO_HOME}/bin:${PATH}"
526
527# libtorch (CPU-only, ~200MB)
528ARG LIBTORCH_VERSION=2.10.0
529RUN wget -q https://download.pytorch.org/libtorch/cpu/libtorch-shared-with-deps-${LIBTORCH_VERSION}%2Bcpu.zip \
530    && unzip -q libtorch-shared-with-deps-${LIBTORCH_VERSION}+cpu.zip -d /usr/local \
531    && rm libtorch-shared-with-deps-${LIBTORCH_VERSION}+cpu.zip
532
533ENV LIBTORCH_PATH="/usr/local/libtorch"
534ENV LD_LIBRARY_PATH="${LIBTORCH_PATH}/lib"
535ENV LIBRARY_PATH="${LIBTORCH_PATH}/lib"
536
537WORKDIR /workspace
538"#;
539
540const DOCKERFILE_CUDA: &str = r#"# CUDA dev image for floDl projects.
541# Requires: docker run --gpus all ...
542FROM nvidia/cuda:12.8.0-devel-ubuntu24.04
543
544ENV DEBIAN_FRONTEND=noninteractive
545
546RUN apt-get update && apt-get install -y --no-install-recommends \
547    wget curl unzip ca-certificates git gcc g++ pkg-config graphviz \
548    && rm -rf /var/lib/apt/lists/*
549
550# Rust
551ENV CARGO_HOME="/usr/local/cargo"
552ENV RUSTUP_HOME="/usr/local/rustup"
553RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --default-toolchain stable \
554    && chmod -R a+rwx "$CARGO_HOME" "$RUSTUP_HOME"
555ENV PATH="${CARGO_HOME}/bin:${PATH}"
556
557# libtorch (CUDA 12.8)
558ARG LIBTORCH_VERSION=2.10.0
559RUN wget -q "https://download.pytorch.org/libtorch/cu128/libtorch-shared-with-deps-${LIBTORCH_VERSION}%2Bcu128.zip" \
560    && unzip -q "libtorch-shared-with-deps-${LIBTORCH_VERSION}+cu128.zip" -d /usr/local \
561    && rm "libtorch-shared-with-deps-${LIBTORCH_VERSION}+cu128.zip"
562
563ENV LIBTORCH_PATH="/usr/local/libtorch"
564ENV LD_LIBRARY_PATH="${LIBTORCH_PATH}/lib:/usr/local/cuda/lib64"
565ENV LIBRARY_PATH="${LIBTORCH_PATH}/lib:/usr/local/cuda/lib64"
566ENV CUDA_HOME="/usr/local/cuda"
567
568WORKDIR /workspace
569"#;
570
571// Mounted mode: libtorch provided at runtime via volume mount
572const DOCKERFILE_MOUNTED: &str = r#"# CPU dev image for floDl projects (libtorch mounted at runtime).
573FROM ubuntu:24.04
574
575ENV DEBIAN_FRONTEND=noninteractive
576
577RUN apt-get update && apt-get install -y --no-install-recommends \
578    wget curl unzip ca-certificates git gcc g++ pkg-config graphviz \
579    && rm -rf /var/lib/apt/lists/*
580
581# Rust
582ENV CARGO_HOME="/usr/local/cargo"
583ENV RUSTUP_HOME="/usr/local/rustup"
584RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --default-toolchain stable \
585    && chmod -R a+rwx "$CARGO_HOME" "$RUSTUP_HOME"
586ENV PATH="${CARGO_HOME}/bin:${PATH}"
587
588ENV LIBTORCH_PATH="/usr/local/libtorch"
589ENV LD_LIBRARY_PATH="${LIBTORCH_PATH}/lib"
590ENV LIBRARY_PATH="${LIBTORCH_PATH}/lib"
591
592WORKDIR /workspace
593"#;
594
595const DOCKERFILE_CUDA_MOUNTED: &str = r#"# CUDA dev image for floDl projects (libtorch mounted at runtime).
596# Requires: docker run --gpus all ...
597ARG CUDA_VERSION=12.8.0
598FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu24.04
599
600ENV DEBIAN_FRONTEND=noninteractive
601
602RUN apt-get update && apt-get install -y --no-install-recommends \
603    wget curl unzip ca-certificates git gcc g++ pkg-config graphviz \
604    && rm -rf /var/lib/apt/lists/*
605
606# Rust
607ENV CARGO_HOME="/usr/local/cargo"
608ENV RUSTUP_HOME="/usr/local/rustup"
609RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --default-toolchain stable \
610    && chmod -R a+rwx "$CARGO_HOME" "$RUSTUP_HOME"
611ENV PATH="${CARGO_HOME}/bin:${PATH}"
612
613ENV LIBTORCH_PATH="/usr/local/libtorch"
614ENV LD_LIBRARY_PATH="${LIBTORCH_PATH}/lib:/usr/local/cuda/lib64"
615ENV LIBRARY_PATH="${LIBTORCH_PATH}/lib:/usr/local/cuda/lib64"
616ENV CUDA_HOME="/usr/local/cuda"
617
618WORKDIR /workspace
619"#;
620
621// ---------------------------------------------------------------------------
622// fdl.yml.example template
623// ---------------------------------------------------------------------------
624
625/// The scaffold ships `fdl.yml.example` (committed) and fdl auto-copies it to
626/// the gitignored `fdl.yml` on first use. Docker modes attach `docker:` to
627/// every command; native mode drops `docker:` so the commands run directly
628/// on the host. Libtorch env vars (`LIBTORCH_HOST_PATH`, `CUDA_VERSION`,
629/// `CUDA_TAG`, etc.) are derived from `libtorch/.active` by
630/// `flodl-cli/src/run.rs::libtorch_env` before each `docker compose run`
631/// (Docker modes) or exported into the child process (native mode).
632fn fdl_yml_example_template(project_name: &str, mode: Mode) -> String {
633    let use_docker = matches!(mode, Mode::Mounted | Mode::Docker);
634    let (cpu_svc, cuda_svc) = if use_docker {
635        ("\n    docker: dev", "\n    docker: cuda")
636    } else {
637        ("", "")
638    };
639    let cuda_note = if use_docker {
640        "(requires NVIDIA Container Toolkit)"
641    } else {
642        "(requires a matching CUDA toolkit on the host)"
643    };
644    let preamble = if use_docker {
645        "# Run any of these with `./fdl <cmd>` (or `fdl <cmd>` once installed\n\
646         # globally via `./fdl install`). Libtorch env vars are derived from\n\
647         # `libtorch/.active` automatically; missing libtorch surfaces as a\n\
648         # clean linker error, with `./fdl setup` one call away."
649    } else {
650        "# Native mode: commands run on the host. Make sure libtorch is\n\
651         # installed (`./fdl libtorch download --cpu` or `--cuda 12.8`)\n\
652         # and that `$LIBTORCH` / `$LD_LIBRARY_PATH` are exported so\n\
653         # cargo can link. `./fdl libtorch info` prints the commands you\n\
654         # need after a download."
655    };
656
657    let shell_block = if use_docker {
658        format!(
659            r#"  shell:
660    description: Interactive shell (CPU container)
661    run: bash{cpu_svc}
662
663"#
664        )
665    } else {
666        // Native mode: no container to drop into; users open their own shell.
667        String::new()
668    };
669
670    let cuda_shell_block = if use_docker {
671        format!(
672            r#"  cuda-shell:
673    description: Interactive shell (CUDA container)
674    run: bash{cuda_svc}
675"#
676        )
677    } else {
678        String::new()
679    };
680
681    format!(
682        r#"description: {project_name}
683
684{preamble}
685
686commands:
687  # --- CPU ---
688  build:
689    description: Build (debug)
690    run: cargo build{cpu_svc}
691  test:
692    description: Run CPU tests
693    run: cargo test -- --nocapture{cpu_svc}
694  run:
695    description: cargo run
696    run: cargo run{cpu_svc}
697  check:
698    description: Type-check without building
699    run: cargo check{cpu_svc}
700  clippy:
701    description: Lint
702    run: cargo clippy -- -W clippy::all{cpu_svc}
703{shell_block}  # --- CUDA {cuda_note} ---
704  cuda-build:
705    description: Build with CUDA feature
706    run: cargo build --features cuda{cuda_svc}
707  cuda-test:
708    description: Run CUDA tests
709    run: cargo test --features cuda -- --nocapture{cuda_svc}
710  cuda-run:
711    description: cargo run --features cuda
712    run: cargo run --features cuda{cuda_svc}
713{cuda_shell_block}"#
714    )
715}
716
717// ---------------------------------------------------------------------------
718// File writing helper
719// ---------------------------------------------------------------------------
720
721fn write_file(path: &str, content: &str) -> Result<(), String> {
722    fs::write(path, content).map_err(|e| format!("cannot write {}: {}", path, e))
723}