use crate::core::ctx::Ctx;
use crate::core::engine::Engine;
use crate::types::{CapToken, TaskId, WorkerId};
use crate::worker::adapter::{SpawnError, SpawnerAdapter, WorkerError, WorkerResult};
use crate::worker::output::{ContentRef, OutputEvent};
use crate::worker::{Worker, WorkerJoinHandler};
use async_trait::async_trait;
use serde_json::Value;
use std::process::Stdio;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::process::Command;
use tokio::sync::oneshot;
use tokio_util::sync::CancellationToken;
#[derive(Debug, Clone)]
pub enum StreamMode {
NdjsonLines,
SseEvents,
LengthPrefixed,
}
pub struct ProcessSpawner {
pub program: String,
pub args: Vec<String>,
pub use_stdin: bool,
pub stream_mode: Option<StreamMode>,
}
impl ProcessSpawner {
pub fn new(program: impl Into<String>) -> Self {
Self {
program: program.into(),
args: Vec::new(),
use_stdin: true,
stream_mode: None,
}
}
pub fn arg(mut self, a: impl Into<String>) -> Self {
self.args.push(a.into());
self
}
pub fn args(mut self, args: impl IntoIterator<Item = impl Into<String>>) -> Self {
self.args.extend(args.into_iter().map(|a| a.into()));
self
}
pub fn use_stdin(mut self, v: bool) -> Self {
self.use_stdin = v;
self
}
pub fn stream_mode(mut self, mode: StreamMode) -> Self {
self.stream_mode = Some(mode);
self
}
pub fn plain(mut self) -> Self {
self.stream_mode = None;
self
}
pub fn ndjson(mut self, v: bool) -> Self {
self.stream_mode = if v {
Some(StreamMode::NdjsonLines)
} else {
None
};
self
}
pub fn run(cmd: impl Into<String>) -> Self {
Self {
program: "sh".into(),
args: vec!["-c".into(), cmd.into()],
use_stdin: true,
stream_mode: None,
}
}
pub fn cmd(program: impl Into<String>) -> Self {
Self {
program: program.into(),
args: Vec::new(),
use_stdin: true,
stream_mode: None,
}
}
}
#[async_trait]
impl SpawnerAdapter for ProcessSpawner {
async fn spawn(
&self,
engine: &Engine,
ctx: &Ctx,
task_id: TaskId,
attempt: u32,
token: CapToken,
) -> Result<Box<dyn Worker>, SpawnError> {
let directive = engine
.fetch_prompt(&token, &task_id)
.await
.map_err(|e| SpawnError::Internal(format!("fetch_prompt: {e}")))?;
let mut cmd = Command::new(&self.program);
cmd.args(&self.args)
.env("MSE_TOKEN_AGENT_ID", &token.agent_id)
.env("MSE_TOKEN_NONCE", &token.nonce)
.env("MSE_TASK_ID", &task_id.0)
.env("MSE_ATTEMPT", attempt.to_string())
.env("MSE_CTX_AGENT", &ctx.agent)
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped());
if !self.use_stdin {
cmd.arg(&directive);
}
let mut child = cmd
.spawn()
.map_err(|e| SpawnError::Internal(format!("spawn failed: {e}")))?;
if self.use_stdin {
if let Some(mut stdin) = child.stdin.take() {
stdin
.write_all(directive.as_bytes())
.await
.map_err(|e| SpawnError::Internal(format!("stdin write: {e}")))?;
drop(stdin); }
}
let cancel = CancellationToken::new();
let cancel_inner = cancel.clone();
let worker_id = WorkerId::new();
let (tx, rx) = oneshot::channel();
let engine_for_emit = engine.clone();
let token_for_emit = token.clone();
let task_id_for_emit = task_id.clone();
let stream_mode = self.stream_mode.clone();
tokio::spawn(async move {
let result: Result<WorkerResult, WorkerError> = if let Some(mode) = stream_mode {
run_streaming_mode(
mode,
child,
&engine_for_emit,
&token_for_emit,
&task_id_for_emit,
attempt,
cancel_inner,
)
.await
} else {
let result = tokio::select! {
output = child.wait_with_output() => {
match output {
Ok(out) => {
let stdout = String::from_utf8_lossy(&out.stdout).to_string();
let value: Value = serde_json::from_str(stdout.trim())
.unwrap_or_else(|_| serde_json::json!({
"raw": stdout.trim_end(),
"stderr": String::from_utf8_lossy(&out.stderr).to_string(),
}));
Ok(WorkerResult { value, ok: out.status.success() })
}
Err(e) => Err(WorkerError::Failed(format!("wait_with_output: {e}"))),
}
}
_ = cancel_inner.cancelled() => Err(WorkerError::Cancelled),
};
if let Ok(wr) = &result {
let ev = OutputEvent::Final {
content: ContentRef::Inline {
value: wr.value.clone(),
},
ok: wr.ok,
};
let _ = engine_for_emit
.submit_output(&token_for_emit, &task_id_for_emit, attempt, ev)
.await;
}
result
};
let signal: Result<(), WorkerError> = result.map(|_| ());
let _ = tx.send(signal);
});
Ok(Box::new(ProcessWorker {
handler: WorkerJoinHandler {
worker_id,
cancel,
completion: rx,
},
}))
}
}
pub struct ProcessWorker {
pub handler: WorkerJoinHandler,
}
#[async_trait]
impl Worker for ProcessWorker {
fn id(&self) -> &WorkerId {
&self.handler.worker_id
}
fn cancel_token(&self) -> CancellationToken {
self.handler.cancel.clone()
}
async fn join(self: Box<Self>) -> Result<(), WorkerError> {
self.handler.await_completion().await
}
}
async fn run_streaming_mode(
mode: StreamMode,
mut child: tokio::process::Child,
engine: &Engine,
token: &CapToken,
task_id: &TaskId,
attempt: u32,
cancel: CancellationToken,
) -> Result<WorkerResult, WorkerError> {
let stdout = child
.stdout
.take()
.ok_or_else(|| WorkerError::Failed("streaming: stdout pipe missing".into()))?;
let last_final = match mode {
StreamMode::NdjsonLines => {
read_ndjson(stdout, engine, token, task_id, attempt, cancel.clone()).await?
}
StreamMode::SseEvents => {
read_sse(stdout, engine, token, task_id, attempt, cancel.clone()).await?
}
StreamMode::LengthPrefixed => {
read_length_prefixed(stdout, engine, token, task_id, attempt, cancel.clone()).await?
}
};
let status = child
.wait()
.await
.map_err(|e| WorkerError::Failed(format!("streaming wait: {e}")))?;
match last_final {
Some((value, ok)) => Ok(WorkerResult {
value,
ok: ok && status.success(),
}),
None => {
let value = serde_json::json!({
"raw": "",
"note": "streaming mode: no Final event received",
"exit_success": status.success(),
});
let _ = engine
.submit_output(
token,
task_id,
attempt,
OutputEvent::Final {
content: ContentRef::Inline {
value: value.clone(),
},
ok: false,
},
)
.await;
Ok(WorkerResult { value, ok: false })
}
}
}
async fn forward_event(
engine: &Engine,
token: &CapToken,
task_id: &TaskId,
attempt: u32,
ev: OutputEvent,
last_final: &mut Option<(Value, bool)>,
) {
if let OutputEvent::Final { content, ok } = &ev {
let value = match content {
ContentRef::Inline { value } => value.clone(),
ContentRef::FileRef {
path,
mime,
size_hint,
} => serde_json::json!({
"file_ref": path.to_string_lossy(),
"mime": mime,
"size_hint": size_hint,
}),
};
*last_final = Some((value, *ok));
}
let _ = engine.submit_output(token, task_id, attempt, ev).await;
}
async fn read_ndjson(
stdout: tokio::process::ChildStdout,
engine: &Engine,
token: &CapToken,
task_id: &TaskId,
attempt: u32,
cancel: CancellationToken,
) -> Result<Option<(Value, bool)>, WorkerError> {
let mut reader = BufReader::new(stdout).lines();
let mut last_final = None;
loop {
tokio::select! {
line_res = reader.next_line() => match line_res {
Ok(Some(line)) => {
let trimmed = line.trim();
if trimmed.is_empty() { continue; }
if let Ok(ev) = serde_json::from_str::<OutputEvent>(trimmed) {
forward_event(engine, token, task_id, attempt, ev, &mut last_final).await;
}
}
Ok(None) => break,
Err(e) => return Err(WorkerError::Failed(format!("ndjson read: {e}"))),
},
_ = cancel.cancelled() => return Err(WorkerError::Cancelled),
}
}
Ok(last_final)
}
async fn read_sse(
stdout: tokio::process::ChildStdout,
engine: &Engine,
token: &CapToken,
task_id: &TaskId,
attempt: u32,
cancel: CancellationToken,
) -> Result<Option<(Value, bool)>, WorkerError> {
let mut reader = BufReader::new(stdout).lines();
let mut last_final = None;
let mut data_buf = String::new();
loop {
tokio::select! {
line_res = reader.next_line() => match line_res {
Ok(Some(line)) => {
if line.is_empty() {
if !data_buf.is_empty() {
if let Ok(ev) = serde_json::from_str::<OutputEvent>(data_buf.trim()) {
forward_event(engine, token, task_id, attempt, ev, &mut last_final).await;
}
data_buf.clear();
}
} else if let Some(rest) = line.strip_prefix("data:") {
let payload = rest.strip_prefix(' ').unwrap_or(rest);
if !data_buf.is_empty() {
data_buf.push('\n');
}
data_buf.push_str(payload);
}
}
Ok(None) => {
if !data_buf.is_empty() {
if let Ok(ev) = serde_json::from_str::<OutputEvent>(data_buf.trim()) {
forward_event(engine, token, task_id, attempt, ev, &mut last_final).await;
}
}
break;
}
Err(e) => return Err(WorkerError::Failed(format!("sse read: {e}"))),
},
_ = cancel.cancelled() => return Err(WorkerError::Cancelled),
}
}
Ok(last_final)
}
async fn read_length_prefixed(
mut stdout: tokio::process::ChildStdout,
engine: &Engine,
token: &CapToken,
task_id: &TaskId,
attempt: u32,
cancel: CancellationToken,
) -> Result<Option<(Value, bool)>, WorkerError> {
use tokio::io::AsyncReadExt;
let mut last_final = None;
loop {
let mut len_buf = [0u8; 4];
let read_fut = stdout.read_exact(&mut len_buf);
let read_res = tokio::select! {
r = read_fut => r,
_ = cancel.cancelled() => return Err(WorkerError::Cancelled),
};
match read_res {
Ok(_) => {}
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => break, Err(e) => return Err(WorkerError::Failed(format!("len read: {e}"))),
}
let len = u32::from_be_bytes(len_buf) as usize;
if len == 0 || len > 16 * 1024 * 1024 {
break;
}
let mut payload = vec![0u8; len];
let read_fut = stdout.read_exact(&mut payload);
let read_res = tokio::select! {
r = read_fut => r,
_ = cancel.cancelled() => return Err(WorkerError::Cancelled),
};
if read_res.is_err() {
break;
}
if let Ok(ev) = serde_json::from_slice::<OutputEvent>(&payload) {
forward_event(engine, token, task_id, attempt, ev, &mut last_final).await;
}
}
Ok(last_final)
}