use std::path::{Path, PathBuf};
use anyhow::{Context, Result};
use serde::{Deserialize, Serialize};
use tree_sitter::{Language, Node, Parser};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum SymbolKind {
Function,
Method,
Class,
Struct,
Trait,
Impl,
Import,
TypeAlias,
Const,
Unknown,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Symbol {
pub file: PathBuf,
pub name: String,
pub kind: SymbolKind,
pub start_byte: usize,
pub end_byte: usize,
pub start_line: usize,
pub end_line: usize,
pub source: String,
}
pub fn detect_language(path: &Path) -> Option<(Language, &'static str)> {
let ext = path.extension()?.to_str()?;
match ext {
"rs" => Some((tree_sitter_rust::LANGUAGE.into(), "rust")),
"py" => Some((tree_sitter_python::LANGUAGE.into(), "python")),
"js" | "jsx" => Some((tree_sitter_javascript::LANGUAGE.into(), "javascript")),
"go" => Some((tree_sitter_go::LANGUAGE.into(), "go")),
"java" => Some((tree_sitter_java::LANGUAGE.into(), "java")),
"ts" => Some((
tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into(),
"typescript",
)),
"tsx" => Some((tree_sitter_typescript::LANGUAGE_TSX.into(), "typescript")),
"c" | "h" => Some((tree_sitter_c::LANGUAGE.into(), "c")),
"cpp" | "cc" | "cxx" | "hpp" | "hxx" => Some((tree_sitter_cpp::LANGUAGE.into(), "cpp")),
_ => None,
}
}
pub fn extract_symbols(source: &str, language: Language, file: &Path) -> Vec<Symbol> {
let lang_tag = detect_language(file).map(|(_, t)| t).unwrap_or("");
let mut parser = Parser::new();
if parser.set_language(&language).is_err() {
return Vec::new();
}
let Some(tree) = parser.parse(source, None) else {
return Vec::new();
};
let mut out = Vec::new();
walk(tree.root_node(), source, file, lang_tag, &mut out);
out
}
fn walk(node: Node, source: &str, file: &Path, lang: &str, out: &mut Vec<Symbol>) {
if let Some((kind, name)) = classify(node, source, lang) {
let start_byte = node.start_byte();
let end_byte = node.end_byte();
if let Some(text) = source.get(start_byte..end_byte) {
out.push(Symbol {
file: file.to_path_buf(),
name,
kind,
start_byte,
end_byte,
start_line: node.start_position().row + 1,
end_line: node.end_position().row + 1,
source: text.to_string(),
});
}
}
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
walk(child, source, file, lang, out);
}
}
fn classify(node: Node, source: &str, lang: &str) -> Option<(SymbolKind, String)> {
let kind_str = node.kind();
let bytes = source.as_bytes();
match lang {
"rust" => match kind_str {
"function_item" => Some((SymbolKind::Function, name_field(node, bytes)?)),
"struct_item" => Some((SymbolKind::Struct, name_field(node, bytes)?)),
"trait_item" => Some((SymbolKind::Trait, name_field(node, bytes)?)),
"impl_item" => {
let n = node
.child_by_field_name("type")
.and_then(|c| c.utf8_text(bytes).ok())
.map(|s| s.to_string())?;
Some((SymbolKind::Impl, n))
}
"use_declaration" => {
let text = node.utf8_text(bytes).ok()?.trim().to_string();
Some((SymbolKind::Import, text))
}
"type_item" => Some((SymbolKind::TypeAlias, name_field(node, bytes)?)),
"const_item" => Some((SymbolKind::Const, name_field(node, bytes)?)),
_ => None,
},
"python" => match kind_str {
"function_definition" | "async_function_definition" => {
Some((SymbolKind::Function, name_field(node, bytes)?))
}
"class_definition" => Some((SymbolKind::Class, name_field(node, bytes)?)),
"import_statement" | "import_from_statement" => {
let text = node.utf8_text(bytes).ok()?.trim().to_string();
Some((SymbolKind::Import, text))
}
_ => None,
},
"javascript" => match kind_str {
"function_declaration" | "function_expression" | "arrow_function" => {
let n = name_field(node, bytes).unwrap_or_else(|| "<anon>".to_string());
Some((SymbolKind::Function, n))
}
"method_definition" => Some((SymbolKind::Method, name_field(node, bytes)?)),
"class_declaration" => Some((SymbolKind::Class, name_field(node, bytes)?)),
"import_statement" => {
let text = node.utf8_text(bytes).ok()?.trim().to_string();
Some((SymbolKind::Import, text))
}
_ => None,
},
"go" => match kind_str {
"function_declaration" => Some((SymbolKind::Function, name_field(node, bytes)?)),
"method_declaration" => Some((SymbolKind::Method, name_field(node, bytes)?)),
"type_declaration" => {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "type_spec"
&& let Some(n) = name_field(child, bytes)
{
return Some((SymbolKind::TypeAlias, n));
}
}
None
}
"import_declaration" => {
let text = node.utf8_text(bytes).ok()?.trim().to_string();
Some((SymbolKind::Import, text))
}
_ => None,
},
"java" => match kind_str {
"method_declaration" | "constructor_declaration" => {
Some((SymbolKind::Method, name_field(node, bytes)?))
}
"class_declaration" | "enum_declaration" => {
Some((SymbolKind::Class, name_field(node, bytes)?))
}
"interface_declaration" => Some((SymbolKind::Trait, name_field(node, bytes)?)),
"import_declaration" => {
let text = node.utf8_text(bytes).ok()?.trim().to_string();
Some((SymbolKind::Import, text))
}
_ => None,
},
"typescript" => match kind_str {
"function_declaration" | "function_expression" | "arrow_function" => {
let n = name_field(node, bytes).unwrap_or_else(|| "<anon>".to_string());
Some((SymbolKind::Function, n))
}
"method_definition" => Some((SymbolKind::Method, name_field(node, bytes)?)),
"class_declaration" => Some((SymbolKind::Class, name_field(node, bytes)?)),
"interface_declaration" => Some((SymbolKind::Trait, name_field(node, bytes)?)),
"type_alias_declaration" => Some((SymbolKind::TypeAlias, name_field(node, bytes)?)),
"import_statement" => {
let text = node.utf8_text(bytes).ok()?.trim().to_string();
Some((SymbolKind::Import, text))
}
_ => None,
},
"c" => match kind_str {
"function_definition" => Some((SymbolKind::Function, name_field(node, bytes)?)),
"struct_specifier" => Some((SymbolKind::Struct, name_field(node, bytes)?)),
"type_definition" => {
let n = node
.child_by_field_name("declarator")
.and_then(|d| d.utf8_text(bytes).ok())
.map(|s| s.trim().to_string())?;
Some((SymbolKind::TypeAlias, n))
}
"preproc_include" => {
let text = node.utf8_text(bytes).ok()?.trim().to_string();
Some((SymbolKind::Import, text))
}
_ => None,
},
"cpp" => match kind_str {
"function_definition" => Some((SymbolKind::Function, name_field(node, bytes)?)),
"class_specifier" => Some((SymbolKind::Class, name_field(node, bytes)?)),
"struct_specifier" => Some((SymbolKind::Struct, name_field(node, bytes)?)),
"type_definition" => {
let n = node
.child_by_field_name("declarator")
.and_then(|d| d.utf8_text(bytes).ok())
.map(|s| s.trim().to_string())?;
Some((SymbolKind::TypeAlias, n))
}
"preproc_include" => {
let text = node.utf8_text(bytes).ok()?.trim().to_string();
Some((SymbolKind::Import, text))
}
_ => None,
},
_ => None,
}
}
fn name_field(node: Node, bytes: &[u8]) -> Option<String> {
node.child_by_field_name("name")
.and_then(|c| c.utf8_text(bytes).ok())
.map(|s| s.to_string())
}
pub fn get_symbol(source: &str, lang: Language, file: &Path, name: &str) -> Option<Symbol> {
extract_symbols(source, lang, file)
.into_iter()
.find(|s| s.name == name)
}
pub fn list_symbols(file: &Path) -> Result<Vec<Symbol>> {
let source = std::fs::read_to_string(file)
.with_context(|| format!("failed to read {}", file.display()))?;
let (lang, _) = detect_language(file)
.with_context(|| format!("unsupported file extension: {}", file.display()))?;
Ok(extract_symbols(&source, lang, file))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn extract_symbols_rust() {
let src = r#"
use std::collections::HashMap;
const MAX: u32 = 100;
struct Foo { x: i32 }
trait Bar { fn baz(&self); }
impl Bar for Foo {
fn baz(&self) {}
}
fn standalone() -> i32 { 42 }
type Alias = u64;
"#;
let path = PathBuf::from("test.rs");
let syms = extract_symbols(src, tree_sitter_rust::LANGUAGE.into(), &path);
let names: Vec<&str> = syms.iter().map(|s| s.name.as_str()).collect();
assert!(names.contains(&"Foo"), "expected Foo, got {names:?}");
assert!(names.contains(&"Bar"), "expected Bar, got {names:?}");
assert!(
names.contains(&"standalone"),
"expected standalone, got {names:?}"
);
assert!(names.contains(&"MAX"), "expected MAX const, got {names:?}");
assert!(names.contains(&"Alias"), "expected Alias, got {names:?}");
assert!(
syms.iter()
.any(|s| matches!(s.kind, SymbolKind::Impl) && s.name == "Foo")
);
let got = get_symbol(src, tree_sitter_rust::LANGUAGE.into(), &path, "standalone").unwrap();
assert_eq!(got.kind, SymbolKind::Function);
assert!(got.source.contains("42"));
}
#[test]
fn extract_symbols_python() {
let src = "import os\nfrom typing import List\n\nclass Foo:\n def bar(self):\n pass\n\ndef baz():\n return 1\n";
let path = PathBuf::from("test.py");
let syms = extract_symbols(src, tree_sitter_python::LANGUAGE.into(), &path);
assert!(
syms.iter()
.any(|s| s.name == "Foo" && matches!(s.kind, SymbolKind::Class))
);
assert!(
syms.iter()
.any(|s| s.name == "baz" && matches!(s.kind, SymbolKind::Function))
);
assert!(syms.iter().any(|s| matches!(s.kind, SymbolKind::Import)));
}
#[test]
fn extract_symbols_go() {
let src = "package main\n\nimport \"fmt\"\n\nfunc Foo() {}\n\ntype X int\n";
let path = PathBuf::from("test.go");
let syms = extract_symbols(src, tree_sitter_go::LANGUAGE.into(), &path);
assert!(syms.iter().any(|s| s.name == "Foo"));
}
}