use tree_sitter::{Node, Parser};
use crate::languages::Lang;
#[derive(Debug, Clone, PartialEq)]
pub struct SymbolDef {
pub name: String,
pub qualified: String,
pub kind: &'static str,
pub signature: String,
pub doc: Option<String>,
pub start_line: usize,
pub end_line: usize,
}
#[derive(Debug, Clone, PartialEq)]
pub struct CallSite {
pub caller: String,
pub callee: String,
}
#[derive(Debug, Clone, PartialEq)]
pub struct ImportRef {
pub local: String,
pub source: String,
}
#[derive(Debug, Default)]
pub struct FileExtract {
pub symbols: Vec<SymbolDef>,
pub calls: Vec<CallSite>,
pub imports: Vec<ImportRef>,
}
pub fn extract(lang: Lang, source: &str) -> FileExtract {
let mut parser = Parser::new();
if parser.set_language(&lang.language()).is_err() {
return FileExtract::default();
}
let Some(tree) = parser.parse(source, None) else {
return FileExtract::default();
};
let mut out = FileExtract::default();
walk(lang, tree.root_node(), source, &mut Vec::new(), &mut out);
out
}
fn walk(lang: Lang, node: Node, src: &str, scope: &mut Vec<String>, out: &mut FileExtract) {
let mut pushed = false;
if let Some((name, kind)) = lang.definition(node, src) {
let qualified = qualify(scope, &name);
let kind = if kind == "function" && !scope.is_empty() {
"method"
} else {
kind
};
out.symbols.push(SymbolDef {
signature: signature_text(lang, node, src),
doc: lang.doc_comment(node, src),
start_line: node.start_position().row + 1,
end_line: node.end_position().row + 1,
name: qualified_tail(&name),
qualified: qualified.clone(),
kind,
});
scope.push(name);
pushed = true;
} else if let Some(scope_name) = lang.scope_only(node, src) {
scope.push(scope_name);
pushed = true;
}
if let Some(callee) = lang.call(node, src) {
out.calls.push(CallSite {
caller: scope.join(&lang.separator()),
callee,
});
}
lang.imports(node, src, &mut out.imports);
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
walk(lang, child, src, scope, out);
}
if pushed {
scope.pop();
}
}
fn qualify(scope: &[String], name: &str) -> String {
if scope.is_empty() {
name.to_string()
} else {
format!("{}::{name}", scope.join("::"))
}
}
fn qualified_tail(qualified: &str) -> String {
qualified
.rsplit("::")
.next()
.unwrap_or(qualified)
.to_string()
}
fn signature_text(lang: Lang, node: Node, src: &str) -> String {
let full = &src[node.byte_range()];
let cut = lang
.body_field()
.and_then(|f| node.child_by_field_name(f))
.map(|b| b.start_byte().saturating_sub(node.start_byte()))
.unwrap_or(full.len());
let sig: String = full[..cut].split_whitespace().collect::<Vec<_>>().join(" ");
sig.chars().take(300).collect()
}
#[cfg(test)]
mod tests {
use super::*;
fn names(fx: &FileExtract) -> Vec<(&str, &str)> {
fx.symbols
.iter()
.map(|s| (s.qualified.as_str(), s.kind))
.collect()
}
#[test]
fn rust_extraction() {
let src = r#"
//! module docs
/// Adds things.
pub fn add(a: i32, b: i32) -> i32 { helper(a) + b }
fn helper(x: i32) -> i32 { x }
pub struct Counter { n: u64 }
impl Counter {
/// Bump it.
pub fn bump(&mut self) { self.n += 1; validate(self.n); }
}
pub trait Resettable { fn reset(&mut self); }
pub enum Mode { A, B }
use std::collections::HashMap;
use crate::store::resolve_ref as rr;
"#;
let fx = extract(Lang::Rust, src);
let n = names(&fx);
assert!(n.contains(&("add", "function")), "{n:?}");
assert!(n.contains(&("helper", "function")), "{n:?}");
assert!(n.contains(&("Counter", "struct")), "{n:?}");
assert!(n.contains(&("Counter::bump", "method")), "{n:?}");
assert!(n.contains(&("Resettable", "trait")), "{n:?}");
assert!(n.contains(&("Mode", "enum")), "{n:?}");
let add = fx.symbols.iter().find(|s| s.qualified == "add").unwrap();
assert_eq!(add.doc.as_deref(), Some("Adds things."));
assert!(add.signature.contains("pub fn add(a: i32, b: i32) -> i32"));
let calls: Vec<(&str, &str)> = fx
.calls
.iter()
.map(|c| (c.caller.as_str(), c.callee.as_str()))
.collect();
assert!(calls.contains(&("add", "helper")), "{calls:?}");
assert!(calls.contains(&("Counter::bump", "validate")), "{calls:?}");
let imports: Vec<(&str, &str)> = fx
.imports
.iter()
.map(|i| (i.local.as_str(), i.source.as_str()))
.collect();
assert!(
imports.contains(&("HashMap", "std::collections::HashMap")),
"{imports:?}"
);
assert!(
imports.contains(&("rr", "crate::store::resolve_ref")),
"{imports:?}"
);
}
#[test]
fn typescript_extraction() {
let src = r#"
import { fetchUser, postUser as pu } from "./api";
import db from "../db";
/** Greets. */
export function greet(name: string): string { return hello(name); }
const shout = (s: string) => s.toUpperCase();
export class UserService {
find(id: number) { return fetchUser(id); }
}
interface Shape { area(): number; }
"#;
let fx = extract(Lang::TypeScript, src);
let n = names(&fx);
assert!(n.contains(&("greet", "function")), "{n:?}");
assert!(n.contains(&("shout", "function")), "{n:?}");
assert!(n.contains(&("UserService", "class")), "{n:?}");
assert!(n.contains(&("UserService::find", "method")), "{n:?}");
assert!(n.contains(&("Shape", "interface")), "{n:?}");
let calls: Vec<(&str, &str)> = fx
.calls
.iter()
.map(|c| (c.caller.as_str(), c.callee.as_str()))
.collect();
assert!(calls.contains(&("greet", "hello")), "{calls:?}");
assert!(
calls.contains(&("UserService::find", "fetchUser")),
"{calls:?}"
);
let imports: Vec<(&str, &str)> = fx
.imports
.iter()
.map(|i| (i.local.as_str(), i.source.as_str()))
.collect();
assert!(imports.contains(&("fetchUser", "./api")), "{imports:?}");
assert!(imports.contains(&("pu", "./api")), "{imports:?}");
assert!(imports.contains(&("db", "../db")), "{imports:?}");
}
#[test]
fn python_extraction() {
let src = r#"
import os
from collections import OrderedDict as OD
from .util import slugify
def top(x):
"""Top-level docstring."""
return slugify(x)
class Repo:
def save(self, item):
validate(item)
return persist(item)
"#;
let fx = extract(Lang::Python, src);
let n = names(&fx);
assert!(n.contains(&("top", "function")), "{n:?}");
assert!(n.contains(&("Repo", "class")), "{n:?}");
assert!(n.contains(&("Repo::save", "method")), "{n:?}");
let top_sym = fx.symbols.iter().find(|s| s.qualified == "top").unwrap();
assert_eq!(top_sym.doc.as_deref(), Some("Top-level docstring."));
let calls: Vec<(&str, &str)> = fx
.calls
.iter()
.map(|c| (c.caller.as_str(), c.callee.as_str()))
.collect();
assert!(calls.contains(&("top", "slugify")), "{calls:?}");
assert!(calls.contains(&("Repo::save", "validate")), "{calls:?}");
let imports: Vec<(&str, &str)> = fx
.imports
.iter()
.map(|i| (i.local.as_str(), i.source.as_str()))
.collect();
assert!(imports.contains(&("os", "os")), "{imports:?}");
assert!(imports.contains(&("OD", "collections")), "{imports:?}");
assert!(imports.contains(&("slugify", ".util")), "{imports:?}");
}
#[test]
fn go_extraction() {
let src = r#"
package main
import (
"fmt"
alias "net/http"
)
// Greet says hi.
func Greet(name string) string { return fmt.Sprintf("hi %s", name) }
type Server struct{ port int }
func (s *Server) Start() error { return listen(s.port) }
"#;
let fx = extract(Lang::Go, src);
let n = names(&fx);
assert!(n.contains(&("Greet", "function")), "{n:?}");
assert!(n.contains(&("Server", "struct")), "{n:?}");
assert!(n.contains(&("Server::Start", "method")), "{n:?}");
let greet = fx.symbols.iter().find(|s| s.qualified == "Greet").unwrap();
assert_eq!(greet.doc.as_deref(), Some("Greet says hi."));
let calls: Vec<(&str, &str)> = fx
.calls
.iter()
.map(|c| (c.caller.as_str(), c.callee.as_str()))
.collect();
assert!(calls.contains(&("Greet", "Sprintf")), "{calls:?}");
assert!(calls.contains(&("Server::Start", "listen")), "{calls:?}");
let imports: Vec<(&str, &str)> = fx
.imports
.iter()
.map(|i| (i.local.as_str(), i.source.as_str()))
.collect();
assert!(imports.contains(&("fmt", "fmt")), "{imports:?}");
assert!(imports.contains(&("alias", "net/http")), "{imports:?}");
}
#[test]
fn broken_source_does_not_panic() {
for lang in [Lang::Rust, Lang::TypeScript, Lang::Python, Lang::Go] {
extract(lang, "fn class def func ((((");
extract(lang, "");
}
}
}