use std::cell::RefCell;
use std::ops::ControlFlow;
use std::path::{Path, PathBuf};
use std::sync::{Arc, OnceLock, RwLock};
use std::time::{Duration, Instant};
use ahash::{AHashMap, AHashSet};
use thiserror::Error;
use tree_sitter::{Language, ParseOptions, Parser, Query, Tree};
pub const DEFAULT_PARSE_TIMEOUT: Duration = Duration::from_millis(5_000);
fn parse_timeout_from_env() -> Duration {
std::env::var("BASEMIND_PARSE_TIMEOUT_MS")
.ok()
.and_then(|s| s.parse::<u64>().ok())
.map(Duration::from_millis)
.unwrap_or(DEFAULT_PARSE_TIMEOUT)
}
#[derive(Debug, Error)]
pub enum LangError {
#[error("language pack error: {0}")]
Pack(String),
#[error("grammar download failed: {0}")]
Download(String),
#[error("query compile error for {lang}/{kind}: {msg}")]
QueryCompile {
lang: &'static str,
kind: &'static str,
msg: String,
},
#[error("failed to set language {0} on parser")]
ParserSetLanguage(String),
}
pub type LangId = &'static str;
pub const OVERRIDE_LANGUAGES: &[LangId] =
&["rust", "python", "typescript", "tsx", "javascript", "go"];
pub const SUPPORTED_LANGUAGES: &[LangId] = OVERRIDE_LANGUAGES;
fn override_query_source(lang: LangId) -> Option<&'static str> {
Some(match lang {
"rust" => include_str!("queries/rust.scm"),
"python" => include_str!("queries/python.scm"),
"typescript" => include_str!("queries/typescript.scm"),
"tsx" => include_str!("queries/tsx.scm"),
"javascript" => include_str!("queries/javascript.scm"),
"go" => include_str!("queries/go.scm"),
_ => return None,
})
}
pub fn has_override(lang: LangId) -> bool {
override_query_source(lang).is_some()
}
pub fn intern(name: &str) -> Option<LangId> {
for &lid in OVERRIDE_LANGUAGES {
if lid == name {
return Some(lid);
}
}
let lock = INTERNED.get_or_init(|| RwLock::new(AHashSet::new()));
if let Some(&existing) = lock
.read()
.expect("intern pool poisoned")
.iter()
.find(|s| **s == name)
{
return Some(existing);
}
if !tree_sitter_language_pack::has_language(name) {
return None;
}
let leaked: &'static str = Box::leak(name.to_string().into_boxed_str());
lock.write().expect("intern pool poisoned").insert(leaked);
Some(leaked)
}
static INTERNED: OnceLock<RwLock<AHashSet<&'static str>>> = OnceLock::new();
#[derive(Debug, Clone)]
pub struct BootstrapSummary {
pub already_cached: Vec<String>,
pub downloaded: Vec<String>,
pub cache_dir: Option<PathBuf>,
}
impl BootstrapSummary {
pub fn did_download(&self) -> bool {
!self.downloaded.is_empty()
}
}
static GRAMMAR_BOOTSTRAP: OnceLock<Result<Arc<BootstrapSummary>, Arc<LangError>>> = OnceLock::new();
fn tslp_version_from_cache_dir(p: &Path) -> Option<String> {
let parent = p.parent()?;
let leaf = parent.file_name()?.to_str()?;
leaf.strip_prefix('v').map(str::to_string)
}
pub fn ensure_grammars() -> Result<Arc<BootstrapSummary>, Arc<LangError>> {
GRAMMAR_BOOTSTRAP
.get_or_init(|| {
let cache_dir_str = tree_sitter_language_pack::cache_dir()
.map_err(|e| Arc::new(LangError::Pack(format!("resolve cache dir: {e}"))))?;
let cache_dir = PathBuf::from(&cache_dir_str);
let version = tslp_version_from_cache_dir(&cache_dir).ok_or_else(|| {
Arc::new(LangError::Pack(format!(
"could not parse tslp version out of {cache_dir_str:?}"
)))
})?;
let dm = tree_sitter_language_pack::DownloadManager::with_cache_dir(
&version,
cache_dir.clone(),
);
let installed: Vec<String> = dm.installed_languages();
let mut already_cached: Vec<String> = Vec::new();
let mut missing: Vec<&'static str> = Vec::new();
for &name in OVERRIDE_LANGUAGES {
if installed.iter().any(|n| n == name) {
already_cached.push(name.to_string());
} else {
missing.push(name);
}
}
if !missing.is_empty() {
if std::env::var("BASEMIND_GRAMMAR_OFFLINE")
.is_ok_and(|v| v != "0" && !v.is_empty())
{
return Err(Arc::new(LangError::Download(format!(
"offline mode: missing grammars {missing:?} and \
BASEMIND_GRAMMAR_OFFLINE is set",
))));
}
dm.ensure_languages(&missing)
.map_err(|e| Arc::new(LangError::Download(format!("{e}"))))?;
}
Ok(Arc::new(BootstrapSummary {
already_cached,
downloaded: missing.into_iter().map(str::to_string).collect(),
cache_dir: Some(cache_dir),
}))
})
.clone()
}
pub fn downloaded_languages() -> Vec<String> {
tree_sitter_language_pack::downloaded_languages()
}
pub fn grammar_cache_dir() -> Option<PathBuf> {
tree_sitter_language_pack::cache_dir()
.ok()
.map(PathBuf::from)
}
pub fn clean_grammar_cache() -> Result<(), LangError> {
tree_sitter_language_pack::clean_cache().map_err(|e| LangError::Pack(format!("{e}")))
}
pub fn detect(path: &Path) -> Option<LangId> {
tree_sitter_language_pack::detect_language(path.to_str()?)
}
pub fn language(lang: LangId) -> Result<Language, LangError> {
tree_sitter_language_pack::get_language(lang).map_err(|e| LangError::Pack(format!("{e}")))
}
thread_local! {
static PARSERS: RefCell<AHashMap<LangId, Parser>> = RefCell::new(AHashMap::new());
}
pub fn with_parser<F, R>(lang: LangId, f: F) -> Result<R, LangError>
where
F: FnOnce(&mut Parser) -> R,
{
PARSERS.with(|cell| {
let mut map = cell.borrow_mut();
if !map.contains_key(&lang) {
let mut p = Parser::new();
let ts_lang = language(lang)?;
p.set_language(&ts_lang)
.map_err(|_| LangError::ParserSetLanguage(lang.to_string()))?;
map.insert(lang, p);
}
Ok(f(map.get_mut(&lang).expect("just inserted")))
})
}
#[derive(Debug)]
pub enum ParseOutcome {
Ok(Tree),
Failed,
TimedOut,
}
pub fn parse_timed(parser: &mut Parser, source: &[u8], timeout: Duration) -> ParseOutcome {
let started = Instant::now();
let mut timed_out = false;
let len = source.len();
let mut input = |i: usize, _| -> &[u8] { if i < len { &source[i..] } else { &[] } };
let mut progress = |_state: &tree_sitter::ParseState| -> ControlFlow<()> {
if started.elapsed() > timeout {
timed_out = true;
ControlFlow::Break(())
} else {
ControlFlow::Continue(())
}
};
let opts = ParseOptions::new().progress_callback(&mut progress);
let tree = parser.parse_with_options(&mut input, None, Some(opts));
match (tree, timed_out) {
(Some(t), _) => ParseOutcome::Ok(t),
(None, true) => ParseOutcome::TimedOut,
(None, false) => ParseOutcome::Failed,
}
}
pub fn parse_with_default_timeout(parser: &mut Parser, source: &[u8]) -> ParseOutcome {
parse_timed(parser, source, parse_timeout_from_env())
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum QueryKind {
Symbols,
Imports,
Calls,
Docs,
}
impl QueryKind {
pub fn name(self) -> &'static str {
match self {
QueryKind::Symbols => "symbols",
QueryKind::Imports => "imports",
QueryKind::Calls => "calls",
QueryKind::Docs => "docs",
}
}
}
type CachedQuery = Option<Arc<Query>>;
type QueryMap = AHashMap<(LangId, QueryKind), CachedQuery>;
static QUERIES: OnceLock<RwLock<QueryMap>> = OnceLock::new();
fn extract_section(source: &str, name: &str) -> Option<String> {
let marker_open = format!(";; section: {name}");
let mut out = String::new();
let mut in_section = false;
for line in source.lines() {
let trimmed = line.trim_start();
if trimmed.starts_with(";; section:") {
in_section = trimmed.starts_with(&marker_open);
continue;
}
if in_section {
out.push_str(line);
out.push('\n');
}
}
if out.trim().is_empty() {
None
} else {
Some(out)
}
}
fn adapt_tslp_tags(source: &str) -> String {
let mut sym_buf = String::new();
let mut call_buf = String::new();
for pattern in split_top_level_patterns(source) {
let kind = classify_pattern(pattern);
match kind {
PatternKind::Definition => sym_buf.push_str(&rewrite_pattern(pattern, kind)),
PatternKind::ReferenceCall => call_buf.push_str(&rewrite_pattern(pattern, kind)),
PatternKind::Other => {}
}
}
let mut out = String::with_capacity(sym_buf.len() + call_buf.len() + 64);
out.push_str(";; section: symbols\n");
out.push_str(&sym_buf);
out.push_str("\n;; section: calls\n");
out.push_str(&call_buf);
out
}
#[derive(Clone, Copy)]
enum PatternKind {
Definition,
ReferenceCall,
Other,
}
fn split_top_level_patterns(source: &str) -> Vec<&str> {
let bytes = source.as_bytes();
let mut patterns: Vec<&str> = Vec::new();
let mut i = 0;
let mut start: Option<usize> = None;
let mut depth: i32 = 0;
while i < bytes.len() {
let b = bytes[i];
match b {
b';' => {
while i < bytes.len() && bytes[i] != b'\n' {
i += 1;
}
}
b'"' => {
i += 1;
while i < bytes.len() {
if bytes[i] == b'\\' && i + 1 < bytes.len() {
i += 2;
continue;
}
if bytes[i] == b'"' {
i += 1;
break;
}
i += 1;
}
}
b'(' => {
if depth == 0 && start.is_none() {
start = Some(i);
}
depth += 1;
i += 1;
}
b')' => {
depth -= 1;
i += 1;
if depth == 0 {
let mut j = i;
while j < bytes.len() && (bytes[j] == b' ' || bytes[j] == b'\t') {
j += 1;
}
if j < bytes.len() && bytes[j] == b'@' {
j += 1;
while j < bytes.len() && is_capture_ident_byte(bytes[j]) {
j += 1;
}
i = j;
}
if let Some(s) = start {
patterns.push(&source[s..i]);
}
start = None;
}
}
_ => i += 1,
}
}
patterns
}
fn is_capture_ident_byte(b: u8) -> bool {
b.is_ascii_alphanumeric() || b == b'_' || b == b'.' || b == b'?' || b == b'!'
}
fn classify_pattern(pattern: &str) -> PatternKind {
let bytes = pattern.as_bytes();
let mut i = bytes.len();
while i > 0 {
i -= 1;
if bytes[i] == b'@' {
let cap_start = i + 1;
let mut j = cap_start;
while j < bytes.len() && is_capture_ident_byte(bytes[j]) {
j += 1;
}
let cap = &pattern[cap_start..j];
return classify_capture(cap);
}
}
PatternKind::Other
}
fn classify_capture(cap: &str) -> PatternKind {
if let Some(suffix) = cap.strip_prefix("definition.") {
let _ = suffix;
PatternKind::Definition
} else if cap == "reference.call" {
PatternKind::ReferenceCall
} else {
PatternKind::Other
}
}
fn rewrite_pattern(pattern: &str, kind: PatternKind) -> String {
let mut out = String::with_capacity(pattern.len() + 16);
let bytes = pattern.as_bytes();
let mut i = 0;
while i < bytes.len() {
if bytes[i] != b'@' {
out.push(bytes[i] as char);
i += 1;
continue;
}
let cap_start = i + 1;
let mut j = cap_start;
while j < bytes.len() && is_capture_ident_byte(bytes[j]) {
j += 1;
}
let cap = &pattern[cap_start..j];
let rewritten = rewrite_capture(cap, kind);
out.push('@');
out.push_str(&rewritten);
i = j;
}
out.push('\n');
out
}
fn rewrite_capture(cap: &str, kind: PatternKind) -> String {
if cap == "name" {
return match kind {
PatternKind::Definition => "symbol.name".to_string(),
PatternKind::ReferenceCall => "call.callee".to_string(),
PatternKind::Other => "name".to_string(),
};
}
if let Some(suffix) = cap.strip_prefix("definition.") {
return format!("symbol.{suffix}");
}
if cap == "reference.call" {
return "call.range".to_string();
}
cap.to_string()
}
type AdaptedTagsMap = AHashMap<LangId, Arc<str>>;
static ADAPTED_TAGS: OnceLock<RwLock<AdaptedTagsMap>> = OnceLock::new();
fn tslp_tags_adapted(lang: LangId) -> Option<Arc<str>> {
let lock = ADAPTED_TAGS.get_or_init(|| RwLock::new(AHashMap::new()));
if let Some(cached) = lock.read().expect("adapted tags pool poisoned").get(&lang) {
return Some(Arc::clone(cached));
}
let raw = tree_sitter_language_pack::get_tags_query(lang)?;
let adapted: Arc<str> = Arc::from(adapt_tslp_tags(raw));
lock.write()
.expect("adapted tags pool poisoned")
.insert(lang, Arc::clone(&adapted));
Some(adapted)
}
pub fn try_get_query(lang: LangId, kind: QueryKind) -> Result<CachedQuery, LangError> {
let lock = QUERIES.get_or_init(|| RwLock::new(AHashMap::new()));
if let Some(slot) = lock.read().expect("query pool poisoned").get(&(lang, kind)) {
return Ok(slot.as_ref().map(Arc::clone));
}
let source: Option<String> = if let Some(raw) = override_query_source(lang) {
extract_section(raw, kind.name())
} else if matches!(kind, QueryKind::Symbols | QueryKind::Calls) {
tslp_tags_adapted(lang).and_then(|adapted| extract_section(&adapted, kind.name()))
} else {
None
};
let cached = match source {
Some(src) => {
let ts_lang = language(lang)?;
let query = Query::new(&ts_lang, &src).map_err(|e| LangError::QueryCompile {
lang,
kind: kind.name(),
msg: format!("{e}"),
})?;
Some(Arc::new(query))
}
None => None,
};
lock.write()
.expect("query pool poisoned")
.insert((lang, kind), cached.as_ref().map(Arc::clone));
Ok(cached)
}
pub fn get_query(lang: LangId, kind: QueryKind) -> Result<Arc<Query>, LangError> {
try_get_query(lang, kind)?.ok_or_else(|| LangError::QueryCompile {
lang,
kind: kind.name(),
msg: format!("no override or TSLP fallback for {}/{}", lang, kind.name()),
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn detect_known_extensions() {
assert_eq!(detect(Path::new("foo.rs")), Some("rust"));
assert_eq!(detect(Path::new("foo.py")), Some("python"));
assert_eq!(detect(Path::new("foo.go")), Some("go"));
}
#[test]
fn detect_dynamic_extension_resolves() {
assert_eq!(detect(Path::new("foo.cpp")), Some("cpp"));
}
#[test]
fn extract_section_basic() {
let src = ";; section: a\n(foo)\n;; section: b\n(bar)\n";
assert_eq!(extract_section(src, "a").unwrap().trim(), "(foo)");
assert_eq!(extract_section(src, "b").unwrap().trim(), "(bar)");
assert_eq!(extract_section(src, "c"), None);
}
#[test]
fn has_override_for_each_supported() {
for &name in OVERRIDE_LANGUAGES {
assert!(has_override(name), "missing override source for {name}");
}
}
#[test]
fn intern_known_overrides_returns_static() {
let owned = "rust".to_string();
let id = intern(&owned).expect("rust must intern");
assert!(std::ptr::eq(id, "rust"));
}
#[test]
fn intern_unknown_returns_none() {
assert!(intern("this-is-not-a-real-grammar-name").is_none());
}
#[test]
fn try_get_query_returns_none_for_unsupported_lang() {
let res = try_get_query("json", QueryKind::Symbols).expect("query lookup must not error");
assert!(res.is_none());
}
#[test]
fn adapt_tslp_tags_emits_two_sections() {
let src = "(function_item name: (identifier) @name) @definition.function\n\
(call_expression function: (identifier) @name) @reference.call\n";
let out = adapt_tslp_tags(src);
assert!(out.contains(";; section: symbols"));
assert!(out.contains(";; section: calls"));
assert!(out.contains("@symbol.function"));
assert!(out.contains("@symbol.name"));
assert!(out.contains("@call.range"));
assert!(out.contains("@call.callee"));
}
#[test]
fn adapt_tslp_tags_drops_reference_class() {
let src = "(impl_item trait: (type_identifier) @name) @reference.implementation\n\
(call_expression function: (identifier) @name) @reference.call\n";
let out = adapt_tslp_tags(src);
assert!(!out.contains("@reference"));
assert!(out.contains("@call.range"));
assert!(out.contains("@call.callee"));
}
#[test]
fn adapt_tslp_tags_handles_multiline_patterns() {
let src = "(struct_item\n name: (type_identifier) @name) @definition.class\n";
let out = adapt_tslp_tags(src);
assert!(out.contains("@symbol.class"));
assert!(out.contains("@symbol.name"));
}
#[test]
fn adapt_tslp_tags_real_rust_compiles() {
let raw = tree_sitter_language_pack::get_tags_query("rust").expect("rust ships tags.scm");
let adapted = adapt_tslp_tags(raw);
let sym = extract_section(&adapted, "symbols").expect("symbols section");
let calls = extract_section(&adapted, "calls").expect("calls section");
let ts_lang = language("rust").expect("rust language resolves");
Query::new(&ts_lang, &sym).expect("adapted symbols query compiles");
Query::new(&ts_lang, &calls).expect("adapted calls query compiles");
}
}