use std::collections::BTreeMap;
use std::fs;
use std::path::{Path, PathBuf};
use std::process::{Command, Stdio};
#[derive(Debug)]
struct FnSig {
name: String,
signature: String,
}
#[derive(Debug)]
struct ApiType {
name: String,
category: &'static str,
file: String,
doc_summary: String,
doc_examples: Vec<String>,
constructors: Vec<FnSig>,
methods: Vec<FnSig>,
builder_methods: Vec<FnSig>,
traits: Vec<String>,
}
struct ApiRef {
version: String,
types: Vec<ApiType>,
}
pub fn find_flodl_src(explicit: Option<&str>) -> Option<PathBuf> {
if let Some(p) = explicit {
let path = PathBuf::from(p);
if path.is_dir() {
return Some(path);
}
}
let mut dir = std::env::current_dir().ok()?;
for _ in 0..5 {
let candidate = dir.join("flodl/src");
if candidate.join("lib.rs").is_file() {
return Some(candidate);
}
if !dir.pop() {
break;
}
}
if let Some(home) = home_dir() {
let registry = home.join(".cargo/registry/src");
if registry.is_dir() {
if let Ok(entries) = fs::read_dir(®istry) {
for index_dir in entries.flatten() {
if let Ok(crates) = fs::read_dir(index_dir.path()) {
let mut best: Option<PathBuf> = None;
for entry in crates.flatten() {
let name = entry.file_name().to_string_lossy().to_string();
if name.starts_with("flodl-") && !name.starts_with("flodl-sys") && !name.starts_with("flodl-cli") {
let src = entry.path().join("src");
if src.join("lib.rs").is_file() {
best = Some(src);
}
}
}
if best.is_some() {
return best;
}
}
}
}
}
}
if let Some(tag) = fetch_latest_tag() {
if let Some(cache) = cache_dir(&tag) {
if let Some(src) = find_src_in_cache(&cache) {
return Some(src);
}
}
match download_source(&tag) {
Ok(src) => return Some(src),
Err(e) => eprintln!("warning: could not download source: {}", e),
}
}
None
}
fn home_dir() -> Option<PathBuf> {
std::env::var_os("HOME")
.or_else(|| std::env::var_os("USERPROFILE"))
.map(PathBuf::from)
}
const REPO: &str = "fab2s/floDl";
fn fetch_latest_tag() -> Option<String> {
let output = Command::new("curl")
.args(["-sI", &format!("https://github.com/{}/releases/latest", REPO)])
.stdout(Stdio::piped())
.stderr(Stdio::null())
.output()
.ok()?;
let stdout = String::from_utf8_lossy(&output.stdout);
for line in stdout.lines() {
let lower = line.to_lowercase();
if lower.starts_with("location:") {
let tag = line.rsplit('/').next()?.trim();
if !tag.is_empty() {
return Some(tag.to_string());
}
}
}
None
}
fn cache_dir(tag: &str) -> Option<PathBuf> {
let home = home_dir()?;
let flodl_home = std::env::var("FLODL_HOME")
.map(PathBuf::from)
.unwrap_or_else(|_| home.join(".flodl"));
Some(flodl_home.join("api-ref-cache").join(tag))
}
fn download_source(tag: &str) -> Result<PathBuf, String> {
let cache = cache_dir(tag)
.ok_or_else(|| "cannot determine home directory".to_string())?;
let src_dir = find_src_in_cache(&cache);
if let Some(src) = src_dir {
return Ok(src);
}
eprintln!("Downloading flodl {} source from GitHub...", tag);
let zip_url = format!(
"https://github.com/{}/archive/refs/tags/{}.zip",
REPO, tag
);
fs::create_dir_all(&cache)
.map_err(|e| format!("cannot create cache dir: {}", e))?;
let zip_path = cache.join("source.zip");
crate::util::http::download_file(&zip_url, &zip_path)?;
eprintln!("Extracting...");
crate::util::archive::extract_zip(&zip_path, &cache)?;
let _ = fs::remove_file(&zip_path);
find_src_in_cache(&cache)
.ok_or_else(|| "downloaded archive does not contain flodl/src/lib.rs".to_string())
}
fn find_src_in_cache(cache: &Path) -> Option<PathBuf> {
if !cache.is_dir() {
return None;
}
let direct = cache.join("flodl/src");
if direct.join("lib.rs").is_file() {
return Some(direct);
}
if let Ok(entries) = fs::read_dir(cache) {
for entry in entries.flatten() {
let path = entry.path();
if path.is_dir() {
let candidate = path.join("flodl/src");
if candidate.join("lib.rs").is_file() {
return Some(candidate);
}
}
}
}
None
}
fn categorize(rel_path: &str) -> &'static str {
if rel_path.contains("loss") {
"losses"
} else if rel_path.contains("optim") {
"optimizers"
} else if rel_path.contains("scheduler") {
"schedulers"
} else if rel_path.contains("nn/") || rel_path.starts_with("nn/") {
"modules"
} else if rel_path.starts_with("tensor") {
"tensor"
} else if rel_path.starts_with("autograd") {
"autograd"
} else if rel_path.starts_with("graph") {
"graph"
} else if rel_path.starts_with("distributed") {
"distributed"
} else if rel_path.starts_with("data") {
"data"
} else {
"other"
}
}
fn extract_docs(lines: &[&str], item_line: usize) -> (String, Vec<String>) {
let mut doc_lines = Vec::new();
let mut i = item_line.saturating_sub(1);
loop {
let line = lines[i].trim();
if line.starts_with("///") {
let text = line.trim_start_matches("///");
let text = text.strip_prefix(' ').unwrap_or(text);
doc_lines.push(text.to_string());
} else if line.starts_with("#[") || line.is_empty() {
if !doc_lines.is_empty() && line.is_empty() {
break;
}
} else {
break;
}
if i == 0 {
break;
}
i -= 1;
}
doc_lines.reverse();
let summary = doc_lines.first().cloned().unwrap_or_default();
let mut examples = Vec::new();
let mut in_code = false;
let mut current_block = String::new();
for line in &doc_lines {
if line.starts_with("```") {
if in_code {
if !current_block.trim().is_empty() {
examples.push(current_block.trim().to_string());
}
current_block.clear();
in_code = false;
} else {
in_code = true;
}
} else if in_code {
if !current_block.is_empty() {
current_block.push('\n');
}
current_block.push_str(line);
}
}
(summary, examples)
}
fn extract_fn_sig(line: &str) -> Option<String> {
let trimmed = line.trim();
let start = if trimmed.contains("pub fn ") {
trimmed.find("pub fn ")?
} else if trimmed.contains("pub const fn ") {
trimmed.find("pub const fn ")?
} else {
return None;
};
let sig = &trimmed[start..];
let sig = sig.trim_end_matches('{').trim_end_matches("where").trim();
Some(sig.to_string())
}
fn extract_fn_name(sig: &str) -> String {
let after_fn = sig.split("fn ").nth(1).unwrap_or("");
let name_end = after_fn.find('(').unwrap_or(after_fn.len());
let name_end = name_end.min(after_fn.find('<').unwrap_or(name_end));
after_fn[..name_end].to_string()
}
fn parse_file(src_root: &Path, path: &Path) -> Vec<ApiType> {
let content = match fs::read_to_string(path) {
Ok(c) => c,
Err(_) => return Vec::new(),
};
let rel_path = path
.strip_prefix(src_root)
.unwrap_or(path)
.to_string_lossy()
.to_string();
let category = categorize(&rel_path);
let lines: Vec<&str> = content.lines().collect();
let mut types: BTreeMap<String, ApiType> = BTreeMap::new();
for (i, line) in lines.iter().enumerate() {
let trimmed = line.trim();
if let Some(after) = trimmed.strip_prefix("pub struct ") {
let name_end = after
.find(|c: char| !c.is_alphanumeric() && c != '_')
.unwrap_or(after.len());
let name = after[..name_end].to_string();
if name.is_empty() || name.starts_with('_') {
continue;
}
if name.ends_with("Inner") || name.ends_with("State") && !name.contains("Trained") {
continue;
}
let (doc, examples) = extract_docs(&lines, i);
types.insert(
name.clone(),
ApiType {
name,
category,
file: rel_path.clone(),
doc_summary: doc,
doc_examples: examples,
constructors: Vec::new(),
methods: Vec::new(),
builder_methods: Vec::new(),
traits: Vec::new(),
},
);
}
if let Some(after) = trimmed.strip_prefix("pub enum ") {
let name_end = after
.find(|c: char| !c.is_alphanumeric() && c != '_')
.unwrap_or(after.len());
let name = after[..name_end].to_string();
if !name.is_empty() && !name.starts_with('_') {
let (doc, examples) = extract_docs(&lines, i);
types.insert(
name.clone(),
ApiType {
name,
category,
file: rel_path.clone(),
doc_summary: doc,
doc_examples: examples,
constructors: Vec::new(),
methods: Vec::new(),
builder_methods: Vec::new(),
traits: Vec::new(),
},
);
}
}
}
let mut current_impl: Option<(String, Option<String>)> = None; let mut brace_depth: i32 = 0;
let mut in_impl = false;
let mut in_test = false;
for line in lines.iter() {
let trimmed = line.trim();
if trimmed.contains("#[cfg(test)]") {
in_test = true;
}
if in_test {
if trimmed == "}" && brace_depth <= 1 {
in_test = false;
}
for c in trimmed.chars() {
if c == '{' { brace_depth += 1; }
if c == '}' { brace_depth -= 1; }
}
continue;
}
if trimmed.starts_with("impl ") || trimmed.starts_with("impl<") {
let impl_str = trimmed.to_string();
let (type_name, trait_name) = if impl_str.contains(" for ") {
let parts: Vec<&str> = impl_str.split(" for ").collect();
let trait_part = parts[0]
.trim_start_matches("impl ")
.trim_start_matches("impl<")
.split('>')
.next_back()
.unwrap_or("")
.trim();
let trait_name = trait_part.split('<').next().unwrap_or(trait_part).trim();
let type_part = parts.get(1).unwrap_or(&"");
let type_name = type_part
.split(|c: char| !c.is_alphanumeric() && c != '_')
.next()
.unwrap_or("")
.trim();
(type_name.to_string(), Some(trait_name.to_string()))
} else {
let after_impl = impl_str
.trim_start_matches("impl<")
.split('>')
.next_back()
.unwrap_or(impl_str.strip_prefix("impl ").unwrap_or(&impl_str));
let after_impl = after_impl
.strip_prefix("impl ")
.unwrap_or(after_impl.trim());
let type_name = after_impl
.split(|c: char| !c.is_alphanumeric() && c != '_')
.next()
.unwrap_or("")
.trim();
(type_name.to_string(), None)
};
if types.contains_key(&type_name) {
current_impl = Some((type_name, trait_name));
in_impl = true;
}
}
for c in trimmed.chars() {
if c == '{' {
brace_depth += 1;
}
if c == '}' {
brace_depth -= 1;
if brace_depth <= 0 && in_impl {
in_impl = false;
current_impl = None;
}
}
}
if in_impl && (trimmed.starts_with("pub fn ") || trimmed.starts_with("pub const fn ")) {
if let Some((ref type_name, ref trait_name)) = current_impl {
if let Some(sig) = extract_fn_sig(trimmed) {
let fn_name = extract_fn_name(&sig);
let fn_sig = FnSig {
name: fn_name.clone(),
signature: sig,
};
if let Some(api_type) = types.get_mut(type_name) {
if let Some(t) = &trait_name {
if !api_type.traits.contains(t) {
api_type.traits.push(t.clone());
}
}
if fn_name == "new"
|| fn_name == "on_device"
|| fn_name == "no_bias"
|| fn_name == "no_bias_on_device"
|| fn_name == "configure"
|| fn_name == "default"
{
api_type.constructors.push(fn_sig);
} else if fn_name.starts_with("with_") || fn_name == "done" || fn_name == "build" {
api_type.builder_methods.push(fn_sig);
} else {
api_type.methods.push(fn_sig);
}
}
}
}
}
}
let mut free_fns: Vec<FnSig> = Vec::new();
let mut depth: i32 = 0;
let mut in_test_block = false;
for (i, line) in lines.iter().enumerate() {
let trimmed = line.trim();
if trimmed.contains("#[cfg(test)]") {
in_test_block = true;
}
for c in trimmed.chars() {
if c == '{' { depth += 1; }
if c == '}' { depth -= 1; }
}
if in_test_block {
if depth <= 0 { in_test_block = false; }
continue;
}
if depth <= 1 && trimmed.starts_with("pub fn ") {
if let Some(sig) = extract_fn_sig(trimmed) {
let fn_name = extract_fn_name(&sig);
let (doc, _) = extract_docs(&lines, i);
free_fns.push(FnSig {
name: format!("{} -- {}", fn_name, doc),
signature: sig,
});
}
}
}
if !free_fns.is_empty() {
let file_stem = std::path::Path::new(&rel_path)
.file_stem()
.unwrap_or_default()
.to_string_lossy()
.to_string();
let label = match file_stem.as_str() {
"mod" => {
std::path::Path::new(&rel_path)
.parent()
.and_then(|p| p.file_name())
.unwrap_or_default()
.to_string_lossy()
.to_string()
}
other => other.to_string(),
};
types.insert(
format!("{}()", label),
ApiType {
name: format!("{} (functions)", label),
category: categorize(&rel_path),
file: rel_path,
doc_summary: String::new(),
doc_examples: Vec::new(),
constructors: Vec::new(),
methods: free_fns,
builder_methods: Vec::new(),
traits: Vec::new(),
},
);
}
types.into_values().collect()
}
fn parse_source_tree(src_root: &Path) -> Vec<ApiType> {
let mut all_types = Vec::new();
walk_dir(src_root, src_root, &mut all_types);
all_types.sort_by(|a, b| a.category.cmp(b.category).then(a.name.cmp(&b.name)));
all_types
}
fn walk_dir(root: &Path, dir: &Path, types: &mut Vec<ApiType>) {
let entries = match fs::read_dir(dir) {
Ok(e) => e,
Err(_) => return,
};
for entry in entries.flatten() {
let path = entry.path();
if path.is_dir() {
walk_dir(root, &path, types);
} else if path.extension().is_some_and(|e| e == "rs") {
let mut file_types = parse_file(root, &path);
types.append(&mut file_types);
}
}
}
fn get_version(src_root: &Path) -> String {
let crate_dir = src_root.parent().unwrap_or(src_root);
for dir in &[crate_dir, crate_dir.parent().unwrap_or(crate_dir)] {
let cargo_toml = dir.join("Cargo.toml");
if let Ok(content) = fs::read_to_string(cargo_toml) {
for line in content.lines() {
let trimmed = line.trim();
if trimmed.starts_with("version") && trimmed.contains('"') && !trimmed.contains("workspace") {
if let Some(v) = trimmed.split('"').nth(1) {
return v.to_string();
}
}
}
}
}
"unknown".to_string()
}
fn print_text(api: &ApiRef) {
println!("flodl API Reference v{}", api.version);
println!("{}", "=".repeat(40));
println!();
let mut by_category: BTreeMap<&str, Vec<&ApiType>> = BTreeMap::new();
for t in &api.types {
by_category.entry(t.category).or_default().push(t);
}
for (category, types) in &by_category {
println!("## {}", category_title(category));
println!();
for t in types {
if t.constructors.is_empty() && t.methods.is_empty() && t.builder_methods.is_empty() {
continue;
}
print!("### {}", t.name);
if !t.traits.is_empty() {
print!(" (implements: {})", t.traits.join(", "));
}
println!();
if !t.doc_summary.is_empty() {
println!(" {}", t.doc_summary);
}
println!(" file: {}", t.file);
if !t.constructors.is_empty() {
println!(" constructors:");
for f in &t.constructors {
println!(" {}", f.signature);
}
}
if !t.builder_methods.is_empty() {
println!(" builder:");
for f in &t.builder_methods {
println!(" .{}()", f.name);
}
}
if !t.methods.is_empty() {
println!(" methods:");
for f in &t.methods {
println!(" {}", f.signature);
}
}
if !t.doc_examples.is_empty() {
println!(" examples:");
for (ei, ex) in t.doc_examples.iter().enumerate() {
if ei > 0 {
println!();
}
for line in ex.lines() {
println!(" {}", line);
}
}
}
println!();
}
}
}
fn print_json(api: &ApiRef) {
print!("{{\"version\":\"{}\",\"types\":[", escape_json(&api.version));
for (i, t) in api.types.iter().enumerate() {
if t.constructors.is_empty() && t.methods.is_empty() && t.builder_methods.is_empty() {
continue;
}
if i > 0 {
print!(",");
}
print!(
"{{\"name\":\"{}\",\"category\":\"{}\",\"file\":\"{}\",\"doc\":\"{}\",",
escape_json(&t.name),
escape_json(t.category),
escape_json(&t.file),
escape_json(&t.doc_summary),
);
print!("\"traits\":[{}],",
t.traits.iter()
.map(|s| format!("\"{}\"", escape_json(s)))
.collect::<Vec<_>>()
.join(",")
);
print!("\"constructors\":[{}],",
t.constructors.iter()
.map(|f| format!("{{\"name\":\"{}\",\"sig\":\"{}\"}}", escape_json(&f.name), escape_json(&f.signature)))
.collect::<Vec<_>>()
.join(",")
);
print!("\"builder_methods\":[{}],",
t.builder_methods.iter()
.map(|f| format!("\"{}\"", escape_json(&f.name)))
.collect::<Vec<_>>()
.join(",")
);
print!("\"methods\":[{}],",
t.methods.iter()
.map(|f| format!("{{\"name\":\"{}\",\"sig\":\"{}\"}}", escape_json(&f.name), escape_json(&f.signature)))
.collect::<Vec<_>>()
.join(",")
);
print!("\"examples\":[{}]",
t.doc_examples.iter()
.map(|e| format!("\"{}\"", escape_json(e)))
.collect::<Vec<_>>()
.join(",")
);
print!("}}");
}
println!("]}}");
}
fn category_title(cat: &str) -> &str {
match cat {
"modules" => "Modules (nn)",
"losses" => "Losses",
"optimizers" => "Optimizers",
"schedulers" => "Schedulers",
"tensor" => "Tensor",
"autograd" => "Autograd",
"graph" => "Graph",
"distributed" => "Distributed",
"data" => "Data",
other => other,
}
}
fn escape_json(s: &str) -> String {
s.replace('\\', "\\\\")
.replace('"', "\\\"")
.replace('\n', "\\n")
.replace('\r', "")
.replace('\t', "\\t")
}
pub fn run(json: bool, path: Option<&str>) -> Result<(), String> {
let src_root = find_flodl_src(path)
.ok_or_else(|| {
"Could not find flodl source. Run from a flodl checkout, \
or pass --path <flodl/src/>."
.to_string()
})?;
let version = get_version(&src_root);
let types = parse_source_tree(&src_root);
let api = ApiRef { version, types };
if json {
print_json(&api);
} else {
print_text(&api);
}
Ok(())
}