1use async_trait::async_trait;
14use mlua_swarm::{
15 CapToken, Ctx, Operator, SeniorBridge, SpawnHook, TaskId, WorkerBinding, WorkerError,
16 WorkerResult,
17};
18use serde_json::Value;
19use std::collections::HashMap;
20use tokio::sync::{mpsc, oneshot, Mutex};
21
22use super::protocol::{current_parent_req_id, PendingReply, ServerMsg};
23
24pub struct WSOperatorSession {
26 sid: String,
27 tx: Mutex<Option<mpsc::UnboundedSender<ServerMsg>>>,
30 pending: Mutex<HashMap<String, oneshot::Sender<PendingReply>>>,
33}
34
35impl WSOperatorSession {
36 pub(super) fn new(sid: String, tx: mpsc::UnboundedSender<ServerMsg>) -> Self {
40 Self {
41 sid,
42 tx: Mutex::new(Some(tx)),
43 pending: Mutex::new(HashMap::new()),
44 }
45 }
46
47 pub(super) async fn replace_tx(&self, new_tx: mpsc::UnboundedSender<ServerMsg>) {
49 *self.tx.lock().await = Some(new_tx);
50 }
51
52 pub(crate) async fn clear_tx(&self) {
54 *self.tx.lock().await = None;
55 }
56
57 pub(super) async fn resolve_pending(&self, req_id: &str, reply: PendingReply) {
60 if let Some(otx) = self.pending.lock().await.remove(req_id) {
61 let _ = otx.send(reply);
62 }
63 }
64
65 async fn send_and_await(&self, req_id: String, msg: ServerMsg) -> Result<PendingReply, String> {
69 let (otx, orx) = oneshot::channel::<PendingReply>();
70 self.pending.lock().await.insert(req_id.clone(), otx);
71
72 let send_result = {
74 let guard = self.tx.lock().await;
75 match guard.as_ref() {
76 Some(tx) => tx
77 .send(msg)
78 .map_err(|_| "ws send channel closed".to_string()),
79 None => Err("ws operator disconnected".to_string()),
80 }
81 };
82 if let Err(e) = send_result {
83 self.pending.lock().await.remove(&req_id);
84 return Err(e);
85 }
86
87 orx.await
88 .map_err(|_| "ws operator: oneshot cancelled (= reply path closed)".to_string())
89 }
90
91 async fn send_oneway(&self, msg: ServerMsg) -> Result<(), String> {
93 let guard = self.tx.lock().await;
94 match guard.as_ref() {
95 Some(tx) => tx
96 .send(msg)
97 .map_err(|_| "ws send channel closed".to_string()),
98 None => Err("ws operator disconnected".to_string()),
99 }
100 }
101}
102
103#[async_trait]
104impl SeniorBridge for WSOperatorSession {
105 async fn ask(&self, task_id: &TaskId, question: Value) -> Result<Value, String> {
106 let req_id = format!("{}-ask-{}", self.sid, uuid::Uuid::new_v4());
107 let msg = ServerMsg::Ask {
108 req_id: req_id.clone(),
109 parent_req_id: current_parent_req_id(),
110 task_id: task_id.0.clone(),
111 question,
112 };
113 match self.send_and_await(req_id, msg).await? {
114 PendingReply::Answer(v) => Ok(v),
115 PendingReply::HookAck { .. } => {
116 Err("ws operator: unexpected hook_ack reply to ask".into())
117 }
118 PendingReply::SpawnAck { .. } => {
119 Err("ws operator: unexpected spawn_ack reply to ask".into())
120 }
121 }
122 }
123}
124
125#[async_trait]
126impl SpawnHook for WSOperatorSession {
127 async fn before(&self, ctx: &Ctx) -> Result<(), String> {
128 let req_id = format!("{}-hb-{}", self.sid, uuid::Uuid::new_v4());
129 let msg = ServerMsg::HookBefore {
130 req_id: req_id.clone(),
131 parent_req_id: current_parent_req_id(),
132 task_id: ctx.task_id.0.clone(),
133 agent: ctx.agent.clone(),
134 attempt: ctx.attempt,
135 };
136 match self.send_and_await(req_id, msg).await? {
137 PendingReply::HookAck { ok: true, .. } => Ok(()),
138 PendingReply::HookAck { ok: false, reason } => {
139 Err(reason.unwrap_or_else(|| "ws operator: spawn rejected".into()))
140 }
141 PendingReply::Answer(_) => {
142 Err("ws operator: unexpected answer reply to hook_before".into())
143 }
144 PendingReply::SpawnAck { .. } => {
145 Err("ws operator: unexpected spawn_ack reply to hook_before".into())
146 }
147 }
148 }
149
150 async fn after(&self, ctx: &Ctx, result: &Value) -> Result<(), String> {
151 let req_id = format!("{}-ha-{}", self.sid, uuid::Uuid::new_v4());
152 let msg = ServerMsg::HookAfter {
153 req_id,
154 parent_req_id: current_parent_req_id(),
155 task_id: ctx.task_id.0.clone(),
156 agent: ctx.agent.clone(),
157 attempt: ctx.attempt,
158 result: result.clone(),
159 };
160 let _ = self.send_oneway(msg).await;
162 Ok(())
163 }
164}
165
166#[async_trait]
167impl Operator for WSOperatorSession {
168 async fn execute(
188 &self,
189 ctx: &Ctx,
190 _system: Option<String>,
191 _prompt: String,
192 worker: Option<WorkerBinding>,
193 worker_token: CapToken,
194 ) -> Result<WorkerResult, WorkerError> {
195 let Some(worker) = worker else {
196 return Err(WorkerError::Failed(format!(
197 "agent '{}' has no worker_binding; WS thin-path requires one \
198 (Blueprint AgentDef.profile.worker_binding)",
199 ctx.agent
200 )));
201 };
202 let req_id = format!("{}-spawn-{}", self.sid, uuid::Uuid::new_v4());
203 let worker_handle = ctx
204 .meta
205 .runtime
206 .get("worker_handle")
207 .and_then(|v| v.as_str())
208 .map(|s| s.to_string());
209 let project_name_alias = ctx
210 .meta
211 .runtime
212 .get("project_name_alias")
213 .and_then(|v| v.as_str());
214 let data_sink_endpoint = ctx
215 .meta
216 .runtime
217 .get("data_sink_endpoint")
218 .and_then(|v| v.as_str());
219 let directive = default_spawn_directive(
220 &ctx.agent,
221 &ctx.task_id.0,
222 &worker.variant,
223 project_name_alias,
224 data_sink_endpoint,
225 );
226 let msg = ServerMsg::Spawn {
227 req_id: req_id.clone(),
228 parent_req_id: current_parent_req_id(),
229 task_id: ctx.task_id.0.clone(),
230 agent: ctx.agent.clone(),
231 attempt: ctx.attempt,
232 capability_token: worker_token.encode(),
233 worker_handle,
234 worker: Some(worker),
235 directive,
236 };
237 match self.send_and_await(req_id, msg).await {
238 Ok(PendingReply::SpawnAck {
239 value,
240 ok,
241 error: None,
242 }) => Ok(WorkerResult { value, ok }),
243 Ok(PendingReply::SpawnAck {
244 error: Some(msg), ..
245 }) => Err(WorkerError::Failed(msg)),
246 Ok(_) => Err(WorkerError::Failed(
247 "ws operator: unexpected non-spawn reply".into(),
248 )),
249 Err(e) => Err(WorkerError::Failed(format!("ws operator spawn: {e}"))),
250 }
251 }
252
253 fn requires_worker_binding(&self) -> bool {
254 true
255 }
256}
257
258pub(super) fn default_spawn_directive(
292 agent: &str,
293 task_id: &str,
294 subagent_type: &str,
295 project_name_alias: Option<&str>,
296 data_sink_endpoint: Option<&str>,
297) -> String {
298 let project_alias_line = match project_name_alias {
302 Some(a) => format!("project_name_alias: {a}\n"),
303 None => String::new(),
304 };
305 let data_endpoint_block = match data_sink_endpoint {
312 Some(base) => format!(
313 "\n\
314 [Data path endpoint — MainAgent reminder]\n\
315 When you kick a SubAgent, inject the following two lines into\n\
316 its prompt / environment so Big Response payloads (4k+ tokens,\n\
317 files, intermediate artifacts) flow directly to the Store owner,\n\
318 bypassing the MainAgent (context stays small; only the out_id\n\
319 ref is passed around).\n \
320 DATA_EMIT: {base}/v1/data/emit (POST, auth = Bearer worker_handle or ?token=)\n \
321 DATA_GET: {base}/v1/data/<out_id|out_name> (the next SubAgent fetches from $IN_REFS)\n\
322 When a SubAgent produces a Big Response, POST it to DATA_EMIT\n\
323 and return only the one-line out_id ref (do not mix the body\n\
324 in; the MainAgent must not answer directly).\n\
325 \n"
326 ),
327 None => String::new(),
328 };
329 let main_ai_reminder = match project_name_alias {
330 Some(a) => format!(
331 "\n\
332 [LDS Session Alias Reminder — MainAI mandatory]\n\
333 Before kicking the SubAgent below, call:\n \
334 mcp__lds__session_create(root=<working_dir>, alias=\"{a}\")\n\
335 (= establish a single task-level lds session; reuse on repeated dispatch).\n\
336 Then add this literal line to the SubAgent prompt body below:\n \
337 LDS Session Alias: {a}\n\
338 The SubAgent will call mcp__lds__session_start(alias=\"{a}\") on init,\n\
339 keeping worktree ownership unified across dispatches.\n\
340 (Full discipline rationale is inlined above; reach is via this directive itself,\n\
341 not via any external doc path. The 2 steps above are the complete contract.)\n\
342 \n"
343 ),
344 None => String::new(),
345 };
346 format!(
347 "[agent_primitive dispatch=@{agent}]\n\
348 worker endpoint:\n \
349 GET <base_url>/v1/worker/prompt?task_id={task_id}\n \
350 POST <base_url>/v1/worker/submit\n\
351 auth: Bearer <worker_handle from THIS Spawn payload (= short `wh-XXXXXXXX` form)>\n\
352 task_id: {task_id}\n\
353 agent_id: {agent}\n\
354 {project_alias_line}\
355 {data_endpoint_block}\
356 {main_ai_reminder}\
357 Kick a SubAgent via Agent tool with subagent_type=\"{subagent_type}\" (= project-local \
358 `.claude/agents/{subagent_type}.md`, this agent's Blueprint-declared worker binding). \
359 The prompt you pass to it MUST be EXACTLY these 4 lines (no preamble, no extra text):\n\
360 \n \
361 agent_id: {agent}\n \
362 worker_handle: <THIS Spawn payload's `worker_handle` field (short string `wh-XXXXXXXX`)>\n \
363 base_url: <server HTTP root, e.g. http://127.0.0.1:7786>\n \
364 task_id: {task_id}\n\
365 \n\
366 The SubAgent self-fetches system + prompt via GET (Bearer = handle), \
367 executes as agent @{agent}, POSTs raw body to /v1/worker/submit (Bearer = handle, \
368 server resolves task_id from handle), and replies `OUTPUT` 1 word. You then forward \
369 SpawnAck {{req_id, value:{{}}, ok:true}} through your operator client — MCP path: \
370 mse_ack(sid, req_id, kind=\"spawn_ack\", ok=true) (= empty value because canonical \
371 body lives in output_tail via the POST). \
372 Do NOT fetch /v1/worker/prompt yourself. Do NOT wrap, summarize, or field-select \
373 the SubAgent reply. Observation / debug is a separate channel (= agent-inspect MCP / \
374 GET /v1/tasks/{{id}}), do NOT mix it into the forward path. \
375 If the SubAgent type is not registered, FAIL LOUD: reply SpawnAck ok=false with an \
376 error explaining the missing `.claude/agents/{subagent_type}.md` — do NOT fall back \
377 to another subagent_type."
378 )
379}
380
381#[cfg(test)]
382mod tests {
383 use super::*;
384
385 #[test]
386 fn directive_omits_project_name_alias_when_none() {
387 let d = default_spawn_directive("impl-lead", "task-x", "mse-worker-coder", None, None);
388 assert!(!d.contains("project_name_alias:"));
389 assert!(!d.contains("LDS Session Alias"));
390 assert!(!d.contains("session_create"));
391 }
392
393 #[test]
394 fn directive_emits_project_name_alias_when_some() {
395 let d = default_spawn_directive(
396 "impl-lead",
397 "task-x",
398 "mse-worker-coder",
399 Some("mse-task-7785"),
400 None,
401 );
402 assert!(
404 d.contains("project_name_alias: mse-task-7785"),
405 "directive missing project_name_alias header: {d}"
406 );
407 assert!(
409 d.contains("mcp__lds__session_create(root=<working_dir>, alias=\"mse-task-7785\")"),
410 "directive missing session_create reminder: {d}"
411 );
412 assert!(
413 d.contains("LDS Session Alias: mse-task-7785"),
414 "directive missing SubAgent prompt inject line: {d}"
415 );
416 assert!(
418 d.contains("inlined above") || d.contains("complete contract"),
419 "directive should inline rationale rather than point at external doc: {d}"
420 );
421 let forbidden_doc_ref = format!(".{}/CLAUDE.md", "claude");
430 assert!(
431 !d.contains(&forbidden_doc_ref),
432 "directive must not reference {forbidden_doc_ref} (out of MainAI scope): {d}"
433 );
434 }
435
436 #[test]
437 fn directive_omits_data_endpoint_when_none() {
438 let d = default_spawn_directive("impl-lead", "task-x", "mse-worker-coder", None, None);
439 assert!(!d.contains("[Data path endpoint"));
440 assert!(!d.contains("DATA_EMIT"));
441 assert!(!d.contains("DATA_GET"));
442 }
443
444 #[test]
445 fn directive_emits_data_endpoint_when_some() {
446 let base = "http://127.0.0.1:7785";
447 let d =
448 default_spawn_directive("impl-lead", "task-x", "mse-worker-coder", None, Some(base));
449 assert!(
450 d.contains("[Data path endpoint"),
451 "directive missing data endpoint block header: {d}"
452 );
453 assert!(
454 d.contains(&format!("DATA_EMIT: {base}/v1/data/emit")),
455 "directive missing single-mouth emit line: {d}"
456 );
457 assert!(
458 d.contains("Bearer worker_handle or ?token="),
459 "directive missing auth transport hint: {d}"
460 );
461 assert!(
462 d.contains(&format!("DATA_GET: {base}/v1/data/<out_id|out_name>")),
463 "directive missing GET line: {d}"
464 );
465 assert!(
466 !d.contains("emit-auth"),
467 "old split endpoint must not leak into directive: {d}"
468 );
469 assert!(
470 d.contains("bypassing the MainAgent") && d.contains("out_id ref"),
471 "directive should carry the ownership + bypass reasoning: {d}"
472 );
473 }
474
475 #[test]
476 fn directive_carries_declared_subagent_type_and_has_no_fallback() {
477 let d = default_spawn_directive("impl-lead", "task-x", "mse-worker-coder", None, None);
478 assert!(
479 d.contains("subagent_type=\"mse-worker-coder\""),
480 "directive must carry the Blueprint-declared subagent_type literally: {d}"
481 );
482 assert!(
483 d.contains(".claude/agents/mse-worker-coder.md"),
484 "directive must reference the declared subagent's own .md path: {d}"
485 );
486 assert!(
488 !d.contains("general-purpose"),
489 "directive must not fall back to subagent_type=\"general-purpose\": {d}"
490 );
491 assert!(
492 !d.contains("mse-worker\""),
493 "directive must not carry the old hardcoded \"mse-worker\" literal: {d}"
494 );
495 assert!(
496 d.contains("FAIL LOUD"),
497 "directive must instruct the MainAI to fail loud instead of falling back: {d}"
498 );
499 }
500}