use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::mpsc::{self, Receiver, Sender};
use std::sync::{Arc, Mutex};
use std::thread;
use tracing::debug;
use super::loader::GrammarLoader;
use super::manifest::{LangSpec, ManifestMeta};
#[derive(Clone, Debug, thiserror::Error)]
pub enum LoadError {
#[error("grammar load failed: {0}")]
Failed(String),
#[error("load handle dropped before completion")]
Cancelled,
}
struct Job {
name: String,
spec: LangSpec,
meta: ManifestMeta,
}
type InFlight = Arc<Mutex<HashMap<String, Vec<Sender<Result<PathBuf, LoadError>>>>>>;
pub struct AsyncGrammarLoader {
inner: Arc<GrammarLoader>,
in_flight: InFlight,
job_tx: Sender<Job>,
_worker_handles: Arc<[thread::JoinHandle<()>]>,
}
impl AsyncGrammarLoader {
pub fn new(loader: GrammarLoader) -> Self {
let inner = Arc::new(loader);
let in_flight: InFlight = Arc::new(Mutex::new(HashMap::new()));
let (job_tx, job_rx) = mpsc::channel::<Job>();
let shared_rx = Arc::new(Mutex::new(job_rx));
let mut handles = Vec::with_capacity(2);
for _ in 0..2 {
let loader_clone = Arc::clone(&inner);
let in_flight_clone = Arc::clone(&in_flight);
let rx_clone = Arc::clone(&shared_rx);
let handle = thread::Builder::new()
.name("hjkl-bonsai-grammar-loader".into())
.spawn(move || worker_loop(loader_clone, in_flight_clone, rx_clone))
.expect("spawn grammar loader worker");
handles.push(handle);
}
Self {
inner,
in_flight,
job_tx,
_worker_handles: handles.into(),
}
}
pub fn load_async(&self, name: String, spec: LangSpec, meta: ManifestMeta) -> LoadHandle {
let (tx, rx) = mpsc::channel();
let mut map = self.in_flight.lock().expect("in_flight mutex poisoned");
if let Some(senders) = map.get_mut(&name) {
senders.push(tx);
} else {
map.insert(name.clone(), vec![tx]);
drop(map); let _ = self.job_tx.send(Job { name, spec, meta });
return LoadHandle { rx };
}
LoadHandle { rx }
}
pub fn inner(&self) -> &GrammarLoader {
&self.inner
}
pub fn in_flight_names(&self) -> Vec<String> {
let map = self.in_flight.lock().expect("in_flight mutex poisoned");
map.keys().cloned().collect()
}
}
fn worker_loop(loader: Arc<GrammarLoader>, in_flight: InFlight, rx: Arc<Mutex<Receiver<Job>>>) {
loop {
let job = {
let guard = rx.lock().expect("job receiver mutex poisoned");
match guard.recv() {
Ok(j) => j,
Err(_) => break, }
};
let name = job.name.clone();
let result: Result<PathBuf, LoadError> = loader
.load(&job.name, &job.spec, &job.meta)
.map_err(|e| LoadError::Failed(format!("{e:#}")));
let senders: Vec<Sender<Result<PathBuf, LoadError>>> = {
let mut map = in_flight.lock().expect("in_flight mutex poisoned");
map.remove(&name).unwrap_or_default()
};
if senders.is_empty() {
debug!("load {name}: all subscribers dropped, completing anyway");
}
for tx in senders {
let _ = tx.send(result.clone());
}
}
}
pub struct LoadHandle {
rx: Receiver<Result<PathBuf, LoadError>>,
}
impl LoadHandle {
pub fn try_recv(&self) -> Option<Result<PathBuf, LoadError>> {
match self.rx.try_recv() {
Ok(r) => Some(r),
Err(mpsc::TryRecvError::Empty) => None,
Err(mpsc::TryRecvError::Disconnected) => Some(Err(LoadError::Cancelled)),
}
}
pub fn recv_blocking(self) -> Result<PathBuf, LoadError> {
self.rx.recv().unwrap_or(Err(LoadError::Cancelled))
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::path::PathBuf;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::Duration;
use crate::runtime::manifest::{LangSpec, ManifestMeta, QuerySource};
trait LoaderBackend: Send + Sync + 'static {
fn load(&self, name: &str, spec: &LangSpec, meta: &ManifestMeta)
-> anyhow::Result<PathBuf>;
}
struct MockLoader {
backend: Box<dyn LoaderBackend>,
call_count: Arc<AtomicUsize>,
}
impl MockLoader {
fn new(backend: impl LoaderBackend) -> Self {
Self {
backend: Box::new(backend),
call_count: Arc::new(AtomicUsize::new(0)),
}
}
}
struct TestAsyncLoader {
inner: Arc<MockLoader>,
in_flight: InFlight,
job_tx: Sender<TestJob>,
_handles: Vec<thread::JoinHandle<()>>,
}
struct TestJob {
name: String,
spec: LangSpec,
meta: ManifestMeta,
}
impl TestAsyncLoader {
fn new(mock: MockLoader) -> Self {
let inner = Arc::new(mock);
let in_flight: InFlight = Arc::new(Mutex::new(HashMap::new()));
let (job_tx, job_rx) = mpsc::channel::<TestJob>();
let shared_rx = Arc::new(Mutex::new(job_rx));
let mut handles = Vec::with_capacity(2);
for _ in 0..2 {
let loader_clone = Arc::clone(&inner);
let in_flight_clone = Arc::clone(&in_flight);
let rx_clone = Arc::clone(&shared_rx);
let handle = thread::spawn(move || {
loop {
let job = {
let guard = rx_clone.lock().unwrap();
match guard.recv() {
Ok(j) => j,
Err(_) => break,
}
};
loader_clone.call_count.fetch_add(1, Ordering::SeqCst);
let name = job.name.clone();
let result = loader_clone
.backend
.load(&job.name, &job.spec, &job.meta)
.map_err(|e| LoadError::Failed(format!("{e:#}")));
let senders: Vec<_> = {
let mut map = in_flight_clone.lock().unwrap();
map.remove(&name).unwrap_or_default()
};
if senders.is_empty() {
debug!("load {name}: all subscribers dropped, completing anyway");
}
for tx in senders {
let _ = tx.send(result.clone());
}
}
});
handles.push(handle);
}
Self {
inner,
in_flight,
job_tx,
_handles: handles,
}
}
fn load_async(&self, name: String, spec: LangSpec, meta: ManifestMeta) -> LoadHandle {
let (tx, rx) = mpsc::channel();
let mut map = self.in_flight.lock().unwrap();
if let Some(senders) = map.get_mut(&name) {
senders.push(tx);
} else {
map.insert(name.clone(), vec![tx]);
drop(map);
let _ = self.job_tx.send(TestJob { name, spec, meta });
return LoadHandle { rx };
}
LoadHandle { rx }
}
fn call_count(&self) -> usize {
self.inner.call_count.load(Ordering::SeqCst)
}
}
fn dummy_meta() -> ManifestMeta {
ManifestMeta {
helix_repo: "https://github.com/helix-editor/helix".into(),
helix_rev: "aaaa0000bbbb1111cccc2222dddd3333eeee4444".into(),
nvim_treesitter_repo: "https://github.com/nvim-treesitter/nvim-treesitter".into(),
nvim_treesitter_rev: "ffff5555aaaa0000bbbb1111cccc2222dddd3333".into(),
}
}
fn dummy_spec() -> LangSpec {
LangSpec {
git_url: "https://example.invalid/repo".into(),
git_rev: "0000000000000000".into(),
subpath: None,
extensions: vec!["x".into()],
c_files: vec!["src/parser.c".into()],
query_source: QuerySource::Helix,
query_subdir: None,
source: None,
}
}
struct OkBackend {
path: PathBuf,
}
impl LoaderBackend for OkBackend {
fn load(
&self,
_name: &str,
_spec: &LangSpec,
_meta: &ManifestMeta,
) -> anyhow::Result<PathBuf> {
Ok(self.path.clone())
}
}
struct ErrBackend;
impl LoaderBackend for ErrBackend {
fn load(
&self,
_name: &str,
_spec: &LangSpec,
_meta: &ManifestMeta,
) -> anyhow::Result<PathBuf> {
anyhow::bail!("mock compile error: cc not found")
}
}
struct SlowBackend {
delay: Duration,
path: PathBuf,
}
impl LoaderBackend for SlowBackend {
fn load(
&self,
_name: &str,
_spec: &LangSpec,
_meta: &ManifestMeta,
) -> anyhow::Result<PathBuf> {
thread::sleep(self.delay);
Ok(self.path.clone())
}
}
#[test]
fn load_async_dedups_concurrent_requests() {
let path = PathBuf::from("/fake/rust.so");
let mock = MockLoader::new(SlowBackend {
delay: Duration::from_millis(80),
path: path.clone(),
});
let loader = TestAsyncLoader::new(mock);
let mut handles = Vec::new();
for _ in 0..5 {
handles.push(loader.load_async("test_grammar".into(), dummy_spec(), dummy_meta()));
}
for h in handles {
assert_eq!(h.recv_blocking().unwrap(), path);
}
assert_eq!(
loader.call_count(),
1,
"expected 1 load() call, got {}",
loader.call_count()
);
}
#[test]
fn load_async_propagates_failure_to_all_subscribers() {
let mock = MockLoader::new(ErrBackend);
let loader = TestAsyncLoader::new(mock);
let mut handles = Vec::new();
for _ in 0..3 {
handles.push(loader.load_async("fail_grammar".into(), dummy_spec(), dummy_meta()));
}
for h in handles {
match h.recv_blocking() {
Err(LoadError::Failed(msg)) => {
assert!(
msg.contains("mock compile error"),
"unexpected error: {msg}"
)
}
other => panic!("expected LoadError::Failed, got {other:?}"),
}
}
}
#[test]
fn try_recv_returns_none_while_in_flight_then_some_on_completion() {
let path = PathBuf::from("/fake/slow.so");
let mock = MockLoader::new(SlowBackend {
delay: Duration::from_millis(100),
path: path.clone(),
});
let loader = TestAsyncLoader::new(mock);
let handle = loader.load_async("slow_grammar".into(), dummy_spec(), dummy_meta());
assert!(handle.try_recv().is_none(), "expected None while in-flight");
thread::sleep(Duration::from_millis(300));
assert_eq!(
handle.try_recv().unwrap().unwrap(),
path,
"expected Ok(path) after completion"
);
}
#[test]
fn recv_blocking_returns_result() {
let path = PathBuf::from("/fake/rust.so");
let mock = MockLoader::new(OkBackend { path: path.clone() });
let loader = TestAsyncLoader::new(mock);
let handle = loader.load_async("rust_grammar".into(), dummy_spec(), dummy_meta());
assert_eq!(handle.recv_blocking().unwrap(), path);
}
}