Skip to main content

flodl_cli/
api_ref.rs

1//! API reference generator: extracts the public API surface from flodl source.
2//!
3//! Parses Rust source files to find pub structs, constructors, methods, and
4//! trait implementations. No external dependencies (string-based parsing).
5//!
6//! Used by the `/port` agent skill to understand what flodl offers, and by
7//! anyone who wants a quick reference without building docs.
8
9use std::collections::BTreeMap;
10use std::fs;
11use std::path::{Path, PathBuf};
12use std::process::{Command, Stdio};
13
14// ---------------------------------------------------------------------------
15// Data model
16// ---------------------------------------------------------------------------
17
18/// A single public function/method signature.
19#[derive(Debug)]
20struct FnSig {
21    name: String,
22    signature: String,
23}
24
25/// A public type extracted from the source.
26#[derive(Debug)]
27struct ApiType {
28    name: String,
29    category: &'static str,
30    file: String,
31    doc_summary: String,
32    doc_examples: Vec<String>,
33    constructors: Vec<FnSig>,
34    methods: Vec<FnSig>,
35    builder_methods: Vec<FnSig>,
36    traits: Vec<String>,
37}
38
39/// Top-level API reference.
40struct ApiRef {
41    version: String,
42    types: Vec<ApiType>,
43}
44
45// ---------------------------------------------------------------------------
46// Source locator
47// ---------------------------------------------------------------------------
48
49/// Find the flodl source directory. Checks (in order):
50/// 1. Explicit path from --path flag
51/// 2. ./flodl/src/ (dev checkout, walk up to 5 levels)
52/// 3. Cargo registry (~/.cargo/registry/src/*/flodl-*/src/)
53/// 4. Cached download (`~/.flodl/api-ref-cache/<tag>/`)
54/// 5. Download from latest GitHub release (cached for next time)
55pub fn find_flodl_src(explicit: Option<&str>) -> Option<PathBuf> {
56    if let Some(p) = explicit {
57        let path = PathBuf::from(p);
58        if path.is_dir() {
59            return Some(path);
60        }
61    }
62
63    // Dev checkout: walk up from cwd looking for flodl/src/lib.rs
64    let mut dir = std::env::current_dir().ok()?;
65    for _ in 0..5 {
66        let candidate = dir.join("flodl/src");
67        if candidate.join("lib.rs").is_file() {
68            return Some(candidate);
69        }
70        if !dir.pop() {
71            break;
72        }
73    }
74
75    // Cargo registry
76    if let Some(home) = home_dir() {
77        let registry = home.join(".cargo/registry/src");
78        if registry.is_dir() {
79            // Find the latest flodl version in registry
80            if let Ok(entries) = fs::read_dir(&registry) {
81                for index_dir in entries.flatten() {
82                    if let Ok(crates) = fs::read_dir(index_dir.path()) {
83                        let mut best: Option<PathBuf> = None;
84                        for entry in crates.flatten() {
85                            let name = entry.file_name().to_string_lossy().to_string();
86                            if name.starts_with("flodl-") && !name.starts_with("flodl-sys") && !name.starts_with("flodl-cli") {
87                                let src = entry.path().join("src");
88                                if src.join("lib.rs").is_file() {
89                                    best = Some(src);
90                                }
91                            }
92                        }
93                        if best.is_some() {
94                            return best;
95                        }
96                    }
97                }
98            }
99        }
100    }
101
102    // Check cached downloads
103    if let Some(tag) = fetch_latest_tag() {
104        if let Some(cache) = cache_dir(&tag) {
105            if let Some(src) = find_src_in_cache(&cache) {
106                return Some(src);
107            }
108        }
109        // Download from GitHub
110        match download_source(&tag) {
111            Ok(src) => return Some(src),
112            Err(e) => eprintln!("warning: could not download source: {}", e),
113        }
114    }
115
116    None
117}
118
119fn home_dir() -> Option<PathBuf> {
120    std::env::var_os("HOME")
121        .or_else(|| std::env::var_os("USERPROFILE"))
122        .map(PathBuf::from)
123}
124
125// ---------------------------------------------------------------------------
126// GitHub source download
127// ---------------------------------------------------------------------------
128
129const REPO: &str = "fab2s/floDl";
130
131/// Get the latest release tag from GitHub.
132fn fetch_latest_tag() -> Option<String> {
133    // curl -sI https://github.com/REPO/releases/latest → Location header has the tag
134    let output = Command::new("curl")
135        .args(["-sI", &format!("https://github.com/{}/releases/latest", REPO)])
136        .stdout(Stdio::piped())
137        .stderr(Stdio::null())
138        .output()
139        .ok()?;
140
141    let stdout = String::from_utf8_lossy(&output.stdout);
142    for line in stdout.lines() {
143        let lower = line.to_lowercase();
144        if lower.starts_with("location:") {
145            // https://github.com/fab2s/floDl/releases/tag/0.3.0
146            let tag = line.rsplit('/').next()?.trim();
147            if !tag.is_empty() {
148                return Some(tag.to_string());
149            }
150        }
151    }
152    None
153}
154
155/// Cache directory for downloaded source: `~/.flodl/api-ref-cache/<tag>/`
156fn cache_dir(tag: &str) -> Option<PathBuf> {
157    let home = home_dir()?;
158    let flodl_home = std::env::var("FLODL_HOME")
159        .map(PathBuf::from)
160        .unwrap_or_else(|_| home.join(".flodl"));
161    Some(flodl_home.join("api-ref-cache").join(tag))
162}
163
164/// Download and extract flodl source from a GitHub release.
165/// Returns the path to the flodl/src/ directory inside the cache.
166fn download_source(tag: &str) -> Result<PathBuf, String> {
167    let cache = cache_dir(tag)
168        .ok_or_else(|| "cannot determine home directory".to_string())?;
169
170    // Check if already cached
171    let src_dir = find_src_in_cache(&cache);
172    if let Some(src) = src_dir {
173        return Ok(src);
174    }
175
176    eprintln!("Downloading flodl {} source from GitHub...", tag);
177
178    let zip_url = format!(
179        "https://github.com/{}/archive/refs/tags/{}.zip",
180        REPO, tag
181    );
182
183    fs::create_dir_all(&cache)
184        .map_err(|e| format!("cannot create cache dir: {}", e))?;
185
186    let zip_path = cache.join("source.zip");
187    crate::util::http::download_file(&zip_url, &zip_path)?;
188
189    eprintln!("Extracting...");
190    crate::util::archive::extract_zip(&zip_path, &cache)?;
191
192    // Clean up zip
193    let _ = fs::remove_file(&zip_path);
194
195    find_src_in_cache(&cache)
196        .ok_or_else(|| "downloaded archive does not contain flodl/src/lib.rs".to_string())
197}
198
199/// Find flodl/src/lib.rs inside a cache directory.
200/// GitHub archives extract to `<repo-name>-<tag>/` (e.g. `floDl-0.3.0/`).
201fn find_src_in_cache(cache: &Path) -> Option<PathBuf> {
202    if !cache.is_dir() {
203        return None;
204    }
205    // Direct check
206    let direct = cache.join("flodl/src");
207    if direct.join("lib.rs").is_file() {
208        return Some(direct);
209    }
210    // GitHub archive layout: cache/<reponame-tag>/flodl/src/
211    if let Ok(entries) = fs::read_dir(cache) {
212        for entry in entries.flatten() {
213            let path = entry.path();
214            if path.is_dir() {
215                let candidate = path.join("flodl/src");
216                if candidate.join("lib.rs").is_file() {
217                    return Some(candidate);
218                }
219            }
220        }
221    }
222    None
223}
224
225// ---------------------------------------------------------------------------
226// Parser
227// ---------------------------------------------------------------------------
228
229/// Categorize a file path into an API category.
230fn categorize(rel_path: &str) -> &'static str {
231    if rel_path.contains("loss") {
232        "losses"
233    } else if rel_path.contains("optim") {
234        "optimizers"
235    } else if rel_path.contains("scheduler") {
236        "schedulers"
237    } else if rel_path.contains("nn/") || rel_path.starts_with("nn/") {
238        "modules"
239    } else if rel_path.starts_with("tensor") {
240        "tensor"
241    } else if rel_path.starts_with("autograd") {
242        "autograd"
243    } else if rel_path.starts_with("graph") {
244        "graph"
245    } else if rel_path.starts_with("distributed") {
246        "distributed"
247    } else if rel_path.starts_with("data") {
248        "data"
249    } else {
250        "other"
251    }
252}
253
254/// Extract doc comments above a pub item.
255/// Returns (summary_line, code_examples).
256fn extract_docs(lines: &[&str], item_line: usize) -> (String, Vec<String>) {
257    // Walk backwards from the item line to find /// comments
258    let mut doc_lines = Vec::new();
259    let mut i = item_line.saturating_sub(1);
260    loop {
261        let line = lines[i].trim();
262        if line.starts_with("///") {
263            let text = line.trim_start_matches("///");
264            // Keep one leading space if present for indentation
265            let text = text.strip_prefix(' ').unwrap_or(text);
266            doc_lines.push(text.to_string());
267        } else if line.starts_with("#[") || line.is_empty() {
268            if !doc_lines.is_empty() && line.is_empty() {
269                break;
270            }
271        } else {
272            break;
273        }
274        if i == 0 {
275            break;
276        }
277        i -= 1;
278    }
279    doc_lines.reverse();
280
281    let summary = doc_lines.first().cloned().unwrap_or_default();
282
283    // Extract code blocks from doc comments
284    let mut examples = Vec::new();
285    let mut in_code = false;
286    let mut current_block = String::new();
287
288    for line in &doc_lines {
289        if line.starts_with("```") {
290            if in_code {
291                // End of code block
292                if !current_block.trim().is_empty() {
293                    examples.push(current_block.trim().to_string());
294                }
295                current_block.clear();
296                in_code = false;
297            } else {
298                in_code = true;
299            }
300        } else if in_code {
301            if !current_block.is_empty() {
302                current_block.push('\n');
303            }
304            current_block.push_str(line);
305        }
306    }
307
308    (summary, examples)
309}
310
311/// Extract a function signature from a line like `pub fn new(a: i64, b: i64) -> Result<Self> {`
312fn extract_fn_sig(line: &str) -> Option<String> {
313    let trimmed = line.trim();
314    // Find the signature between "pub fn" and the opening brace or "where"
315    let start = if trimmed.contains("pub fn ") {
316        trimmed.find("pub fn ")?
317    } else if trimmed.contains("pub const fn ") {
318        trimmed.find("pub const fn ")?
319    } else {
320        return None;
321    };
322
323    let sig = &trimmed[start..];
324    // Trim trailing { or where
325    let sig = sig.trim_end_matches('{').trim_end_matches("where").trim();
326    Some(sig.to_string())
327}
328
329/// Extract a function name from a signature.
330fn extract_fn_name(sig: &str) -> String {
331    // "pub fn new(...)" -> "new"
332    let after_fn = sig.split("fn ").nth(1).unwrap_or("");
333    let name_end = after_fn.find('(').unwrap_or(after_fn.len());
334    // Handle generic parameters
335    let name_end = name_end.min(after_fn.find('<').unwrap_or(name_end));
336    after_fn[..name_end].to_string()
337}
338
339/// Parse a single Rust source file and extract pub types and their API.
340fn parse_file(src_root: &Path, path: &Path) -> Vec<ApiType> {
341    let content = match fs::read_to_string(path) {
342        Ok(c) => c,
343        Err(_) => return Vec::new(),
344    };
345
346    let rel_path = path
347        .strip_prefix(src_root)
348        .unwrap_or(path)
349        .to_string_lossy()
350        .to_string();
351
352    let category = categorize(&rel_path);
353    let lines: Vec<&str> = content.lines().collect();
354    let mut types: BTreeMap<String, ApiType> = BTreeMap::new();
355
356    // Pass 1: find all pub struct declarations
357    for (i, line) in lines.iter().enumerate() {
358        let trimmed = line.trim();
359        if let Some(after) = trimmed.strip_prefix("pub struct ") {
360            let name_end = after
361                .find(|c: char| !c.is_alphanumeric() && c != '_')
362                .unwrap_or(after.len());
363            let name = after[..name_end].to_string();
364
365            if name.is_empty() || name.starts_with('_') {
366                continue;
367            }
368
369            // Skip test helper types, internal types
370            if name.ends_with("Inner") || name.ends_with("State") && !name.contains("Trained") {
371                continue;
372            }
373
374            let (doc, examples) = extract_docs(&lines, i);
375
376            types.insert(
377                name.clone(),
378                ApiType {
379                    name,
380                    category,
381                    file: rel_path.clone(),
382                    doc_summary: doc,
383                    doc_examples: examples,
384                    constructors: Vec::new(),
385                    methods: Vec::new(),
386                    builder_methods: Vec::new(),
387                    traits: Vec::new(),
388                },
389            );
390        }
391
392        // Also capture pub enum
393        if let Some(after) = trimmed.strip_prefix("pub enum ") {
394            let name_end = after
395                .find(|c: char| !c.is_alphanumeric() && c != '_')
396                .unwrap_or(after.len());
397            let name = after[..name_end].to_string();
398            if !name.is_empty() && !name.starts_with('_') {
399                let (doc, examples) = extract_docs(&lines, i);
400                types.insert(
401                    name.clone(),
402                    ApiType {
403                        name,
404                        category,
405                        file: rel_path.clone(),
406                        doc_summary: doc,
407                        doc_examples: examples,
408                        constructors: Vec::new(),
409                        methods: Vec::new(),
410                        builder_methods: Vec::new(),
411                        traits: Vec::new(),
412                    },
413                );
414            }
415        }
416    }
417
418    // Pass 2: find impl blocks and extract pub methods
419    let mut current_impl: Option<(String, Option<String>)> = None; // (type_name, trait_name)
420    let mut brace_depth: i32 = 0;
421    let mut in_impl = false;
422    let mut in_test = false;
423
424    for line in lines.iter() {
425        let trimmed = line.trim();
426
427        // Skip test modules
428        if trimmed.contains("#[cfg(test)]") {
429            in_test = true;
430        }
431        if in_test {
432            if trimmed == "}" && brace_depth <= 1 {
433                in_test = false;
434            }
435            // Count braces even in test to track depth
436            for c in trimmed.chars() {
437                if c == '{' { brace_depth += 1; }
438                if c == '}' { brace_depth -= 1; }
439            }
440            continue;
441        }
442
443        // Detect impl blocks
444        if trimmed.starts_with("impl ") || trimmed.starts_with("impl<") {
445            let impl_str = trimmed.to_string();
446
447            // Parse: "impl TypeName {" or "impl TraitName for TypeName {"
448            let (type_name, trait_name) = if impl_str.contains(" for ") {
449                // impl Trait for Type
450                let parts: Vec<&str> = impl_str.split(" for ").collect();
451                let trait_part = parts[0]
452                    .trim_start_matches("impl ")
453                    .trim_start_matches("impl<")
454                    .split('>')
455                    .next_back()
456                    .unwrap_or("")
457                    .trim();
458                // Remove generic bounds from trait name
459                let trait_name = trait_part.split('<').next().unwrap_or(trait_part).trim();
460                let type_part = parts.get(1).unwrap_or(&"");
461                let type_name = type_part
462                    .split(|c: char| !c.is_alphanumeric() && c != '_')
463                    .next()
464                    .unwrap_or("")
465                    .trim();
466                (type_name.to_string(), Some(trait_name.to_string()))
467            } else {
468                // impl Type
469                let after_impl = impl_str
470                    .trim_start_matches("impl<")
471                    .split('>')
472                    .next_back()
473                    .unwrap_or(impl_str.strip_prefix("impl ").unwrap_or(&impl_str));
474                let after_impl = after_impl
475                    .strip_prefix("impl ")
476                    .unwrap_or(after_impl.trim());
477                let type_name = after_impl
478                    .split(|c: char| !c.is_alphanumeric() && c != '_')
479                    .next()
480                    .unwrap_or("")
481                    .trim();
482                (type_name.to_string(), None)
483            };
484
485            if types.contains_key(&type_name) {
486                current_impl = Some((type_name, trait_name));
487                in_impl = true;
488            }
489        }
490
491        // Track brace depth
492        for c in trimmed.chars() {
493            if c == '{' {
494                brace_depth += 1;
495            }
496            if c == '}' {
497                brace_depth -= 1;
498                if brace_depth <= 0 && in_impl {
499                    in_impl = false;
500                    current_impl = None;
501                }
502            }
503        }
504
505        // Extract pub fn inside impl blocks
506        if in_impl && (trimmed.starts_with("pub fn ") || trimmed.starts_with("pub const fn ")) {
507            if let Some((ref type_name, ref trait_name)) = current_impl {
508                if let Some(sig) = extract_fn_sig(trimmed) {
509                    let fn_name = extract_fn_name(&sig);
510                    let fn_sig = FnSig {
511                        name: fn_name.clone(),
512                        signature: sig,
513                    };
514
515                    if let Some(api_type) = types.get_mut(type_name) {
516                        // Record trait implementation
517                        if let Some(t) = &trait_name {
518                            if !api_type.traits.contains(t) {
519                                api_type.traits.push(t.clone());
520                            }
521                        }
522
523                        // Categorize the method
524                        if fn_name == "new"
525                            || fn_name == "on_device"
526                            || fn_name == "no_bias"
527                            || fn_name == "no_bias_on_device"
528                            || fn_name == "configure"
529                            || fn_name == "default"
530                        {
531                            api_type.constructors.push(fn_sig);
532                        } else if fn_name.starts_with("with_") || fn_name == "done" || fn_name == "build" {
533                            api_type.builder_methods.push(fn_sig);
534                        } else {
535                            api_type.methods.push(fn_sig);
536                        }
537                    }
538                }
539            }
540        }
541    }
542
543    // Pass 3: collect top-level pub fns (not inside impl blocks).
544    // These are common for losses, init functions, utility functions.
545    let mut free_fns: Vec<FnSig> = Vec::new();
546    let mut depth: i32 = 0;
547    let mut in_test_block = false;
548
549    for (i, line) in lines.iter().enumerate() {
550        let trimmed = line.trim();
551
552        if trimmed.contains("#[cfg(test)]") {
553            in_test_block = true;
554        }
555
556        for c in trimmed.chars() {
557            if c == '{' { depth += 1; }
558            if c == '}' { depth -= 1; }
559        }
560
561        if in_test_block {
562            if depth <= 0 { in_test_block = false; }
563            continue;
564        }
565
566        // Top-level pub fn: depth 0 (module level) or 1 (inside mod block)
567        if depth <= 1 && trimmed.starts_with("pub fn ") {
568            if let Some(sig) = extract_fn_sig(trimmed) {
569                let fn_name = extract_fn_name(&sig);
570                let (doc, _) = extract_docs(&lines, i);
571                free_fns.push(FnSig {
572                    name: format!("{} -- {}", fn_name, doc),
573                    signature: sig,
574                });
575            }
576        }
577    }
578
579    if !free_fns.is_empty() {
580        // Determine a good label from the file name
581        let file_stem = std::path::Path::new(&rel_path)
582            .file_stem()
583            .unwrap_or_default()
584            .to_string_lossy()
585            .to_string();
586
587        let label = match file_stem.as_str() {
588            "mod" => {
589                // Use parent directory name
590                std::path::Path::new(&rel_path)
591                    .parent()
592                    .and_then(|p| p.file_name())
593                    .unwrap_or_default()
594                    .to_string_lossy()
595                    .to_string()
596            }
597            other => other.to_string(),
598        };
599
600        types.insert(
601            format!("{}()", label),
602            ApiType {
603                name: format!("{} (functions)", label),
604                category: categorize(&rel_path),
605                file: rel_path,
606                doc_summary: String::new(),
607                doc_examples: Vec::new(),
608                constructors: Vec::new(),
609                methods: free_fns,
610                builder_methods: Vec::new(),
611                traits: Vec::new(),
612            },
613        );
614    }
615
616    types.into_values().collect()
617}
618
619/// Walk a source tree and parse all .rs files.
620fn parse_source_tree(src_root: &Path) -> Vec<ApiType> {
621    let mut all_types = Vec::new();
622    walk_dir(src_root, src_root, &mut all_types);
623    // Sort by category then name
624    all_types.sort_by(|a, b| a.category.cmp(b.category).then(a.name.cmp(&b.name)));
625    all_types
626}
627
628fn walk_dir(root: &Path, dir: &Path, types: &mut Vec<ApiType>) {
629    let entries = match fs::read_dir(dir) {
630        Ok(e) => e,
631        Err(_) => return,
632    };
633    for entry in entries.flatten() {
634        let path = entry.path();
635        if path.is_dir() {
636            walk_dir(root, &path, types);
637        } else if path.extension().is_some_and(|e| e == "rs") {
638            let mut file_types = parse_file(root, &path);
639            types.append(&mut file_types);
640        }
641    }
642}
643
644// ---------------------------------------------------------------------------
645// Output
646// ---------------------------------------------------------------------------
647
648fn get_version(src_root: &Path) -> String {
649    // Try crate Cargo.toml first, then workspace root
650    let crate_dir = src_root.parent().unwrap_or(src_root);
651    for dir in &[crate_dir, crate_dir.parent().unwrap_or(crate_dir)] {
652        let cargo_toml = dir.join("Cargo.toml");
653        if let Ok(content) = fs::read_to_string(cargo_toml) {
654            // Look for version = "x.y.z" (not version.workspace = true)
655            for line in content.lines() {
656                let trimmed = line.trim();
657                if trimmed.starts_with("version") && trimmed.contains('"') && !trimmed.contains("workspace") {
658                    if let Some(v) = trimmed.split('"').nth(1) {
659                        return v.to_string();
660                    }
661                }
662            }
663        }
664    }
665    "unknown".to_string()
666}
667
668fn print_text(api: &ApiRef) {
669    println!("flodl API Reference v{}", api.version);
670    println!("{}", "=".repeat(40));
671    println!();
672
673    let mut by_category: BTreeMap<&str, Vec<&ApiType>> = BTreeMap::new();
674    for t in &api.types {
675        by_category.entry(t.category).or_default().push(t);
676    }
677
678    for (category, types) in &by_category {
679        println!("## {}", category_title(category));
680        println!();
681
682        for t in types {
683            // Skip types with no public API
684            if t.constructors.is_empty() && t.methods.is_empty() && t.builder_methods.is_empty() {
685                continue;
686            }
687
688            print!("### {}", t.name);
689            if !t.traits.is_empty() {
690                print!("  (implements: {})", t.traits.join(", "));
691            }
692            println!();
693
694            if !t.doc_summary.is_empty() {
695                println!("  {}", t.doc_summary);
696            }
697            println!("  file: {}", t.file);
698
699            if !t.constructors.is_empty() {
700                println!("  constructors:");
701                for f in &t.constructors {
702                    println!("    {}", f.signature);
703                }
704            }
705            if !t.builder_methods.is_empty() {
706                println!("  builder:");
707                for f in &t.builder_methods {
708                    println!("    .{}()", f.name);
709                }
710            }
711            if !t.methods.is_empty() {
712                println!("  methods:");
713                for f in &t.methods {
714                    println!("    {}", f.signature);
715                }
716            }
717            if !t.doc_examples.is_empty() {
718                println!("  examples:");
719                for (ei, ex) in t.doc_examples.iter().enumerate() {
720                    if ei > 0 {
721                        println!();
722                    }
723                    for line in ex.lines() {
724                        println!("    {}", line);
725                    }
726                }
727            }
728            println!();
729        }
730    }
731}
732
733fn print_json(api: &ApiRef) {
734    print!("{{\"version\":\"{}\",\"types\":[", escape_json(&api.version));
735
736    for (i, t) in api.types.iter().enumerate() {
737        if t.constructors.is_empty() && t.methods.is_empty() && t.builder_methods.is_empty() {
738            continue;
739        }
740
741        if i > 0 {
742            print!(",");
743        }
744
745        print!(
746            "{{\"name\":\"{}\",\"category\":\"{}\",\"file\":\"{}\",\"doc\":\"{}\",",
747            escape_json(&t.name),
748            escape_json(t.category),
749            escape_json(&t.file),
750            escape_json(&t.doc_summary),
751        );
752
753        print!("\"traits\":[{}],",
754            t.traits.iter()
755                .map(|s| format!("\"{}\"", escape_json(s)))
756                .collect::<Vec<_>>()
757                .join(",")
758        );
759
760        print!("\"constructors\":[{}],",
761            t.constructors.iter()
762                .map(|f| format!("{{\"name\":\"{}\",\"sig\":\"{}\"}}", escape_json(&f.name), escape_json(&f.signature)))
763                .collect::<Vec<_>>()
764                .join(",")
765        );
766
767        print!("\"builder_methods\":[{}],",
768            t.builder_methods.iter()
769                .map(|f| format!("\"{}\"", escape_json(&f.name)))
770                .collect::<Vec<_>>()
771                .join(",")
772        );
773
774        print!("\"methods\":[{}],",
775            t.methods.iter()
776                .map(|f| format!("{{\"name\":\"{}\",\"sig\":\"{}\"}}", escape_json(&f.name), escape_json(&f.signature)))
777                .collect::<Vec<_>>()
778                .join(",")
779        );
780
781        print!("\"examples\":[{}]",
782            t.doc_examples.iter()
783                .map(|e| format!("\"{}\"", escape_json(e)))
784                .collect::<Vec<_>>()
785                .join(",")
786        );
787
788        print!("}}");
789    }
790
791    println!("]}}");
792}
793
794fn category_title(cat: &str) -> &str {
795    match cat {
796        "modules" => "Modules (nn)",
797        "losses" => "Losses",
798        "optimizers" => "Optimizers",
799        "schedulers" => "Schedulers",
800        "tensor" => "Tensor",
801        "autograd" => "Autograd",
802        "graph" => "Graph",
803        "distributed" => "Distributed",
804        "data" => "Data",
805        other => other,
806    }
807}
808
809fn escape_json(s: &str) -> String {
810    s.replace('\\', "\\\\")
811        .replace('"', "\\\"")
812        .replace('\n', "\\n")
813        .replace('\r', "")
814        .replace('\t', "\\t")
815}
816
817// ---------------------------------------------------------------------------
818// Public entry point
819// ---------------------------------------------------------------------------
820
821pub fn run(json: bool, path: Option<&str>) -> Result<(), String> {
822    let src_root = find_flodl_src(path)
823        .ok_or_else(|| {
824            "Could not find flodl source. Run from a flodl checkout, \
825             or pass --path <flodl/src/>."
826                .to_string()
827        })?;
828
829    let version = get_version(&src_root);
830    let types = parse_source_tree(&src_root);
831
832    let api = ApiRef { version, types };
833
834    if json {
835        print_json(&api);
836    } else {
837        print_text(&api);
838    }
839
840    Ok(())
841}