use std::collections::HashMap;
use std::sync::Arc;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::process::{Child, ChildStdin, Command};
use tokio::sync::{mpsc, Mutex, Semaphore};
use tokio::task::JoinHandle;
use super::super::sdk_message::SDKMessage;
use super::{
RunParams, RunnerError, RunnerStream, RunnerUpdate, StdioError, StdioInput,
StdioOutput,
};
type Registry = Arc<Mutex<HashMap<String, mpsc::UnboundedSender<RunnerUpdate>>>>;
pub struct Runner {
stdin: Mutex<ChildStdin>,
registry: Registry,
closed: Arc<std::sync::atomic::AtomicBool>,
semaphore: Arc<Semaphore>,
_stdout_task: JoinHandle<()>,
_stderr_task: JoinHandle<()>,
_child: Child,
}
impl Runner {
pub async fn spawn(binary: &str, query_limit: u64) -> Result<Self, RunnerError> {
let mut cmd = Command::new(binary);
cmd.stdin(std::process::Stdio::piped())
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped());
cmd.env_remove("CLAUDECODE");
cmd.kill_on_drop(true);
let mut child = cmd.spawn().map_err(|e| RunnerError::Spawn(e.to_string()))?;
let stdin = child
.stdin
.take()
.ok_or_else(|| RunnerError::Spawn("stdin not piped".into()))?;
let stdout = child
.stdout
.take()
.ok_or_else(|| RunnerError::Spawn("stdout not piped".into()))?;
let stderr = child
.stderr
.take()
.ok_or_else(|| RunnerError::Spawn("stderr not piped".into()))?;
let registry: Registry = Arc::new(Mutex::new(HashMap::new()));
let closed = Arc::new(std::sync::atomic::AtomicBool::new(false));
let semaphore = Arc::new(Semaphore::new(query_limit as usize));
let stdout_task = {
let registry = registry.clone();
let closed = closed.clone();
tokio::spawn(async move {
let mut reader = BufReader::new(stdout).lines();
while let Ok(Some(line)) = reader.next_line().await {
let trimmed = line.trim();
if trimmed.is_empty() {
continue;
}
Self::dispatch_stdout(trimmed, ®istry).await;
}
Self::close_all(®istry, &closed).await;
})
};
let stderr_task = {
let registry = registry.clone();
let closed = closed.clone();
tokio::spawn(async move {
let mut reader = BufReader::new(stderr).lines();
while let Ok(Some(line)) = reader.next_line().await {
let trimmed = line.trim();
if trimmed.is_empty() {
continue;
}
Self::dispatch_stderr(trimmed, ®istry).await;
}
Self::close_all(®istry, &closed).await;
})
};
Ok(Self {
stdin: Mutex::new(stdin),
registry,
closed,
semaphore,
_stdout_task: stdout_task,
_stderr_task: stderr_task,
_child: child,
})
}
pub async fn create_stream<'a>(
self: &Arc<Self>,
id: String,
params: RunParams<'a>,
) -> Result<RunnerStream, RunnerError> {
let permit = self
.semaphore
.clone()
.acquire_owned()
.await
.expect("Runner semaphore is never closed");
let rx = self.register(id.clone()).await?;
if let Err(e) = self.send_run(&id, params).await {
self.unregister(&id).await;
return Err(e);
}
Ok(RunnerStream::new(rx, self.clone(), id, permit))
}
async fn register(
&self,
id: String,
) -> Result<mpsc::UnboundedReceiver<RunnerUpdate>, RunnerError> {
if self.closed.load(std::sync::atomic::Ordering::Acquire) {
return Err(RunnerError::Closed);
}
let (tx, rx) = mpsc::unbounded_channel();
let mut reg = self.registry.lock().await;
if reg.contains_key(&id) {
return Err(RunnerError::DuplicateId(id));
}
reg.insert(id, tx);
Ok(rx)
}
pub(super) async fn unregister(&self, id: &str) {
let mut reg = self.registry.lock().await;
reg.remove(id);
}
async fn send_run<'a>(
&self,
id: &'a str,
params: RunParams<'a>,
) -> Result<(), RunnerError> {
if self.closed.load(std::sync::atomic::Ordering::Acquire) {
return Err(RunnerError::Closed);
}
let request = StdioInput::Run { id, params };
self.write_line(&request).await
}
async fn write_line(&self, request: &StdioInput<'_>) -> Result<(), RunnerError> {
let mut line = serde_json::to_vec(request)?;
line.push(b'\n');
let mut stdin = self.stdin.lock().await;
stdin
.write_all(&line)
.await
.map_err(|e| RunnerError::Write(e.to_string()))?;
stdin
.flush()
.await
.map_err(|e| RunnerError::Write(e.to_string()))?;
Ok(())
}
async fn dispatch_stdout(line: &str, registry: &Registry) {
let parsed: StdioOutput<SDKMessage> = match serde_json::from_str(line) {
Ok(p) => p,
Err(_) => return,
};
match parsed {
StdioOutput::Event { id, event } => {
Self::send_to(registry, &id, RunnerUpdate::Event(event)).await;
}
StdioOutput::End { id, status } => {
let mut reg = registry.lock().await;
if let Some(tx) = reg.remove(&id) {
let _ = tx.send(RunnerUpdate::End(status));
}
}
}
}
async fn dispatch_stderr(line: &str, registry: &Registry) {
let parsed: StdioError = match serde_json::from_str(line) {
Ok(p) => p,
Err(_) => return,
};
match parsed {
StdioError::Diag {
id,
level,
message,
} => {
Self::send_to(
registry,
&id,
RunnerUpdate::Diag { level, message },
)
.await;
}
StdioError::Fatal { message } => {
let mut reg = registry.lock().await;
let entries: Vec<_> = reg.drain().collect();
drop(reg);
for (_, tx) in entries {
let _ = tx.send(RunnerUpdate::Fatal(message.clone()));
}
}
}
}
async fn send_to(registry: &Registry, id: &str, update: RunnerUpdate) {
let reg = registry.lock().await;
if let Some(tx) = reg.get(id) {
let _ = tx.send(update);
}
}
async fn close_all(registry: &Registry, closed: &Arc<std::sync::atomic::AtomicBool>) {
if closed.swap(true, std::sync::atomic::Ordering::AcqRel) {
return; }
let mut reg = registry.lock().await;
let entries: Vec<_> = reg.drain().collect();
drop(reg);
for (_, tx) in entries {
let _ = tx.send(RunnerUpdate::RunnerExited);
}
}
}