use async_trait::async_trait;
use mlua_swarm::{
CapToken, Ctx, Operator, SeniorBridge, SpawnHook, TaskId, WorkerError, WorkerResult,
};
use serde_json::Value;
use std::collections::HashMap;
use tokio::sync::{mpsc, oneshot, Mutex};
use super::protocol::{current_parent_req_id, PendingReply, ServerMsg};
pub struct WSOperatorSession {
sid: String,
tx: Mutex<Option<mpsc::UnboundedSender<ServerMsg>>>,
pending: Mutex<HashMap<String, oneshot::Sender<PendingReply>>>,
}
impl WSOperatorSession {
pub(super) fn new(sid: String, tx: mpsc::UnboundedSender<ServerMsg>) -> Self {
Self {
sid,
tx: Mutex::new(Some(tx)),
pending: Mutex::new(HashMap::new()),
}
}
pub(super) async fn replace_tx(&self, new_tx: mpsc::UnboundedSender<ServerMsg>) {
*self.tx.lock().await = Some(new_tx);
}
pub(crate) async fn clear_tx(&self) {
*self.tx.lock().await = None;
}
pub(super) async fn resolve_pending(&self, req_id: &str, reply: PendingReply) {
if let Some(otx) = self.pending.lock().await.remove(req_id) {
let _ = otx.send(reply);
}
}
async fn send_and_await(&self, req_id: String, msg: ServerMsg) -> Result<PendingReply, String> {
let (otx, orx) = oneshot::channel::<PendingReply>();
self.pending.lock().await.insert(req_id.clone(), otx);
let send_result = {
let guard = self.tx.lock().await;
match guard.as_ref() {
Some(tx) => tx
.send(msg)
.map_err(|_| "ws send channel closed".to_string()),
None => Err("ws operator disconnected".to_string()),
}
};
if let Err(e) = send_result {
self.pending.lock().await.remove(&req_id);
return Err(e);
}
orx.await
.map_err(|_| "ws operator: oneshot cancelled (= reply path closed)".to_string())
}
async fn send_oneway(&self, msg: ServerMsg) -> Result<(), String> {
let guard = self.tx.lock().await;
match guard.as_ref() {
Some(tx) => tx
.send(msg)
.map_err(|_| "ws send channel closed".to_string()),
None => Err("ws operator disconnected".to_string()),
}
}
}
#[async_trait]
impl SeniorBridge for WSOperatorSession {
async fn ask(&self, task_id: &TaskId, question: Value) -> Result<Value, String> {
let req_id = format!("{}-ask-{}", self.sid, uuid::Uuid::new_v4());
let msg = ServerMsg::Ask {
req_id: req_id.clone(),
parent_req_id: current_parent_req_id(),
task_id: task_id.0.clone(),
question,
};
match self.send_and_await(req_id, msg).await? {
PendingReply::Answer(v) => Ok(v),
PendingReply::HookAck { .. } => {
Err("ws operator: unexpected hook_ack reply to ask".into())
}
PendingReply::SpawnAck { .. } => {
Err("ws operator: unexpected spawn_ack reply to ask".into())
}
}
}
}
#[async_trait]
impl SpawnHook for WSOperatorSession {
async fn before(&self, ctx: &Ctx) -> Result<(), String> {
let req_id = format!("{}-hb-{}", self.sid, uuid::Uuid::new_v4());
let msg = ServerMsg::HookBefore {
req_id: req_id.clone(),
parent_req_id: current_parent_req_id(),
task_id: ctx.task_id.0.clone(),
agent: ctx.agent.clone(),
attempt: ctx.attempt,
};
match self.send_and_await(req_id, msg).await? {
PendingReply::HookAck { ok: true, .. } => Ok(()),
PendingReply::HookAck { ok: false, reason } => {
Err(reason.unwrap_or_else(|| "ws operator: spawn rejected".into()))
}
PendingReply::Answer(_) => {
Err("ws operator: unexpected answer reply to hook_before".into())
}
PendingReply::SpawnAck { .. } => {
Err("ws operator: unexpected spawn_ack reply to hook_before".into())
}
}
}
async fn after(&self, ctx: &Ctx, result: &Value) -> Result<(), String> {
let req_id = format!("{}-ha-{}", self.sid, uuid::Uuid::new_v4());
let msg = ServerMsg::HookAfter {
req_id,
parent_req_id: current_parent_req_id(),
task_id: ctx.task_id.0.clone(),
agent: ctx.agent.clone(),
attempt: ctx.attempt,
result: result.clone(),
};
let _ = self.send_oneway(msg).await;
Ok(())
}
}
#[async_trait]
impl Operator for WSOperatorSession {
async fn execute(
&self,
ctx: &Ctx,
_system: Option<String>,
_prompt: String,
worker_token: CapToken,
) -> Result<WorkerResult, WorkerError> {
let req_id = format!("{}-spawn-{}", self.sid, uuid::Uuid::new_v4());
let worker_handle = ctx
.meta
.runtime
.get("worker_handle")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
let project_name_alias = ctx
.meta
.runtime
.get("project_name_alias")
.and_then(|v| v.as_str());
let data_sink_endpoint = ctx
.meta
.runtime
.get("data_sink_endpoint")
.and_then(|v| v.as_str());
let directive = default_spawn_directive(
&ctx.agent,
&ctx.task_id.0,
project_name_alias,
data_sink_endpoint,
);
let msg = ServerMsg::Spawn {
req_id: req_id.clone(),
parent_req_id: current_parent_req_id(),
task_id: ctx.task_id.0.clone(),
agent: ctx.agent.clone(),
attempt: ctx.attempt,
capability_token: worker_token.encode(),
worker_handle,
directive,
};
match self.send_and_await(req_id, msg).await {
Ok(PendingReply::SpawnAck {
value,
ok,
error: None,
}) => Ok(WorkerResult { value, ok }),
Ok(PendingReply::SpawnAck {
error: Some(msg), ..
}) => Err(WorkerError::Failed(msg)),
Ok(_) => Err(WorkerError::Failed(
"ws operator: unexpected non-spawn reply".into(),
)),
Err(e) => Err(WorkerError::Failed(format!("ws operator spawn: {e}"))),
}
}
}
pub(super) fn default_spawn_directive(
agent: &str,
task_id: &str,
project_name_alias: Option<&str>,
data_sink_endpoint: Option<&str>,
) -> String {
let project_alias_line = match project_name_alias {
Some(a) => format!("project_name_alias: {a}\n"),
None => String::new(),
};
let data_endpoint_block = match data_sink_endpoint {
Some(base) => format!(
"\n\
[Data path endpoint — MainAgent reminder]\n\
When you kick a SubAgent, inject the following two lines into\n\
its prompt / environment so Big Response payloads (4k+ tokens,\n\
files, intermediate artifacts) flow directly to the Store owner,\n\
bypassing the MainAgent (context stays small; only the out_id\n\
ref is passed around).\n \
DATA_EMIT: {base}/v1/data/emit (POST, auth = Bearer worker_handle or ?token=)\n \
DATA_GET: {base}/v1/data/<out_id|out_name> (the next SubAgent fetches from $IN_REFS)\n\
When a SubAgent produces a Big Response, POST it to DATA_EMIT\n\
and return only the one-line out_id ref (do not mix the body\n\
in; the MainAgent must not answer directly).\n\
\n"
),
None => String::new(),
};
let main_ai_reminder = match project_name_alias {
Some(a) => format!(
"\n\
[LDS Session Alias Reminder — MainAI mandatory]\n\
Before kicking the SubAgent below, call:\n \
mcp__lds__session_create(root=<working_dir>, alias=\"{a}\")\n\
(= establish a single task-level lds session; reuse on repeated dispatch).\n\
Then add this literal line to the SubAgent prompt body below:\n \
LDS Session Alias: {a}\n\
The SubAgent will call mcp__lds__session_start(alias=\"{a}\") on init,\n\
keeping worktree ownership unified across dispatches.\n\
(Full discipline rationale is inlined above; reach is via this directive itself,\n\
not via any external doc path. The 2 steps above are the complete contract.)\n\
\n"
),
None => String::new(),
};
format!(
"[agent_primitive dispatch=@{agent}]\n\
worker endpoint:\n \
GET <base_url>/v1/worker/prompt?task_id={task_id}\n \
POST <base_url>/v1/worker/submit\n\
auth: Bearer <worker_handle from THIS Spawn payload (= short `wh-XXXXXXXX` form)>\n\
task_id: {task_id}\n\
agent_id: {agent}\n\
{project_alias_line}\
{data_endpoint_block}\
{main_ai_reminder}\
Kick a SubAgent via Agent tool with subagent_type=\"mse-worker\" (= project-local \
`.claude/agents/mse-worker.md`, a generic WS Operator worker). The prompt you pass \
to it MUST be EXACTLY these 4 lines (no preamble, no extra text):\n\
\n \
agent_id: {agent}\n \
worker_handle: <THIS Spawn payload's `worker_handle` field (short string `wh-XXXXXXXX`)>\n \
base_url: <server HTTP root, e.g. http://127.0.0.1:7786>\n \
task_id: {task_id}\n\
\n\
The mse-worker SubAgent self-fetches system + prompt via GET (Bearer = handle), \
executes as agent @{agent}, POSTs raw body to /v1/worker/submit (Bearer = handle, \
server resolves task_id from handle), and replies `OUTPUT` 1 word. You then forward \
SpawnAck {{req_id, value:{{}}, ok:true}} through your operator client — MCP path: \
mse_ack(sid, req_id, kind=\"spawn_ack\", ok=true) (= empty value because canonical \
body lives in output_tail via the POST). \
Do NOT fetch /v1/worker/prompt yourself. Do NOT wrap, summarize, or field-select \
the SubAgent reply. Observation / debug is a separate channel (= agent-inspect MCP / \
GET /v1/tasks/{{id}}), do NOT mix it into the forward path. \
If mse-worker is not registered, fall back to subagent_type=\"general-purpose\" with \
the same 4-line prompt + a 1-line hint pointing to this directive."
)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn directive_omits_project_name_alias_when_none() {
let d = default_spawn_directive("impl-lead", "task-x", None, None);
assert!(!d.contains("project_name_alias:"));
assert!(!d.contains("LDS Session Alias"));
assert!(!d.contains("session_create"));
}
#[test]
fn directive_emits_project_name_alias_when_some() {
let d = default_spawn_directive("impl-lead", "task-x", Some("mse-task-7785"), None);
assert!(
d.contains("project_name_alias: mse-task-7785"),
"directive missing project_name_alias header: {d}"
);
assert!(
d.contains("mcp__lds__session_create(root=<working_dir>, alias=\"mse-task-7785\")"),
"directive missing session_create reminder: {d}"
);
assert!(
d.contains("LDS Session Alias: mse-task-7785"),
"directive missing SubAgent prompt inject line: {d}"
);
assert!(
d.contains("inlined above") || d.contains("complete contract"),
"directive should inline rationale rather than point at external doc: {d}"
);
let forbidden_doc_ref = format!(".{}/CLAUDE.md", "claude");
assert!(
!d.contains(&forbidden_doc_ref),
"directive must not reference {forbidden_doc_ref} (out of MainAI scope): {d}"
);
}
#[test]
fn directive_omits_data_endpoint_when_none() {
let d = default_spawn_directive("impl-lead", "task-x", None, None);
assert!(!d.contains("[Data path endpoint"));
assert!(!d.contains("DATA_EMIT"));
assert!(!d.contains("DATA_GET"));
}
#[test]
fn directive_emits_data_endpoint_when_some() {
let base = "http://127.0.0.1:7785";
let d = default_spawn_directive("impl-lead", "task-x", None, Some(base));
assert!(
d.contains("[Data path endpoint"),
"directive missing data endpoint block header: {d}"
);
assert!(
d.contains(&format!("DATA_EMIT: {base}/v1/data/emit")),
"directive missing single-mouth emit line: {d}"
);
assert!(
d.contains("Bearer worker_handle or ?token="),
"directive missing auth transport hint: {d}"
);
assert!(
d.contains(&format!("DATA_GET: {base}/v1/data/<out_id|out_name>")),
"directive missing GET line: {d}"
);
assert!(
!d.contains("emit-auth"),
"old split endpoint must not leak into directive: {d}"
);
assert!(
d.contains("bypassing the MainAgent") && d.contains("out_id ref"),
"directive should carry the ownership + bypass reasoning: {d}"
);
}
}