1use async_trait::async_trait;
14use mlua_swarm::{
15 CapToken, Ctx, Operator, SeniorBridge, SpawnHook, TaskId, WorkerError, WorkerResult,
16};
17use serde_json::Value;
18use std::collections::HashMap;
19use tokio::sync::{mpsc, oneshot, Mutex};
20
21use super::protocol::{current_parent_req_id, PendingReply, ServerMsg};
22
23pub struct WSOperatorSession {
25 sid: String,
26 tx: Mutex<Option<mpsc::UnboundedSender<ServerMsg>>>,
29 pending: Mutex<HashMap<String, oneshot::Sender<PendingReply>>>,
32}
33
34impl WSOperatorSession {
35 pub(super) fn new(sid: String, tx: mpsc::UnboundedSender<ServerMsg>) -> Self {
39 Self {
40 sid,
41 tx: Mutex::new(Some(tx)),
42 pending: Mutex::new(HashMap::new()),
43 }
44 }
45
46 pub(super) async fn replace_tx(&self, new_tx: mpsc::UnboundedSender<ServerMsg>) {
48 *self.tx.lock().await = Some(new_tx);
49 }
50
51 pub(crate) async fn clear_tx(&self) {
53 *self.tx.lock().await = None;
54 }
55
56 pub(super) async fn resolve_pending(&self, req_id: &str, reply: PendingReply) {
59 if let Some(otx) = self.pending.lock().await.remove(req_id) {
60 let _ = otx.send(reply);
61 }
62 }
63
64 async fn send_and_await(&self, req_id: String, msg: ServerMsg) -> Result<PendingReply, String> {
68 let (otx, orx) = oneshot::channel::<PendingReply>();
69 self.pending.lock().await.insert(req_id.clone(), otx);
70
71 let send_result = {
73 let guard = self.tx.lock().await;
74 match guard.as_ref() {
75 Some(tx) => tx
76 .send(msg)
77 .map_err(|_| "ws send channel closed".to_string()),
78 None => Err("ws operator disconnected".to_string()),
79 }
80 };
81 if let Err(e) = send_result {
82 self.pending.lock().await.remove(&req_id);
83 return Err(e);
84 }
85
86 orx.await
87 .map_err(|_| "ws operator: oneshot cancelled (= reply path closed)".to_string())
88 }
89
90 async fn send_oneway(&self, msg: ServerMsg) -> Result<(), String> {
92 let guard = self.tx.lock().await;
93 match guard.as_ref() {
94 Some(tx) => tx
95 .send(msg)
96 .map_err(|_| "ws send channel closed".to_string()),
97 None => Err("ws operator disconnected".to_string()),
98 }
99 }
100}
101
102#[async_trait]
103impl SeniorBridge for WSOperatorSession {
104 async fn ask(&self, task_id: &TaskId, question: Value) -> Result<Value, String> {
105 let req_id = format!("{}-ask-{}", self.sid, uuid::Uuid::new_v4());
106 let msg = ServerMsg::Ask {
107 req_id: req_id.clone(),
108 parent_req_id: current_parent_req_id(),
109 task_id: task_id.0.clone(),
110 question,
111 };
112 match self.send_and_await(req_id, msg).await? {
113 PendingReply::Answer(v) => Ok(v),
114 PendingReply::HookAck { .. } => {
115 Err("ws operator: unexpected hook_ack reply to ask".into())
116 }
117 PendingReply::SpawnAck { .. } => {
118 Err("ws operator: unexpected spawn_ack reply to ask".into())
119 }
120 }
121 }
122}
123
124#[async_trait]
125impl SpawnHook for WSOperatorSession {
126 async fn before(&self, ctx: &Ctx) -> Result<(), String> {
127 let req_id = format!("{}-hb-{}", self.sid, uuid::Uuid::new_v4());
128 let msg = ServerMsg::HookBefore {
129 req_id: req_id.clone(),
130 parent_req_id: current_parent_req_id(),
131 task_id: ctx.task_id.0.clone(),
132 agent: ctx.agent.clone(),
133 attempt: ctx.attempt,
134 };
135 match self.send_and_await(req_id, msg).await? {
136 PendingReply::HookAck { ok: true, .. } => Ok(()),
137 PendingReply::HookAck { ok: false, reason } => {
138 Err(reason.unwrap_or_else(|| "ws operator: spawn rejected".into()))
139 }
140 PendingReply::Answer(_) => {
141 Err("ws operator: unexpected answer reply to hook_before".into())
142 }
143 PendingReply::SpawnAck { .. } => {
144 Err("ws operator: unexpected spawn_ack reply to hook_before".into())
145 }
146 }
147 }
148
149 async fn after(&self, ctx: &Ctx, result: &Value) -> Result<(), String> {
150 let req_id = format!("{}-ha-{}", self.sid, uuid::Uuid::new_v4());
151 let msg = ServerMsg::HookAfter {
152 req_id,
153 parent_req_id: current_parent_req_id(),
154 task_id: ctx.task_id.0.clone(),
155 agent: ctx.agent.clone(),
156 attempt: ctx.attempt,
157 result: result.clone(),
158 };
159 let _ = self.send_oneway(msg).await;
161 Ok(())
162 }
163}
164
165#[async_trait]
166impl Operator for WSOperatorSession {
167 async fn execute(
180 &self,
181 ctx: &Ctx,
182 _system: Option<String>,
183 _prompt: String,
184 worker_token: CapToken,
185 ) -> Result<WorkerResult, WorkerError> {
186 let req_id = format!("{}-spawn-{}", self.sid, uuid::Uuid::new_v4());
187 let worker_handle = ctx
188 .meta
189 .runtime
190 .get("worker_handle")
191 .and_then(|v| v.as_str())
192 .map(|s| s.to_string());
193 let project_name_alias = ctx
194 .meta
195 .runtime
196 .get("project_name_alias")
197 .and_then(|v| v.as_str());
198 let data_sink_endpoint = ctx
199 .meta
200 .runtime
201 .get("data_sink_endpoint")
202 .and_then(|v| v.as_str());
203 let directive = default_spawn_directive(
204 &ctx.agent,
205 &ctx.task_id.0,
206 project_name_alias,
207 data_sink_endpoint,
208 );
209 let msg = ServerMsg::Spawn {
210 req_id: req_id.clone(),
211 parent_req_id: current_parent_req_id(),
212 task_id: ctx.task_id.0.clone(),
213 agent: ctx.agent.clone(),
214 attempt: ctx.attempt,
215 capability_token: worker_token.encode(),
216 worker_handle,
217 directive,
218 };
219 match self.send_and_await(req_id, msg).await {
220 Ok(PendingReply::SpawnAck {
221 value,
222 ok,
223 error: None,
224 }) => Ok(WorkerResult { value, ok }),
225 Ok(PendingReply::SpawnAck {
226 error: Some(msg), ..
227 }) => Err(WorkerError::Failed(msg)),
228 Ok(_) => Err(WorkerError::Failed(
229 "ws operator: unexpected non-spawn reply".into(),
230 )),
231 Err(e) => Err(WorkerError::Failed(format!("ws operator spawn: {e}"))),
232 }
233 }
234}
235
236pub(super) fn default_spawn_directive(
260 agent: &str,
261 task_id: &str,
262 project_name_alias: Option<&str>,
263 data_sink_endpoint: Option<&str>,
264) -> String {
265 let project_alias_line = match project_name_alias {
269 Some(a) => format!("project_name_alias: {a}\n"),
270 None => String::new(),
271 };
272 let data_endpoint_block = match data_sink_endpoint {
279 Some(base) => format!(
280 "\n\
281 [Data path endpoint — MainAgent reminder]\n\
282 When you kick a SubAgent, inject the following two lines into\n\
283 its prompt / environment so Big Response payloads (4k+ tokens,\n\
284 files, intermediate artifacts) flow directly to the Store owner,\n\
285 bypassing the MainAgent (context stays small; only the out_id\n\
286 ref is passed around).\n \
287 DATA_EMIT: {base}/v1/data/emit (POST, auth = Bearer worker_handle or ?token=)\n \
288 DATA_GET: {base}/v1/data/<out_id|out_name> (the next SubAgent fetches from $IN_REFS)\n\
289 When a SubAgent produces a Big Response, POST it to DATA_EMIT\n\
290 and return only the one-line out_id ref (do not mix the body\n\
291 in; the MainAgent must not answer directly).\n\
292 \n"
293 ),
294 None => String::new(),
295 };
296 let main_ai_reminder = match project_name_alias {
297 Some(a) => format!(
298 "\n\
299 [LDS Session Alias Reminder — MainAI mandatory]\n\
300 Before kicking the SubAgent below, call:\n \
301 mcp__lds__session_create(root=<working_dir>, alias=\"{a}\")\n\
302 (= establish a single task-level lds session; reuse on repeated dispatch).\n\
303 Then add this literal line to the SubAgent prompt body below:\n \
304 LDS Session Alias: {a}\n\
305 The SubAgent will call mcp__lds__session_start(alias=\"{a}\") on init,\n\
306 keeping worktree ownership unified across dispatches.\n\
307 (Full discipline rationale is inlined above; reach is via this directive itself,\n\
308 not via any external doc path. The 2 steps above are the complete contract.)\n\
309 \n"
310 ),
311 None => String::new(),
312 };
313 format!(
314 "[agent_primitive dispatch=@{agent}]\n\
315 worker endpoint:\n \
316 GET <base_url>/v1/worker/prompt?task_id={task_id}\n \
317 POST <base_url>/v1/worker/submit\n\
318 auth: Bearer <worker_handle from THIS Spawn payload (= short `wh-XXXXXXXX` form)>\n\
319 task_id: {task_id}\n\
320 agent_id: {agent}\n\
321 {project_alias_line}\
322 {data_endpoint_block}\
323 {main_ai_reminder}\
324 Kick a SubAgent via Agent tool with subagent_type=\"mse-worker\" (= project-local \
325 `.claude/agents/mse-worker.md`, a generic WS Operator worker). The prompt you pass \
326 to it MUST be EXACTLY these 4 lines (no preamble, no extra text):\n\
327 \n \
328 agent_id: {agent}\n \
329 worker_handle: <THIS Spawn payload's `worker_handle` field (short string `wh-XXXXXXXX`)>\n \
330 base_url: <server HTTP root, e.g. http://127.0.0.1:7786>\n \
331 task_id: {task_id}\n\
332 \n\
333 The mse-worker SubAgent self-fetches system + prompt via GET (Bearer = handle), \
334 executes as agent @{agent}, POSTs raw body to /v1/worker/submit (Bearer = handle, \
335 server resolves task_id from handle), and replies `OUTPUT` 1 word. You then forward \
336 SpawnAck {{req_id, value:{{}}, ok:true}} through your operator client — MCP path: \
337 mse_ack(sid, req_id, kind=\"spawn_ack\", ok=true) (= empty value because canonical \
338 body lives in output_tail via the POST). \
339 Do NOT fetch /v1/worker/prompt yourself. Do NOT wrap, summarize, or field-select \
340 the SubAgent reply. Observation / debug is a separate channel (= agent-inspect MCP / \
341 GET /v1/tasks/{{id}}), do NOT mix it into the forward path. \
342 If mse-worker is not registered, fall back to subagent_type=\"general-purpose\" with \
343 the same 4-line prompt + a 1-line hint pointing to this directive."
344 )
345}
346
347#[cfg(test)]
348mod tests {
349 use super::*;
350
351 #[test]
352 fn directive_omits_project_name_alias_when_none() {
353 let d = default_spawn_directive("impl-lead", "task-x", None, None);
354 assert!(!d.contains("project_name_alias:"));
355 assert!(!d.contains("LDS Session Alias"));
356 assert!(!d.contains("session_create"));
357 }
358
359 #[test]
360 fn directive_emits_project_name_alias_when_some() {
361 let d = default_spawn_directive("impl-lead", "task-x", Some("mse-task-7785"), None);
362 assert!(
364 d.contains("project_name_alias: mse-task-7785"),
365 "directive missing project_name_alias header: {d}"
366 );
367 assert!(
369 d.contains("mcp__lds__session_create(root=<working_dir>, alias=\"mse-task-7785\")"),
370 "directive missing session_create reminder: {d}"
371 );
372 assert!(
373 d.contains("LDS Session Alias: mse-task-7785"),
374 "directive missing SubAgent prompt inject line: {d}"
375 );
376 assert!(
378 d.contains("inlined above") || d.contains("complete contract"),
379 "directive should inline rationale rather than point at external doc: {d}"
380 );
381 let forbidden_doc_ref = format!(".{}/CLAUDE.md", "claude");
390 assert!(
391 !d.contains(&forbidden_doc_ref),
392 "directive must not reference {forbidden_doc_ref} (out of MainAI scope): {d}"
393 );
394 }
395
396 #[test]
397 fn directive_omits_data_endpoint_when_none() {
398 let d = default_spawn_directive("impl-lead", "task-x", None, None);
399 assert!(!d.contains("[Data path endpoint"));
400 assert!(!d.contains("DATA_EMIT"));
401 assert!(!d.contains("DATA_GET"));
402 }
403
404 #[test]
405 fn directive_emits_data_endpoint_when_some() {
406 let base = "http://127.0.0.1:7785";
407 let d = default_spawn_directive("impl-lead", "task-x", None, Some(base));
408 assert!(
409 d.contains("[Data path endpoint"),
410 "directive missing data endpoint block header: {d}"
411 );
412 assert!(
413 d.contains(&format!("DATA_EMIT: {base}/v1/data/emit")),
414 "directive missing single-mouth emit line: {d}"
415 );
416 assert!(
417 d.contains("Bearer worker_handle or ?token="),
418 "directive missing auth transport hint: {d}"
419 );
420 assert!(
421 d.contains(&format!("DATA_GET: {base}/v1/data/<out_id|out_name>")),
422 "directive missing GET line: {d}"
423 );
424 assert!(
425 !d.contains("emit-auth"),
426 "old split endpoint must not leak into directive: {d}"
427 );
428 assert!(
429 d.contains("bypassing the MainAgent") && d.contains("out_id ref"),
430 "directive should carry the ownership + bypass reasoning: {d}"
431 );
432 }
433}