use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::thread;
use crossbeam_channel::{Receiver, Sender, bounded};
use crate::installer::{InstallStatus, install_blocking};
use crate::manifest::ToolSpec;
struct Job {
name: String,
spec: ToolSpec,
}
type InFlight = Arc<Mutex<HashMap<String, Vec<Sender<InstallStatus>>>>>;
pub struct InstallPool {
job_tx: Sender<Job>,
in_flight: InFlight,
}
impl InstallPool {
pub fn new() -> Self {
let (job_tx, job_rx) = bounded::<Job>(64);
let in_flight: InFlight = Arc::new(Mutex::new(HashMap::new()));
for _ in 0..2 {
let job_rx = job_rx.clone();
let in_flight = Arc::clone(&in_flight);
thread::spawn(move || {
worker_loop(job_rx, in_flight);
});
}
Self { job_tx, in_flight }
}
pub fn install(&self, name: String, spec: ToolSpec) -> InstallHandle {
let (tx, rx) = bounded::<InstallStatus>(128);
let mut guard = self.in_flight.lock().unwrap();
if let Some(senders) = guard.get_mut(&name) {
senders.push(tx);
} else {
guard.insert(name.clone(), vec![tx]);
drop(guard);
let job = Job {
name: name.clone(),
spec,
};
let _ = self.job_tx.send(job);
return InstallHandle { name, rx };
}
drop(guard);
InstallHandle { name, rx }
}
pub fn in_flight_names(&self) -> Vec<String> {
let mut names: Vec<String> = self.in_flight.lock().unwrap().keys().cloned().collect();
names.sort_unstable();
names
}
}
impl Default for InstallPool {
fn default() -> Self {
Self::new()
}
}
#[derive(Clone)]
pub struct InstallHandle {
name: String,
rx: Receiver<InstallStatus>,
}
impl InstallHandle {
pub fn name(&self) -> &str {
&self.name
}
pub fn try_recv(&self) -> Option<InstallStatus> {
self.rx.try_recv().ok()
}
pub fn wait(&self) -> InstallStatus {
for status in &self.rx {
match &status {
InstallStatus::Done { .. } | InstallStatus::Failed(_) => return status,
_ => {}
}
}
InstallStatus::Failed("<channel closed>".to_string())
}
}
fn worker_loop(job_rx: Receiver<Job>, in_flight: InFlight) {
for job in &job_rx {
let name = job.name.clone();
let in_flight_clone = Arc::clone(&in_flight);
let name_clone = name.clone();
let progress = move |status: InstallStatus| {
broadcast(&in_flight_clone, &name_clone, status);
};
let result = install_blocking(&name, &job.spec, &progress);
let terminal = match result {
Ok(bin_path) => InstallStatus::Done { bin_path },
Err(e) => InstallStatus::Failed(e.to_string()),
};
broadcast(&in_flight, &name, terminal);
in_flight.lock().unwrap().remove(&name);
}
}
fn broadcast(in_flight: &InFlight, key: &str, status: InstallStatus) {
let mut guard = in_flight.lock().unwrap();
if let Some(senders) = guard.get_mut(key) {
senders.retain(|tx| tx.send(status.clone()).is_ok());
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::manifest::{GithubMethod, ToolCategory};
#[test]
fn in_flight_names_reports_active_job() {
use std::collections::BTreeMap;
let pool = InstallPool::new();
let mut sha256 = BTreeMap::new();
sha256.insert(
"x86_64-unknown-linux-gnu".to_string(),
"0000000000000000000000000000000000000000000000000000000000000000".to_string(),
);
let spec = ToolSpec {
category: ToolCategory::Lsp,
description: "test".to_string(),
version: "v1.0".to_string(),
bin: "test-tool".to_string(),
method: crate::manifest::InstallMethod::Github(GithubMethod {
repo: "owner/fake-repo".to_string(),
asset_pattern: "tool-{triple}.tar.gz".to_string(),
sha256,
}),
};
let handle = pool.install("test-tool".to_string(), spec);
let names = pool.in_flight_names();
let _ = names;
let status = handle.wait();
assert!(
matches!(
status,
InstallStatus::Failed(_) | InstallStatus::Done { .. }
),
"expected terminal status, got: {status:?}"
);
}
#[test]
fn dropped_handle_does_not_poison_pool() {
use std::collections::BTreeMap;
let pool = InstallPool::new();
let mut sha256 = BTreeMap::new();
sha256.insert("x86_64-unknown-linux-gnu".to_string(), "bad".to_string());
let spec = ToolSpec {
category: ToolCategory::Lsp,
description: "test".to_string(),
version: "v1.0".to_string(),
bin: "drop-tool".to_string(),
method: crate::manifest::InstallMethod::Github(GithubMethod {
repo: "owner/repo".to_string(),
asset_pattern: "drop-tool-{triple}.tar.gz".to_string(),
sha256,
}),
};
{
let _handle = pool.install("drop-tool".to_string(), spec);
}
let _ = pool.in_flight_names();
}
#[test]
fn concurrent_install_same_tool_both_handles_terminate() {
use std::collections::BTreeMap;
let pool = Arc::new(InstallPool::new());
let mut sha256 = BTreeMap::new();
sha256.insert(
"x86_64-unknown-linux-gnu".to_string(),
"badhash".to_string(),
);
let spec = ToolSpec {
category: ToolCategory::Lsp,
description: "test".to_string(),
version: "v1.0".to_string(),
bin: "shared-tool".to_string(),
method: crate::manifest::InstallMethod::Github(GithubMethod {
repo: "owner/repo".to_string(),
asset_pattern: "shared-tool-{triple}.tar.gz".to_string(),
sha256,
}),
};
let h1 = pool.install("shared-tool".to_string(), spec.clone());
let h2 = pool.install("shared-tool".to_string(), spec);
let s1 = h1.wait();
let s2 = h2.wait();
assert!(
matches!(s1, InstallStatus::Failed(_) | InstallStatus::Done { .. }),
"h1 got: {s1:?}"
);
assert!(
matches!(s2, InstallStatus::Failed(_) | InstallStatus::Done { .. }),
"h2 got: {s2:?}"
);
}
}