use std::{
process::Stdio,
sync::{
Arc,
atomic::{AtomicU64, Ordering},
},
time::Duration,
};
use anyhow::{Context, Result, bail};
use serde_json::{Value, json};
use tokio::{
io::{AsyncBufReadExt, AsyncWriteExt, BufReader},
process::{Child, ChildStdin, ChildStdout, Command},
sync::Mutex,
time,
};
use tracing::{debug, warn};
use super::manifest::PluginManifest;
const DEFAULT_CALL_TIMEOUT_SECS: u64 = 30;
pub struct ShellBridgePlugin {
name: String,
stdin: Arc<Mutex<ChildStdin>>,
stdout: Arc<Mutex<BufReader<ChildStdout>>>,
child: Arc<Mutex<Child>>,
next_id: Arc<AtomicU64>,
timeout: Duration,
}
impl ShellBridgePlugin {
pub async fn spawn(manifest: &PluginManifest) -> Result<Self> {
let runtime = resolve_runtime(&manifest.runtime)?;
let entry = manifest.dir.join(&manifest.entry);
if !entry.exists() {
bail!(
"plugin `{}` entry not found: {}",
manifest.name,
entry.display()
);
}
let mut child = Command::new(&runtime)
.arg(&entry)
.current_dir(&manifest.dir)
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::inherit()) .kill_on_drop(true)
.spawn()
.with_context(|| {
format!(
"spawn plugin `{}` with {runtime} {}",
manifest.name,
entry.display()
)
})?;
let stdin = child.stdin.take().context("plugin stdin")?;
let stdout = child.stdout.take().context("plugin stdout")?;
debug!(plugin = %manifest.name, runtime, "plugin subprocess started");
Ok(Self {
name: manifest.name.clone(),
stdin: Arc::new(Mutex::new(stdin)),
stdout: Arc::new(Mutex::new(BufReader::new(stdout))),
child: Arc::new(Mutex::new(child)),
next_id: Arc::new(AtomicU64::new(1)),
timeout: Duration::from_secs(DEFAULT_CALL_TIMEOUT_SECS),
})
}
pub async fn call(&self, method: &str, params: Value) -> Result<Value> {
let id = self.next_id.fetch_add(1, Ordering::Relaxed);
let request = json!({
"id": id,
"method": method,
"params": params,
});
let line = serde_json::to_string(&request).context("serialize request")?;
{
let mut stdin = self.stdin.lock().await;
stdin
.write_all(line.as_bytes())
.await
.with_context(|| format!("write to plugin `{}`", self.name))?;
stdin.write_all(b"\n").await.context("write newline")?;
stdin.flush().await.context("flush stdin")?;
}
let response_line = time::timeout(self.timeout, self.read_line()).await;
let line = match response_line {
Ok(Ok(l)) => l,
Ok(Err(e)) => {
bail!("plugin `{}` read error: {e:#}", self.name)
}
Err(_) => bail!(
"plugin `{}` call `{method}` timed out after {}s",
self.name,
self.timeout.as_secs()
),
};
let resp: Value = serde_json::from_str(&line)
.with_context(|| format!("plugin `{}` returned invalid JSON: {line}", self.name))?;
if resp["id"] != id {
warn!(
plugin = %self.name,
expected = id,
got = ?resp["id"],
"response ID mismatch"
);
}
if let Some(err) = resp.get("error") {
bail!("plugin `{}` error: {err}", self.name);
}
Ok(resp["result"].clone())
}
pub async fn shutdown(&self) {
let mut child = self.child.lock().await;
let _ = child.kill().await;
debug!(plugin = %self.name, "plugin subprocess terminated");
}
async fn read_line(&self) -> Result<String> {
let mut line = String::new();
let mut stdout = self.stdout.lock().await;
stdout
.read_line(&mut line)
.await
.with_context(|| format!("read from plugin `{}`", self.name))?;
Ok(line.trim_end_matches('\n').to_owned())
}
}
fn resolve_runtime(runtime: &str) -> Result<String> {
let candidates = match runtime {
"bun" => vec!["bun"],
"deno" => vec!["deno"],
"node" => vec!["node"],
other => vec![other],
};
for candidate in candidates {
if which::which(candidate).is_ok() {
return Ok(candidate.to_owned());
}
}
bail!(
"no suitable JS runtime found for `{runtime}`. \
Install node, bun, or deno and ensure it is on PATH."
)
}
pub struct Plugin {
inner: ShellBridgePlugin,
pub manifest: PluginManifest,
}
impl Plugin {
pub async fn spawn(manifest: PluginManifest) -> Result<Self> {
let inner = ShellBridgePlugin::spawn(&manifest).await?;
Ok(Self { inner, manifest })
}
pub async fn call(&self, method: &str, params: Value) -> Result<Value> {
self.inner.call(method, params).await
}
pub async fn shutdown(&self) {
self.inner.shutdown().await;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn resolve_node_runtime() {
let res = resolve_runtime("node");
if which::which("node").is_ok() {
assert!(res.is_ok());
}
}
#[test]
fn resolve_unknown_runtime_fails() {
assert!(resolve_runtime("__nonexistent_runtime_xyz__").is_err());
}
}