1use tree_sitter::{Node, Parser};
5
6use crate::languages::Lang;
7
8#[derive(Debug, Clone, PartialEq)]
9pub struct SymbolDef {
10 pub name: String,
12 pub qualified: String,
14 pub kind: &'static str,
16 pub signature: String,
18 pub doc: Option<String>,
19 pub start_line: usize,
21 pub end_line: usize,
22}
23
24#[derive(Debug, Clone, PartialEq)]
25pub struct CallSite {
26 pub caller: String,
28 pub callee: String,
30}
31
32#[derive(Debug, Clone, PartialEq)]
33pub struct ImportRef {
34 pub local: String,
36 pub source: String,
38}
39
40#[derive(Debug, Default)]
41pub struct FileExtract {
42 pub symbols: Vec<SymbolDef>,
43 pub calls: Vec<CallSite>,
44 pub imports: Vec<ImportRef>,
45}
46
47pub fn extract(lang: Lang, source: &str) -> FileExtract {
48 let mut parser = Parser::new();
49 if parser.set_language(&lang.language()).is_err() {
50 return FileExtract::default();
51 }
52 let Some(tree) = parser.parse(source, None) else {
53 return FileExtract::default();
54 };
55 let mut out = FileExtract::default();
56 walk(lang, tree.root_node(), source, &mut Vec::new(), &mut out);
57 out
58}
59
60fn walk(lang: Lang, node: Node, src: &str, scope: &mut Vec<String>, out: &mut FileExtract) {
62 let mut pushed = false;
63
64 if let Some((name, kind)) = lang.definition(node, src) {
65 let qualified = qualify(scope, &name);
66 let kind = if kind == "function" && !scope.is_empty() {
68 "method"
69 } else {
70 kind
71 };
72 out.symbols.push(SymbolDef {
73 signature: signature_text(lang, node, src),
74 doc: lang.doc_comment(node, src),
75 start_line: node.start_position().row + 1,
76 end_line: node.end_position().row + 1,
77 name: qualified_tail(&name),
78 qualified: qualified.clone(),
79 kind,
80 });
81 scope.push(name);
84 pushed = true;
85 } else if let Some(scope_name) = lang.scope_only(node, src) {
86 scope.push(scope_name);
89 pushed = true;
90 }
91
92 if let Some(callee) = lang.call(node, src) {
93 out.calls.push(CallSite {
94 caller: scope.join(&lang.separator()),
95 callee,
96 });
97 }
98 lang.imports(node, src, &mut out.imports);
99
100 let mut cursor = node.walk();
101 for child in node.children(&mut cursor) {
102 walk(lang, child, src, scope, out);
103 }
104 if pushed {
105 scope.pop();
106 }
107}
108
109fn qualify(scope: &[String], name: &str) -> String {
110 if scope.is_empty() {
111 name.to_string()
112 } else {
113 format!("{}::{name}", scope.join("::"))
114 }
115}
116
117fn qualified_tail(qualified: &str) -> String {
118 qualified
119 .rsplit("::")
120 .next()
121 .unwrap_or(qualified)
122 .to_string()
123}
124
125fn signature_text(lang: Lang, node: Node, src: &str) -> String {
127 let full = &src[node.byte_range()];
128 let cut = lang
129 .body_field()
130 .and_then(|f| node.child_by_field_name(f))
131 .map(|b| b.start_byte().saturating_sub(node.start_byte()))
132 .unwrap_or(full.len());
133 let sig: String = full[..cut].split_whitespace().collect::<Vec<_>>().join(" ");
134 sig.chars().take(300).collect()
136}
137
138#[cfg(test)]
139mod tests {
140 use super::*;
141
142 fn names(fx: &FileExtract) -> Vec<(&str, &str)> {
143 fx.symbols
144 .iter()
145 .map(|s| (s.qualified.as_str(), s.kind))
146 .collect()
147 }
148
149 #[test]
150 fn rust_extraction() {
151 let src = r#"
152//! module docs
153
154/// Adds things.
155pub fn add(a: i32, b: i32) -> i32 { helper(a) + b }
156
157fn helper(x: i32) -> i32 { x }
158
159pub struct Counter { n: u64 }
160
161impl Counter {
162 /// Bump it.
163 pub fn bump(&mut self) { self.n += 1; validate(self.n); }
164}
165
166pub trait Resettable { fn reset(&mut self); }
167
168pub enum Mode { A, B }
169
170use std::collections::HashMap;
171use crate::store::resolve_ref as rr;
172"#;
173 let fx = extract(Lang::Rust, src);
174 let n = names(&fx);
175 assert!(n.contains(&("add", "function")), "{n:?}");
176 assert!(n.contains(&("helper", "function")), "{n:?}");
177 assert!(n.contains(&("Counter", "struct")), "{n:?}");
178 assert!(n.contains(&("Counter::bump", "method")), "{n:?}");
179 assert!(n.contains(&("Resettable", "trait")), "{n:?}");
180 assert!(n.contains(&("Mode", "enum")), "{n:?}");
181
182 let add = fx.symbols.iter().find(|s| s.qualified == "add").unwrap();
183 assert_eq!(add.doc.as_deref(), Some("Adds things."));
184 assert!(add.signature.contains("pub fn add(a: i32, b: i32) -> i32"));
185
186 let calls: Vec<(&str, &str)> = fx
187 .calls
188 .iter()
189 .map(|c| (c.caller.as_str(), c.callee.as_str()))
190 .collect();
191 assert!(calls.contains(&("add", "helper")), "{calls:?}");
192 assert!(calls.contains(&("Counter::bump", "validate")), "{calls:?}");
193
194 let imports: Vec<(&str, &str)> = fx
195 .imports
196 .iter()
197 .map(|i| (i.local.as_str(), i.source.as_str()))
198 .collect();
199 assert!(
200 imports.contains(&("HashMap", "std::collections::HashMap")),
201 "{imports:?}"
202 );
203 assert!(
204 imports.contains(&("rr", "crate::store::resolve_ref")),
205 "{imports:?}"
206 );
207 }
208
209 #[test]
210 fn typescript_extraction() {
211 let src = r#"
212import { fetchUser, postUser as pu } from "./api";
213import db from "../db";
214
215/** Greets. */
216export function greet(name: string): string { return hello(name); }
217
218const shout = (s: string) => s.toUpperCase();
219
220export class UserService {
221 find(id: number) { return fetchUser(id); }
222}
223
224interface Shape { area(): number; }
225"#;
226 let fx = extract(Lang::TypeScript, src);
227 let n = names(&fx);
228 assert!(n.contains(&("greet", "function")), "{n:?}");
229 assert!(n.contains(&("shout", "function")), "{n:?}");
230 assert!(n.contains(&("UserService", "class")), "{n:?}");
231 assert!(n.contains(&("UserService::find", "method")), "{n:?}");
232 assert!(n.contains(&("Shape", "interface")), "{n:?}");
233
234 let calls: Vec<(&str, &str)> = fx
235 .calls
236 .iter()
237 .map(|c| (c.caller.as_str(), c.callee.as_str()))
238 .collect();
239 assert!(calls.contains(&("greet", "hello")), "{calls:?}");
240 assert!(
241 calls.contains(&("UserService::find", "fetchUser")),
242 "{calls:?}"
243 );
244
245 let imports: Vec<(&str, &str)> = fx
246 .imports
247 .iter()
248 .map(|i| (i.local.as_str(), i.source.as_str()))
249 .collect();
250 assert!(imports.contains(&("fetchUser", "./api")), "{imports:?}");
251 assert!(imports.contains(&("pu", "./api")), "{imports:?}");
252 assert!(imports.contains(&("db", "../db")), "{imports:?}");
253 }
254
255 #[test]
256 fn python_extraction() {
257 let src = r#"
258import os
259from collections import OrderedDict as OD
260from .util import slugify
261
262def top(x):
263 """Top-level docstring."""
264 return slugify(x)
265
266class Repo:
267 def save(self, item):
268 validate(item)
269 return persist(item)
270"#;
271 let fx = extract(Lang::Python, src);
272 let n = names(&fx);
273 assert!(n.contains(&("top", "function")), "{n:?}");
274 assert!(n.contains(&("Repo", "class")), "{n:?}");
275 assert!(n.contains(&("Repo::save", "method")), "{n:?}");
276
277 let top_sym = fx.symbols.iter().find(|s| s.qualified == "top").unwrap();
278 assert_eq!(top_sym.doc.as_deref(), Some("Top-level docstring."));
279
280 let calls: Vec<(&str, &str)> = fx
281 .calls
282 .iter()
283 .map(|c| (c.caller.as_str(), c.callee.as_str()))
284 .collect();
285 assert!(calls.contains(&("top", "slugify")), "{calls:?}");
286 assert!(calls.contains(&("Repo::save", "validate")), "{calls:?}");
287
288 let imports: Vec<(&str, &str)> = fx
289 .imports
290 .iter()
291 .map(|i| (i.local.as_str(), i.source.as_str()))
292 .collect();
293 assert!(imports.contains(&("os", "os")), "{imports:?}");
294 assert!(imports.contains(&("OD", "collections")), "{imports:?}");
295 assert!(imports.contains(&("slugify", ".util")), "{imports:?}");
296 }
297
298 #[test]
299 fn go_extraction() {
300 let src = r#"
301package main
302
303import (
304 "fmt"
305 alias "net/http"
306)
307
308// Greet says hi.
309func Greet(name string) string { return fmt.Sprintf("hi %s", name) }
310
311type Server struct{ port int }
312
313func (s *Server) Start() error { return listen(s.port) }
314"#;
315 let fx = extract(Lang::Go, src);
316 let n = names(&fx);
317 assert!(n.contains(&("Greet", "function")), "{n:?}");
318 assert!(n.contains(&("Server", "struct")), "{n:?}");
319 assert!(n.contains(&("Server::Start", "method")), "{n:?}");
320
321 let greet = fx.symbols.iter().find(|s| s.qualified == "Greet").unwrap();
322 assert_eq!(greet.doc.as_deref(), Some("Greet says hi."));
323
324 let calls: Vec<(&str, &str)> = fx
325 .calls
326 .iter()
327 .map(|c| (c.caller.as_str(), c.callee.as_str()))
328 .collect();
329 assert!(calls.contains(&("Greet", "Sprintf")), "{calls:?}");
330 assert!(calls.contains(&("Server::Start", "listen")), "{calls:?}");
331
332 let imports: Vec<(&str, &str)> = fx
333 .imports
334 .iter()
335 .map(|i| (i.local.as_str(), i.source.as_str()))
336 .collect();
337 assert!(imports.contains(&("fmt", "fmt")), "{imports:?}");
338 assert!(imports.contains(&("alias", "net/http")), "{imports:?}");
339 }
340
341 #[test]
342 fn broken_source_does_not_panic() {
343 for lang in [Lang::Rust, Lang::TypeScript, Lang::Python, Lang::Go] {
344 extract(lang, "fn class def func ((((");
345 extract(lang, "");
346 }
347 }
348}