1use std::fs;
19use std::path::{Path, PathBuf};
20
21use crate::util::{cargo_toml as cargo_edit, fdl_yml as yml_edit, prompt};
22
23const TEMPLATE_CARGO_TOML: &str = include_str!("scaffold/Cargo.toml.in");
30const TEMPLATE_MAIN_RS: &str = include_str!("scaffold/src/main.rs");
31const TEMPLATE_FDL_YML: &str = include_str!("scaffold/fdl.yml.example");
32const TEMPLATE_README: &str = include_str!("scaffold/README.md");
33const TEMPLATE_GITIGNORE: &str = include_str!("scaffold/.gitignore");
34
35const FDL_YML_HF_DESCRIPTION: &str =
37 "HuggingFace integration (BERT, RoBERTa, DistilBERT, ...)";
38
39pub fn run(target: Option<&str>, playground: bool, install: bool) -> Result<(), String> {
40 let target = target.ok_or(
41 "usage: fdl add <target> [--playground] [--install]\n\n\
42 Supported targets:\n \
43 flodl-hf HuggingFace integration (pre-built BERT / RoBERTa / DistilBERT, Hub loader, tokenizer)",
44 )?;
45 match target {
46 "flodl-hf" | "hf" => {}
47 other => {
48 return Err(format!(
49 "unknown target: {other:?}\n\n\
50 Supported targets:\n \
51 flodl-hf HuggingFace integration\n\n\
52 (More targets land as the flodl ecosystem grows.)",
53 ));
54 }
55 }
56
57 let cwd = std::env::current_dir()
58 .map_err(|e| format!("cannot read current directory: {e}"))?;
59
60 let (do_playground, do_install) = if !playground && !install {
62 resolve_interactive()?
63 } else {
64 (playground, install)
65 };
66
67 if do_install {
68 install_flodl_hf_at(&cwd)?;
69 }
70 if do_playground {
71 add_flodl_hf_at(&cwd)?;
72 }
73 Ok(())
74}
75
76fn resolve_interactive() -> Result<(bool, bool), String> {
80 if !has_tty() {
81 return Err(
82 "fdl add flodl-hf needs an interactive terminal to prompt.\n\
83 Pass --playground (sandbox at ./flodl-hf/) or --install \
84 (add to Cargo.toml), or both."
85 .into(),
86 );
87 }
88
89 println!("Add flodl-hf to your project?");
90 println!();
91 let choice = prompt::ask_choice(
92 "Choose",
93 &[
94 "playground sandbox at ./flodl-hf/ (try it without touching your project)",
95 "install add flodl-hf to your root Cargo.toml as a dependency",
96 "both playground + install (try it, and wire it in)",
97 "cancel",
98 ],
99 1,
100 );
101 println!();
102
103 match choice {
104 1 => Ok((true, false)),
105 2 => Ok((false, true)),
106 3 => Ok((true, true)),
107 _ => Err("cancelled.".into()),
108 }
109}
110
111fn has_tty() -> bool {
115 #[cfg(unix)]
116 {
117 std::fs::File::open("/dev/tty").is_ok()
118 }
119 #[cfg(windows)]
120 {
121 std::fs::OpenOptions::new()
122 .read(true)
123 .open("CONIN$")
124 .is_ok()
125 }
126 #[cfg(not(any(unix, windows)))]
127 {
128 true
129 }
130}
131
132pub fn install_flodl_hf_at(cwd: &Path) -> Result<(), String> {
135 let cargo_toml = cwd.join("Cargo.toml");
136 if !cargo_toml.exists() {
137 return Err(format!(
138 "no Cargo.toml in {}.\n\n\
139 fdl add flodl-hf --install must run from a flodl project root.\n\
140 Start with `fdl init <name>` if you don't have one yet.",
141 cwd.display(),
142 ));
143 }
144
145 let flodl_version = detect_flodl_version(&cargo_toml)?;
146 let version_spec = format!("={flodl_version}");
147 let outcome = cargo_edit::add_dep(&cargo_toml, "flodl-hf", &version_spec)?;
148
149 match outcome {
150 cargo_edit::AddDepOutcome::AlreadyPresent => {
151 println!("flodl-hf is already declared in {}.", cargo_toml.display());
152 println!("Edit the entry directly to change version or features.");
153 }
154 cargo_edit::AddDepOutcome::Added => {
155 println!();
156 println!(
157 "Added flodl-hf = \"={flodl_version}\" to {} with default features (hub, tokenizer).",
158 cargo_toml.display(),
159 );
160 println!();
161 println!("Default features include the HuggingFace Hub loader and tokenizer.");
162 println!("To switch to offline / vision-only flavors, edit the entry manually:");
163 println!(" flodl-hf = {{ version = \"={flodl_version}\", default-features = false, features = [...] }}");
164 println!();
165 println!("Run `fdl build` (or `cargo build`) to pull and compile the new dependency.");
166 }
167 }
168 Ok(())
169}
170
171pub fn add_flodl_hf_at(cwd: &Path) -> Result<(), String> {
177 let cargo_toml = cwd.join("Cargo.toml");
179 if !cargo_toml.exists() {
180 return Err(format!(
181 "no Cargo.toml in {}.\n\n\
182 fdl add flodl-hf must run from a flodl project root.\n\
183 Start with `fdl init <name>` if you don't have one yet.",
184 cwd.display(),
185 ));
186 }
187
188 if !has_fdl_config(cwd) {
193 return Err(format!(
194 "no fdl.yml (nor fdl.yml.example) in {}.\n\n\
195 fdl add flodl-hf expects an initialised flodl project: \
196 Docker or native mode already chosen, fdl.yml present. \
197 Run `fdl init <name>` first, or cd into an existing flodl project.",
198 cwd.display(),
199 ));
200 }
201
202 let flodl_version = detect_flodl_version(&cargo_toml)?;
203 let mode = detect_project_mode(cwd);
204
205 let dest = cwd.join("flodl-hf");
207 if dest.exists() {
208 return Err(format!(
209 "{} already exists.\n\n\
210 Remove it first, or keep it. `fdl add flodl-hf` does not overwrite.",
211 dest.display(),
212 ));
213 }
214
215 fs::create_dir_all(dest.join("src"))
217 .map_err(|e| format!("cannot create {}: {e}", dest.join("src").display()))?;
218
219 write_file(
220 &dest.join("Cargo.toml"),
221 &substitute_version(TEMPLATE_CARGO_TOML, &flodl_version),
222 )?;
223 write_file(&dest.join("src/main.rs"), TEMPLATE_MAIN_RS)?;
224 let fdl_yml = render_fdl_yml(TEMPLATE_FDL_YML, mode);
225 write_file(&dest.join("fdl.yml.example"), &fdl_yml)?;
226 write_file(&dest.join("fdl.yml"), &fdl_yml)?;
227 write_file(
228 &dest.join("README.md"),
229 &substitute_version(TEMPLATE_README, &flodl_version),
230 )?;
231 write_file(&dest.join(".gitignore"), TEMPLATE_GITIGNORE)?;
232
233 link_into_root_fdl_yml(cwd)?;
237
238 print_next_steps(&flodl_version, mode);
239 Ok(())
240}
241
242fn link_into_root_fdl_yml(cwd: &Path) -> Result<(), String> {
246 for filename in ["fdl.yml", "fdl.yml.example"] {
247 let path = cwd.join(filename);
248 if !path.exists() {
249 continue;
250 }
251 yml_edit::add_command(&path, "flodl-hf", FDL_YML_HF_DESCRIPTION)?;
252 }
253 Ok(())
254}
255
256#[derive(Debug, Clone, Copy, PartialEq, Eq)]
263enum ProjectMode {
264 Docker,
265 Native,
266}
267
268fn has_fdl_config(cwd: &Path) -> bool {
269 cwd.join("fdl.yml").exists() || cwd.join("fdl.yml.example").exists()
270}
271
272fn detect_project_mode(cwd: &Path) -> ProjectMode {
273 if cwd.join("docker-compose.yml").exists() {
274 ProjectMode::Docker
275 } else {
276 ProjectMode::Native
277 }
278}
279
280fn render_fdl_yml(template: &str, mode: ProjectMode) -> String {
286 match mode {
287 ProjectMode::Docker => template.to_string(),
288 ProjectMode::Native => template
289 .lines()
290 .filter(|l| l.trim() != "docker: dev")
291 .collect::<Vec<&str>>()
292 .join("\n")
293 + "\n",
294 }
295}
296
297fn detect_flodl_version(cargo_toml: &Path) -> Result<String, String> {
308 let content = fs::read_to_string(cargo_toml)
309 .map_err(|e| format!("cannot read {}: {e}", cargo_toml.display()))?;
310
311 if let Some(v) = parse_flodl_dep(&content)? {
312 return Ok(v);
313 }
314
315 if let Some(ws_root) = find_workspace_root(cargo_toml) {
317 let ws_content = fs::read_to_string(&ws_root)
318 .map_err(|e| format!("cannot read workspace {}: {e}", ws_root.display()))?;
319 if let Some(v) = parse_flodl_dep(&ws_content)? {
320 return Ok(v);
321 }
322 }
323
324 Err(format!(
325 "no flodl dependency found in {}.\n\n\
326 fdl add flodl-hf needs to pin flodl-hf to the same version as \
327 flodl. Add `flodl = \"X.Y.Z\"` to [dependencies] first, or run \
328 `fdl init <name>` to scaffold a flodl project.",
329 cargo_toml.display(),
330 ))
331}
332
333fn parse_flodl_dep(content: &str) -> Result<Option<String>, String> {
339 let lines: Vec<&str> = content.lines().collect();
340
341 let mut in_dep_table = false;
345 for line in &lines {
346 let t = line.trim();
347 if t.starts_with('[') {
348 in_dep_table = matches!(
350 t,
351 "[dependencies]" | "[workspace.dependencies]" | "[dev-dependencies]",
352 );
353 continue;
354 }
355 if !in_dep_table {
356 continue;
357 }
358 let after_key = match t.strip_prefix("flodl") {
360 Some(rest) => rest.trim_start(),
361 None => continue,
362 };
363 let Some(rhs) = after_key.strip_prefix('=') else {
364 continue;
365 };
366 let rhs = rhs.trim();
367
368 if let Some(v) = rhs.strip_prefix('"').and_then(|r| r.strip_suffix('"')) {
370 return Ok(Some(v.to_string()));
371 }
372 if let Some(v) = extract_version_from_table(rhs) {
373 return Ok(Some(v));
374 }
375 if rhs.contains("workspace") && rhs.contains("true") {
376 return Ok(None);
378 }
379 if rhs.contains("git =") || rhs.contains("git=") {
380 return Err(
381 "flodl is declared as a git dependency. \
382 fdl add flodl-hf needs a pinnable crates.io version. \
383 Switch to `flodl = \"X.Y.Z\"` first."
384 .into(),
385 );
386 }
387 if rhs.contains("path =") || rhs.contains("path=") {
388 return Err(
391 "flodl is declared as a path dependency only. \
392 Add an explicit `version = \"X.Y.Z\"` so fdl add can \
393 pin the matching flodl-hf release."
394 .into(),
395 );
396 }
397 }
398 Ok(None)
399}
400
401fn extract_version_from_table(rhs: &str) -> Option<String> {
405 let rhs = rhs.strip_prefix('{')?.strip_suffix('}')?;
406 for part in rhs.split(',') {
407 let part = part.trim();
408 let Some(after) = part.strip_prefix("version") else {
409 continue;
410 };
411 let after = after.trim_start();
412 let Some(after) = after.strip_prefix('=') else {
413 continue;
414 };
415 let after = after.trim_start();
416 let Some(v) = after.strip_prefix('"').and_then(|r| r.strip_suffix('"')) else {
417 continue;
418 };
419 return Some(v.to_string());
420 }
421 None
422}
423
424fn find_workspace_root(from: &Path) -> Option<PathBuf> {
427 let mut dir = from.parent()?.parent()?.to_path_buf();
428 loop {
429 let candidate = dir.join("Cargo.toml");
430 if candidate.exists() {
431 if let Ok(content) = fs::read_to_string(&candidate) {
432 if content.lines().any(|l| l.trim() == "[workspace]") {
433 return Some(candidate);
434 }
435 }
436 }
437 if !dir.pop() {
438 return None;
439 }
440 }
441}
442
443fn substitute_version(template: &str, version: &str) -> String {
444 template.replace("{{FLODL_VERSION}}", version)
445}
446
447fn write_file(path: &Path, content: &str) -> Result<(), String> {
448 fs::write(path, content).map_err(|e| format!("cannot write {}: {e}", path.display()))
449}
450
451fn print_next_steps(version: &str, mode: ProjectMode) {
452 println!();
453 println!(
454 "Scaffolded flodl-hf/ playground (flodl {version}, {} mode).",
455 match mode {
456 ProjectMode::Docker => "Docker",
457 ProjectMode::Native => "native",
458 },
459 );
460 println!();
461 println!("Next steps:");
462 println!(" fdl flodl-hf classify # default RoBERTa sentiment checkpoint");
463 println!(" fdl flodl-hf classify -- bert-base-uncased # any other BERT-family repo id");
464 println!();
465 println!("(Or `cd flodl-hf` and run `fdl classify` directly.)");
466 println!();
467 println!("See flodl-hf/README.md for feature flavors (offline / vision-only),");
468 println!("`.bin` to safetensors conversion for older checkpoints, and how to wire");
469 println!("flodl-hf into your main crate when you're ready (`fdl add flodl-hf --install`).");
470}
471
472#[cfg(test)]
473mod tests {
474 use super::*;
475
476 #[test]
477 fn parse_plain_version_string() {
478 let c = r#"
479[dependencies]
480flodl = "0.6.0"
481other = "1.0"
482"#;
483 assert_eq!(parse_flodl_dep(c).unwrap(), Some("0.6.0".into()));
484 }
485
486 #[test]
487 fn parse_table_version() {
488 let c = r#"
489[dependencies]
490flodl = { version = "0.5.1", features = ["cuda"] }
491"#;
492 assert_eq!(parse_flodl_dep(c).unwrap(), Some("0.5.1".into()));
493 }
494
495 #[test]
496 fn parse_workspace_inheritance_returns_none() {
497 let c = r#"
498[dependencies]
499flodl = { workspace = true }
500"#;
501 assert_eq!(parse_flodl_dep(c).unwrap(), None);
503 }
504
505 #[test]
506 fn parse_git_dep_errors() {
507 let c = r#"
508[dependencies]
509flodl = { git = "https://github.com/flodl-labs/flodl" }
510"#;
511 let err = parse_flodl_dep(c).unwrap_err();
512 assert!(err.contains("git dependency"), "got: {err}");
513 }
514
515 #[test]
516 fn parse_no_flodl_returns_none() {
517 let c = r#"
518[dependencies]
519other = "1.0"
520"#;
521 assert_eq!(parse_flodl_dep(c).unwrap(), None);
522 }
523
524 #[test]
525 fn parse_ignores_flodl_hf_and_flodl_sys() {
526 let c = r#"
529[dependencies]
530flodl-hf = "0.6.0"
531flodl-sys = "0.6.0"
532"#;
533 assert_eq!(parse_flodl_dep(c).unwrap(), None);
534 }
535
536 #[test]
537 fn parse_ignores_non_dep_tables() {
538 let c = r#"
539[package]
540flodl = "0.6.0" # not actually a dep; this is bogus but must not match
541"#;
542 assert_eq!(parse_flodl_dep(c).unwrap(), None);
543 }
544
545 #[test]
546 fn substitute_version_replaces_all_occurrences() {
547 let t = "flodl = \"={{FLODL_VERSION}}\"\nflodl-hf = \"={{FLODL_VERSION}}\"";
548 let out = substitute_version(t, "0.6.0");
549 assert_eq!(out, "flodl = \"=0.6.0\"\nflodl-hf = \"=0.6.0\"");
550 }
551
552 #[test]
553 fn render_fdl_yml_docker_preserves_docker_lines() {
554 let t = "commands:\n classify:\n run: cargo run --release\n docker: dev\n";
555 assert_eq!(render_fdl_yml(t, ProjectMode::Docker), t);
556 }
557
558 #[test]
559 fn render_fdl_yml_native_strips_docker_lines() {
560 let t = "commands:\n classify:\n run: cargo run --release\n docker: dev\n check:\n run: cargo check\n docker: dev\n";
561 let out = render_fdl_yml(t, ProjectMode::Native);
562 assert!(
563 !out.contains("docker: dev"),
564 "native output must not contain docker: dev lines: {out}"
565 );
566 assert!(out.contains("cargo run --release"));
569 assert!(out.contains("cargo check"));
570 }
571
572 #[test]
573 fn render_fdl_yml_native_only_strips_exact_docker_line() {
574 let t = "\
577commands:
578 classify:
579 run: cargo run
580 docker: dev
581 other:
582 description: docker: dev isn't a literal directive here
583 docker: hf-parity
584";
585 let out = render_fdl_yml(t, ProjectMode::Native);
586 assert!(!out.contains(" docker: dev\n"), "exact match stripped: {out}");
587 assert!(out.contains("hf-parity"), "other services preserved: {out}");
588 assert!(
589 out.contains("docker: dev isn't a literal"),
590 "description text preserved: {out}",
591 );
592 }
593
594 fn temp_project(tag: &str) -> PathBuf {
597 use std::sync::atomic::{AtomicU64, Ordering};
598 static N: AtomicU64 = AtomicU64::new(0);
599 let n = N.fetch_add(1, Ordering::Relaxed);
600 let pid = std::process::id();
601 let dir = std::env::temp_dir().join(format!("fdl-add-test-{pid}-{n}-{tag}"));
602 let _ = fs::remove_dir_all(&dir);
603 fs::create_dir_all(&dir).unwrap();
604 fs::write(
605 dir.join("Cargo.toml"),
606 "[package]\nname = \"x\"\nversion = \"0.1.0\"\nedition = \"2024\"\n\n[dependencies]\nflodl = \"0.5.2\"\n",
607 )
608 .unwrap();
609 fs::write(
610 dir.join("fdl.yml"),
611 "description: test project\n\ncommands:\n build:\n run: cargo build\n",
612 )
613 .unwrap();
614 dir
615 }
616
617 #[test]
618 fn install_appends_dep_and_is_idempotent() {
619 let dir = temp_project("install-idem");
620 install_flodl_hf_at(&dir).unwrap();
621 let toml = fs::read_to_string(dir.join("Cargo.toml")).unwrap();
622 assert!(toml.contains("flodl-hf = \"=0.5.2\""), "first install: {toml}");
623
624 install_flodl_hf_at(&dir).unwrap();
626 let toml2 = fs::read_to_string(dir.join("Cargo.toml")).unwrap();
627 assert_eq!(toml, toml2, "install is idempotent");
628
629 let _ = fs::remove_dir_all(&dir);
630 }
631
632 #[test]
633 fn install_errors_without_cargo_toml() {
634 use std::sync::atomic::{AtomicU64, Ordering};
635 static N: AtomicU64 = AtomicU64::new(9000);
636 let n = N.fetch_add(1, Ordering::Relaxed);
637 let pid = std::process::id();
638 let dir = std::env::temp_dir().join(format!("fdl-add-test-no-cargo-{pid}-{n}"));
639 let _ = fs::remove_dir_all(&dir);
640 fs::create_dir_all(&dir).unwrap();
641 let err = install_flodl_hf_at(&dir).unwrap_err();
642 assert!(err.contains("no Cargo.toml"), "got: {err}");
643 let _ = fs::remove_dir_all(&dir);
644 }
645
646 #[test]
647 fn playground_links_root_fdl_yml() {
648 let dir = temp_project("playground-link");
649 add_flodl_hf_at(&dir).unwrap();
650 let yml = fs::read_to_string(dir.join("fdl.yml")).unwrap();
651 assert!(yml.contains("flodl-hf:"), "linked into root fdl.yml: {yml}");
652 assert!(yml.contains("build:"));
654 assert!(dir.join("flodl-hf/Cargo.toml").exists());
656 assert!(dir.join("flodl-hf/fdl.yml").exists());
657 let _ = fs::remove_dir_all(&dir);
658 }
659}