1use std::collections::BTreeMap;
10use std::fs;
11use std::path::{Path, PathBuf};
12use std::process::{Command, Stdio};
13
14#[derive(Debug)]
20struct FnSig {
21 name: String,
22 signature: String,
23}
24
25#[derive(Debug)]
27struct ApiType {
28 name: String,
29 category: &'static str,
30 file: String,
31 doc_summary: String,
32 doc_examples: Vec<String>,
33 constructors: Vec<FnSig>,
34 methods: Vec<FnSig>,
35 builder_methods: Vec<FnSig>,
36 traits: Vec<String>,
37}
38
39struct ApiRef {
41 version: String,
42 types: Vec<ApiType>,
43}
44
45pub fn find_flodl_src(explicit: Option<&str>) -> Option<PathBuf> {
56 if let Some(p) = explicit {
57 let path = PathBuf::from(p);
58 if path.is_dir() {
59 return Some(path);
60 }
61 }
62
63 let mut dir = std::env::current_dir().ok()?;
65 for _ in 0..5 {
66 let candidate = dir.join("flodl/src");
67 if candidate.join("lib.rs").is_file() {
68 return Some(candidate);
69 }
70 if !dir.pop() {
71 break;
72 }
73 }
74
75 if let Some(home) = home_dir() {
77 let registry = home.join(".cargo/registry/src");
78 if registry.is_dir() {
79 if let Ok(entries) = fs::read_dir(®istry) {
81 for index_dir in entries.flatten() {
82 if let Ok(crates) = fs::read_dir(index_dir.path()) {
83 let mut best: Option<PathBuf> = None;
84 for entry in crates.flatten() {
85 let name = entry.file_name().to_string_lossy().to_string();
86 if name.starts_with("flodl-") && !name.starts_with("flodl-sys") && !name.starts_with("flodl-cli") {
87 let src = entry.path().join("src");
88 if src.join("lib.rs").is_file() {
89 best = Some(src);
90 }
91 }
92 }
93 if best.is_some() {
94 return best;
95 }
96 }
97 }
98 }
99 }
100 }
101
102 if let Some(tag) = fetch_latest_tag() {
104 if let Some(cache) = cache_dir(&tag) {
105 if let Some(src) = find_src_in_cache(&cache) {
106 return Some(src);
107 }
108 }
109 match download_source(&tag) {
111 Ok(src) => return Some(src),
112 Err(e) => eprintln!("warning: could not download source: {}", e),
113 }
114 }
115
116 None
117}
118
119fn home_dir() -> Option<PathBuf> {
120 std::env::var_os("HOME")
121 .or_else(|| std::env::var_os("USERPROFILE"))
122 .map(PathBuf::from)
123}
124
125const REPO: &str = "fab2s/floDl";
130
131fn fetch_latest_tag() -> Option<String> {
133 let output = Command::new("curl")
135 .args(["-sI", &format!("https://github.com/{}/releases/latest", REPO)])
136 .stdout(Stdio::piped())
137 .stderr(Stdio::null())
138 .output()
139 .ok()?;
140
141 let stdout = String::from_utf8_lossy(&output.stdout);
142 for line in stdout.lines() {
143 let lower = line.to_lowercase();
144 if lower.starts_with("location:") {
145 let tag = line.rsplit('/').next()?.trim();
147 if !tag.is_empty() {
148 return Some(tag.to_string());
149 }
150 }
151 }
152 None
153}
154
155fn cache_dir(tag: &str) -> Option<PathBuf> {
157 let home = home_dir()?;
158 let flodl_home = std::env::var("FLODL_HOME")
159 .map(PathBuf::from)
160 .unwrap_or_else(|_| home.join(".flodl"));
161 Some(flodl_home.join("api-ref-cache").join(tag))
162}
163
164fn download_source(tag: &str) -> Result<PathBuf, String> {
167 let cache = cache_dir(tag)
168 .ok_or_else(|| "cannot determine home directory".to_string())?;
169
170 let src_dir = find_src_in_cache(&cache);
172 if let Some(src) = src_dir {
173 return Ok(src);
174 }
175
176 eprintln!("Downloading flodl {} source from GitHub...", tag);
177
178 let zip_url = format!(
179 "https://github.com/{}/archive/refs/tags/{}.zip",
180 REPO, tag
181 );
182
183 fs::create_dir_all(&cache)
184 .map_err(|e| format!("cannot create cache dir: {}", e))?;
185
186 let zip_path = cache.join("source.zip");
187 crate::util::http::download_file(&zip_url, &zip_path)?;
188
189 eprintln!("Extracting...");
190 crate::util::archive::extract_zip(&zip_path, &cache)?;
191
192 let _ = fs::remove_file(&zip_path);
194
195 find_src_in_cache(&cache)
196 .ok_or_else(|| "downloaded archive does not contain flodl/src/lib.rs".to_string())
197}
198
199fn find_src_in_cache(cache: &Path) -> Option<PathBuf> {
202 if !cache.is_dir() {
203 return None;
204 }
205 let direct = cache.join("flodl/src");
207 if direct.join("lib.rs").is_file() {
208 return Some(direct);
209 }
210 if let Ok(entries) = fs::read_dir(cache) {
212 for entry in entries.flatten() {
213 let path = entry.path();
214 if path.is_dir() {
215 let candidate = path.join("flodl/src");
216 if candidate.join("lib.rs").is_file() {
217 return Some(candidate);
218 }
219 }
220 }
221 }
222 None
223}
224
225fn categorize(rel_path: &str) -> &'static str {
231 if rel_path.contains("loss") {
232 "losses"
233 } else if rel_path.contains("optim") {
234 "optimizers"
235 } else if rel_path.contains("scheduler") {
236 "schedulers"
237 } else if rel_path.contains("nn/") || rel_path.starts_with("nn/") {
238 "modules"
239 } else if rel_path.starts_with("tensor") {
240 "tensor"
241 } else if rel_path.starts_with("autograd") {
242 "autograd"
243 } else if rel_path.starts_with("graph") {
244 "graph"
245 } else if rel_path.starts_with("distributed") {
246 "distributed"
247 } else if rel_path.starts_with("data") {
248 "data"
249 } else {
250 "other"
251 }
252}
253
254fn extract_docs(lines: &[&str], item_line: usize) -> (String, Vec<String>) {
257 let mut doc_lines = Vec::new();
259 let mut i = item_line.saturating_sub(1);
260 loop {
261 let line = lines[i].trim();
262 if line.starts_with("///") {
263 let text = line.trim_start_matches("///");
264 let text = text.strip_prefix(' ').unwrap_or(text);
266 doc_lines.push(text.to_string());
267 } else if line.starts_with("#[") || line.is_empty() {
268 if !doc_lines.is_empty() && line.is_empty() {
269 break;
270 }
271 } else {
272 break;
273 }
274 if i == 0 {
275 break;
276 }
277 i -= 1;
278 }
279 doc_lines.reverse();
280
281 let summary = doc_lines.first().cloned().unwrap_or_default();
282
283 let mut examples = Vec::new();
285 let mut in_code = false;
286 let mut current_block = String::new();
287
288 for line in &doc_lines {
289 if line.starts_with("```") {
290 if in_code {
291 if !current_block.trim().is_empty() {
293 examples.push(current_block.trim().to_string());
294 }
295 current_block.clear();
296 in_code = false;
297 } else {
298 in_code = true;
299 }
300 } else if in_code {
301 if !current_block.is_empty() {
302 current_block.push('\n');
303 }
304 current_block.push_str(line);
305 }
306 }
307
308 (summary, examples)
309}
310
311fn extract_fn_sig(line: &str) -> Option<String> {
313 let trimmed = line.trim();
314 let start = if trimmed.contains("pub fn ") {
316 trimmed.find("pub fn ")?
317 } else if trimmed.contains("pub const fn ") {
318 trimmed.find("pub const fn ")?
319 } else {
320 return None;
321 };
322
323 let sig = &trimmed[start..];
324 let sig = sig.trim_end_matches('{').trim_end_matches("where").trim();
326 Some(sig.to_string())
327}
328
329fn extract_fn_name(sig: &str) -> String {
331 let after_fn = sig.split("fn ").nth(1).unwrap_or("");
333 let name_end = after_fn.find('(').unwrap_or(after_fn.len());
334 let name_end = name_end.min(after_fn.find('<').unwrap_or(name_end));
336 after_fn[..name_end].to_string()
337}
338
339fn parse_file(src_root: &Path, path: &Path) -> Vec<ApiType> {
341 let content = match fs::read_to_string(path) {
342 Ok(c) => c,
343 Err(_) => return Vec::new(),
344 };
345
346 let rel_path = path
347 .strip_prefix(src_root)
348 .unwrap_or(path)
349 .to_string_lossy()
350 .to_string();
351
352 let category = categorize(&rel_path);
353 let lines: Vec<&str> = content.lines().collect();
354 let mut types: BTreeMap<String, ApiType> = BTreeMap::new();
355
356 for (i, line) in lines.iter().enumerate() {
358 let trimmed = line.trim();
359 if let Some(after) = trimmed.strip_prefix("pub struct ") {
360 let name_end = after
361 .find(|c: char| !c.is_alphanumeric() && c != '_')
362 .unwrap_or(after.len());
363 let name = after[..name_end].to_string();
364
365 if name.is_empty() || name.starts_with('_') {
366 continue;
367 }
368
369 if name.ends_with("Inner") || name.ends_with("State") && !name.contains("Trained") {
371 continue;
372 }
373
374 let (doc, examples) = extract_docs(&lines, i);
375
376 types.insert(
377 name.clone(),
378 ApiType {
379 name,
380 category,
381 file: rel_path.clone(),
382 doc_summary: doc,
383 doc_examples: examples,
384 constructors: Vec::new(),
385 methods: Vec::new(),
386 builder_methods: Vec::new(),
387 traits: Vec::new(),
388 },
389 );
390 }
391
392 if let Some(after) = trimmed.strip_prefix("pub enum ") {
394 let name_end = after
395 .find(|c: char| !c.is_alphanumeric() && c != '_')
396 .unwrap_or(after.len());
397 let name = after[..name_end].to_string();
398 if !name.is_empty() && !name.starts_with('_') {
399 let (doc, examples) = extract_docs(&lines, i);
400 types.insert(
401 name.clone(),
402 ApiType {
403 name,
404 category,
405 file: rel_path.clone(),
406 doc_summary: doc,
407 doc_examples: examples,
408 constructors: Vec::new(),
409 methods: Vec::new(),
410 builder_methods: Vec::new(),
411 traits: Vec::new(),
412 },
413 );
414 }
415 }
416 }
417
418 let mut current_impl: Option<(String, Option<String>)> = None; let mut brace_depth: i32 = 0;
421 let mut in_impl = false;
422 let mut in_test = false;
423
424 for line in lines.iter() {
425 let trimmed = line.trim();
426
427 if trimmed.contains("#[cfg(test)]") {
429 in_test = true;
430 }
431 if in_test {
432 if trimmed == "}" && brace_depth <= 1 {
433 in_test = false;
434 }
435 for c in trimmed.chars() {
437 if c == '{' { brace_depth += 1; }
438 if c == '}' { brace_depth -= 1; }
439 }
440 continue;
441 }
442
443 if trimmed.starts_with("impl ") || trimmed.starts_with("impl<") {
445 let impl_str = trimmed.to_string();
446
447 let (type_name, trait_name) = if impl_str.contains(" for ") {
449 let parts: Vec<&str> = impl_str.split(" for ").collect();
451 let trait_part = parts[0]
452 .trim_start_matches("impl ")
453 .trim_start_matches("impl<")
454 .split('>')
455 .next_back()
456 .unwrap_or("")
457 .trim();
458 let trait_name = trait_part.split('<').next().unwrap_or(trait_part).trim();
460 let type_part = parts.get(1).unwrap_or(&"");
461 let type_name = type_part
462 .split(|c: char| !c.is_alphanumeric() && c != '_')
463 .next()
464 .unwrap_or("")
465 .trim();
466 (type_name.to_string(), Some(trait_name.to_string()))
467 } else {
468 let after_impl = impl_str
470 .trim_start_matches("impl<")
471 .split('>')
472 .next_back()
473 .unwrap_or(impl_str.strip_prefix("impl ").unwrap_or(&impl_str));
474 let after_impl = after_impl
475 .strip_prefix("impl ")
476 .unwrap_or(after_impl.trim());
477 let type_name = after_impl
478 .split(|c: char| !c.is_alphanumeric() && c != '_')
479 .next()
480 .unwrap_or("")
481 .trim();
482 (type_name.to_string(), None)
483 };
484
485 if types.contains_key(&type_name) {
486 current_impl = Some((type_name, trait_name));
487 in_impl = true;
488 }
489 }
490
491 for c in trimmed.chars() {
493 if c == '{' {
494 brace_depth += 1;
495 }
496 if c == '}' {
497 brace_depth -= 1;
498 if brace_depth <= 0 && in_impl {
499 in_impl = false;
500 current_impl = None;
501 }
502 }
503 }
504
505 if in_impl && (trimmed.starts_with("pub fn ") || trimmed.starts_with("pub const fn ")) {
507 if let Some((ref type_name, ref trait_name)) = current_impl {
508 if let Some(sig) = extract_fn_sig(trimmed) {
509 let fn_name = extract_fn_name(&sig);
510 let fn_sig = FnSig {
511 name: fn_name.clone(),
512 signature: sig,
513 };
514
515 if let Some(api_type) = types.get_mut(type_name) {
516 if let Some(t) = &trait_name {
518 if !api_type.traits.contains(t) {
519 api_type.traits.push(t.clone());
520 }
521 }
522
523 if fn_name == "new"
525 || fn_name == "on_device"
526 || fn_name == "no_bias"
527 || fn_name == "no_bias_on_device"
528 || fn_name == "configure"
529 || fn_name == "default"
530 {
531 api_type.constructors.push(fn_sig);
532 } else if fn_name.starts_with("with_") || fn_name == "done" || fn_name == "build" {
533 api_type.builder_methods.push(fn_sig);
534 } else {
535 api_type.methods.push(fn_sig);
536 }
537 }
538 }
539 }
540 }
541 }
542
543 let mut free_fns: Vec<FnSig> = Vec::new();
546 let mut depth: i32 = 0;
547 let mut in_test_block = false;
548
549 for (i, line) in lines.iter().enumerate() {
550 let trimmed = line.trim();
551
552 if trimmed.contains("#[cfg(test)]") {
553 in_test_block = true;
554 }
555
556 for c in trimmed.chars() {
557 if c == '{' { depth += 1; }
558 if c == '}' { depth -= 1; }
559 }
560
561 if in_test_block {
562 if depth <= 0 { in_test_block = false; }
563 continue;
564 }
565
566 if depth <= 1 && trimmed.starts_with("pub fn ") {
568 if let Some(sig) = extract_fn_sig(trimmed) {
569 let fn_name = extract_fn_name(&sig);
570 let (doc, _) = extract_docs(&lines, i);
571 free_fns.push(FnSig {
572 name: format!("{} -- {}", fn_name, doc),
573 signature: sig,
574 });
575 }
576 }
577 }
578
579 if !free_fns.is_empty() {
580 let file_stem = std::path::Path::new(&rel_path)
582 .file_stem()
583 .unwrap_or_default()
584 .to_string_lossy()
585 .to_string();
586
587 let label = match file_stem.as_str() {
588 "mod" => {
589 std::path::Path::new(&rel_path)
591 .parent()
592 .and_then(|p| p.file_name())
593 .unwrap_or_default()
594 .to_string_lossy()
595 .to_string()
596 }
597 other => other.to_string(),
598 };
599
600 types.insert(
601 format!("{}()", label),
602 ApiType {
603 name: format!("{} (functions)", label),
604 category: categorize(&rel_path),
605 file: rel_path,
606 doc_summary: String::new(),
607 doc_examples: Vec::new(),
608 constructors: Vec::new(),
609 methods: free_fns,
610 builder_methods: Vec::new(),
611 traits: Vec::new(),
612 },
613 );
614 }
615
616 types.into_values().collect()
617}
618
619fn parse_source_tree(src_root: &Path) -> Vec<ApiType> {
621 let mut all_types = Vec::new();
622 walk_dir(src_root, src_root, &mut all_types);
623 all_types.sort_by(|a, b| a.category.cmp(b.category).then(a.name.cmp(&b.name)));
625 all_types
626}
627
628fn walk_dir(root: &Path, dir: &Path, types: &mut Vec<ApiType>) {
629 let entries = match fs::read_dir(dir) {
630 Ok(e) => e,
631 Err(_) => return,
632 };
633 for entry in entries.flatten() {
634 let path = entry.path();
635 if path.is_dir() {
636 walk_dir(root, &path, types);
637 } else if path.extension().is_some_and(|e| e == "rs") {
638 let mut file_types = parse_file(root, &path);
639 types.append(&mut file_types);
640 }
641 }
642}
643
644fn get_version(src_root: &Path) -> String {
649 let crate_dir = src_root.parent().unwrap_or(src_root);
651 for dir in &[crate_dir, crate_dir.parent().unwrap_or(crate_dir)] {
652 let cargo_toml = dir.join("Cargo.toml");
653 if let Ok(content) = fs::read_to_string(cargo_toml) {
654 for line in content.lines() {
656 let trimmed = line.trim();
657 if trimmed.starts_with("version") && trimmed.contains('"') && !trimmed.contains("workspace") {
658 if let Some(v) = trimmed.split('"').nth(1) {
659 return v.to_string();
660 }
661 }
662 }
663 }
664 }
665 "unknown".to_string()
666}
667
668fn print_text(api: &ApiRef) {
669 println!("flodl API Reference v{}", api.version);
670 println!("{}", "=".repeat(40));
671 println!();
672
673 let mut by_category: BTreeMap<&str, Vec<&ApiType>> = BTreeMap::new();
674 for t in &api.types {
675 by_category.entry(t.category).or_default().push(t);
676 }
677
678 for (category, types) in &by_category {
679 println!("## {}", category_title(category));
680 println!();
681
682 for t in types {
683 if t.constructors.is_empty() && t.methods.is_empty() && t.builder_methods.is_empty() {
685 continue;
686 }
687
688 print!("### {}", t.name);
689 if !t.traits.is_empty() {
690 print!(" (implements: {})", t.traits.join(", "));
691 }
692 println!();
693
694 if !t.doc_summary.is_empty() {
695 println!(" {}", t.doc_summary);
696 }
697 println!(" file: {}", t.file);
698
699 if !t.constructors.is_empty() {
700 println!(" constructors:");
701 for f in &t.constructors {
702 println!(" {}", f.signature);
703 }
704 }
705 if !t.builder_methods.is_empty() {
706 println!(" builder:");
707 for f in &t.builder_methods {
708 println!(" .{}()", f.name);
709 }
710 }
711 if !t.methods.is_empty() {
712 println!(" methods:");
713 for f in &t.methods {
714 println!(" {}", f.signature);
715 }
716 }
717 if !t.doc_examples.is_empty() {
718 println!(" examples:");
719 for (ei, ex) in t.doc_examples.iter().enumerate() {
720 if ei > 0 {
721 println!();
722 }
723 for line in ex.lines() {
724 println!(" {}", line);
725 }
726 }
727 }
728 println!();
729 }
730 }
731}
732
733fn print_json(api: &ApiRef) {
734 print!("{{\"version\":\"{}\",\"types\":[", escape_json(&api.version));
735
736 for (i, t) in api.types.iter().enumerate() {
737 if t.constructors.is_empty() && t.methods.is_empty() && t.builder_methods.is_empty() {
738 continue;
739 }
740
741 if i > 0 {
742 print!(",");
743 }
744
745 print!(
746 "{{\"name\":\"{}\",\"category\":\"{}\",\"file\":\"{}\",\"doc\":\"{}\",",
747 escape_json(&t.name),
748 escape_json(t.category),
749 escape_json(&t.file),
750 escape_json(&t.doc_summary),
751 );
752
753 print!("\"traits\":[{}],",
754 t.traits.iter()
755 .map(|s| format!("\"{}\"", escape_json(s)))
756 .collect::<Vec<_>>()
757 .join(",")
758 );
759
760 print!("\"constructors\":[{}],",
761 t.constructors.iter()
762 .map(|f| format!("{{\"name\":\"{}\",\"sig\":\"{}\"}}", escape_json(&f.name), escape_json(&f.signature)))
763 .collect::<Vec<_>>()
764 .join(",")
765 );
766
767 print!("\"builder_methods\":[{}],",
768 t.builder_methods.iter()
769 .map(|f| format!("\"{}\"", escape_json(&f.name)))
770 .collect::<Vec<_>>()
771 .join(",")
772 );
773
774 print!("\"methods\":[{}],",
775 t.methods.iter()
776 .map(|f| format!("{{\"name\":\"{}\",\"sig\":\"{}\"}}", escape_json(&f.name), escape_json(&f.signature)))
777 .collect::<Vec<_>>()
778 .join(",")
779 );
780
781 print!("\"examples\":[{}]",
782 t.doc_examples.iter()
783 .map(|e| format!("\"{}\"", escape_json(e)))
784 .collect::<Vec<_>>()
785 .join(",")
786 );
787
788 print!("}}");
789 }
790
791 println!("]}}");
792}
793
794fn category_title(cat: &str) -> &str {
795 match cat {
796 "modules" => "Modules (nn)",
797 "losses" => "Losses",
798 "optimizers" => "Optimizers",
799 "schedulers" => "Schedulers",
800 "tensor" => "Tensor",
801 "autograd" => "Autograd",
802 "graph" => "Graph",
803 "distributed" => "Distributed",
804 "data" => "Data",
805 other => other,
806 }
807}
808
809fn escape_json(s: &str) -> String {
810 s.replace('\\', "\\\\")
811 .replace('"', "\\\"")
812 .replace('\n', "\\n")
813 .replace('\r', "")
814 .replace('\t', "\\t")
815}
816
817pub fn run(json: bool, path: Option<&str>) -> Result<(), String> {
822 let src_root = find_flodl_src(path)
823 .ok_or_else(|| {
824 "Could not find flodl source. Run from a flodl checkout, \
825 or pass --path <flodl/src/>."
826 .to_string()
827 })?;
828
829 let version = get_version(&src_root);
830 let types = parse_source_tree(&src_root);
831
832 let api = ApiRef { version, types };
833
834 if json {
835 print_json(&api);
836 } else {
837 print_text(&api);
838 }
839
840 Ok(())
841}