Skip to main content

lash_protocol_rlm/
control_tools.rs

1use async_trait::async_trait;
2use lash_core::{
3    ToolArgumentProjectionPolicy, ToolAvailabilityConfig, ToolCall, ToolContract, ToolControl,
4    ToolDefinition, ToolManifest, ToolProvider, ToolResult, ToolScheduling,
5};
6use serde_json::{Value, json};
7use std::sync::Arc;
8
9use crate::projection::RlmSeed;
10
11pub(crate) struct RlmControlToolsProvider;
12
13#[async_trait]
14impl ToolProvider for RlmControlToolsProvider {
15    fn tool_manifests(&self) -> Vec<ToolManifest> {
16        vec![continue_as_tool_definition().manifest()]
17    }
18
19    fn resolve_contract(&self, name: &str) -> Option<Arc<ToolContract>> {
20        (name == "continue_as").then(|| Arc::new(continue_as_tool_definition().contract()))
21    }
22
23    async fn execute(&self, call: ToolCall<'_>) -> ToolResult {
24        let result = match call.name {
25            "continue_as" => continue_as_switch_frame(call.args),
26            _ => return ToolResult::err_fmt(format_args!("Unknown tool: {}", call.name)),
27        };
28        finalise_tool_result(result)
29    }
30}
31
32pub fn continue_as_tool_definition() -> ToolDefinition {
33    ToolDefinition::raw(
34        "tool:continue_as",
35        "continue_as",
36        "Tail-call into a fresh RLM AgentFrame inside the current session with a clean window.\n\nThe new frame inherits **nothing** implicitly — no variables or message history. Pass everything it needs via `seed: { name: value, ... }`. Seed values copied from read-only values stay read-only in the new frame; computed expressions become writable variables.\n\n- Use when the current trajectory is stale, dominated by failed attempts, or the context budget is tight.\n- Treat `control.continue_as(...)` as a terminal control action: make it the last meaningful statement in the lashlang block, and do not call `submit` or perform more work after it.\n- `task` packs the concrete goal, constraints, and next steps the new frame must act on.\n- `seed` packs the concrete state (paths, facts already learned, partial results, read-only values) the new frame needs in scope; leave bulky raw output behind.\n- `continue_as` only changes the active AgentFrame. It does not start, transfer, list, cancel, or otherwise manage processes.",
37        continue_as_input_schema(),
38        continue_as_output_schema(),
39    )
40    .with_examples(vec![
41        r#"await control.continue_as({ task: "continue the audit from the summarized findings", seed: { problem: input.prompt, findings: findings } })?"#.into(),
42    ])
43    .with_agent_surface(lash_core::ToolAgentSurface::new(["control"], "continue_as"))
44    .with_argument_projection(ToolArgumentProjectionPolicy::preserve_projected_refs_in_field(
45        "seed",
46    ))
47    .with_availability(ToolAvailabilityConfig::callable())
48    .with_scheduling(ToolScheduling::Parallel)
49}
50
51fn continue_as_output_schema() -> Value {
52    json!({
53        "type": "object",
54        "properties": {
55            "ok": { "type": "boolean" },
56            "frame_id": { "type": "string" },
57            "task": { "type": "string" },
58            "seed_keys": {
59                "type": "array",
60                "items": { "type": "string" }
61            },
62            "seed_count": { "type": "integer", "minimum": 0 }
63        },
64        "required": [
65            "ok",
66            "frame_id",
67            "task",
68            "seed_keys",
69            "seed_count"
70        ],
71        "additionalProperties": false
72    })
73}
74
75pub fn continue_as_input_schema() -> Value {
76    json!({
77        "type": "object",
78        "properties": {
79            "task": {
80                "type": "string",
81                "description": "Task for the new AgentFrame."
82            },
83            "seed": {
84                "type": "object",
85                "additionalProperties": true,
86                "description": "Optional record/dict of concrete state for the new AgentFrame."
87            }
88        },
89        "required": ["task"],
90        "additionalProperties": false
91    })
92}
93
94fn continue_as_switch_frame(args: &Value) -> Result<ContinueAsResult, String> {
95    let task = required_string(args, "task")?;
96    let seed = RlmSeed::from_tool_args(args).map_err(|err| format!("continue_as {err}"))?;
97    let mut seed_keys = seed
98        .globals
99        .keys()
100        .cloned()
101        .chain(seed.projected.entries.iter().map(|(name, _)| name.clone()))
102        .collect::<Vec<_>>();
103    seed_keys.sort();
104    let seed_count = seed_keys.len();
105    let frame_id = uuid::Uuid::new_v4().to_string();
106    let initial_nodes = crate::rlm_seed_initial_nodes(seed);
107    let initial_nodes = initial_nodes
108        .into_iter()
109        .map(|node| {
110            serde_json::to_value(node)
111                .map_err(|err| format!("failed to encode continue_as frame seed node: {err}"))
112        })
113        .collect::<Result<Vec<_>, _>>()?;
114
115    Ok(ContinueAsResult {
116        value: json!({
117            "ok": true,
118            "frame_id": frame_id.clone(),
119            "task": task.clone(),
120            "seed_keys": seed_keys,
121            "seed_count": seed_count,
122        }),
123        control: ToolControl::SwitchAgentFrame {
124            frame_id,
125            initial_nodes,
126            task: Some(task),
127        },
128    })
129}
130
131fn required_string(args: &Value, key: &str) -> Result<String, String> {
132    args.get(key)
133        .and_then(Value::as_str)
134        .map(str::trim)
135        .filter(|value| !value.is_empty())
136        .map(ToOwned::to_owned)
137        .ok_or_else(|| format!("missing required parameter: {key}"))
138}
139
140struct ContinueAsResult {
141    value: Value,
142    control: ToolControl,
143}
144
145fn finalise_tool_result(result: Result<ContinueAsResult, String>) -> ToolResult {
146    match result {
147        Ok(result) => ToolResult::ok(result.value).with_control(result.control),
148        Err(err) => ToolResult::err(json!(err)),
149    }
150}
151
152#[cfg(test)]
153mod tests {
154    use super::*;
155    use crate::projection::{decode_rlm_protocol_event, rlm_protocol_event};
156    use std::sync::{Arc, Mutex};
157
158    use lash_core::plugin::runtime_host::{
159        SessionGraphService, SessionLifecycleService, SessionStateService,
160    };
161    use lash_core::plugin::{PluginError, SessionHandle};
162    use lash_core::runtime::RuntimeSessionState;
163    use lash_core::{
164        SessionAppendNode, SessionCreateRequest, SessionPolicy, SessionSnapshot, ToolProvider,
165    };
166    use lash_rlm_types::{RlmProtocolEvent, RlmTermination};
167
168    fn model_spec(model: &str) -> lash_core::ModelSpec {
169        lash_core::ModelSpec::from_token_limits(model, None, 200_000, None)
170            .expect("valid test model spec")
171    }
172
173    #[test]
174    fn continue_as_contract_documents_switch_result() {
175        let definition = continue_as_tool_definition();
176
177        assert_eq!(
178            definition.contract.output_schema["required"],
179            json!(["ok", "frame_id", "task", "seed_keys", "seed_count"])
180        );
181        let rendered = definition.compact_contract().render_signature();
182        assert!(rendered.contains("frame_id"), "{rendered}");
183        assert!(!rendered.contains("handle_count"), "{rendered}");
184        assert!(!rendered.contains("projected_count"), "{rendered}");
185        assert!(!rendered.contains("global_count"), "{rendered}");
186    }
187
188    #[derive(Default)]
189    struct BatonManager {
190        snapshot: RuntimeSessionState,
191        created: Mutex<Vec<SessionCreateRequest>>,
192        closed: Mutex<Vec<String>>,
193    }
194
195    #[test]
196    fn continue_as_tool_definition_preserves_projected_seed_refs_by_metadata() {
197        assert_eq!(
198            continue_as_tool_definition().manifest.argument_projection,
199            ToolArgumentProjectionPolicy::preserve_projected_refs_in_field("seed")
200        );
201    }
202
203    #[async_trait]
204    impl SessionStateService for BatonManager {
205        async fn snapshot_current(&self) -> Result<SessionSnapshot, PluginError> {
206            Ok(self.snapshot.to_snapshot())
207        }
208
209        async fn snapshot_session(
210            &self,
211            _session_id: &str,
212        ) -> Result<SessionSnapshot, PluginError> {
213            Ok(self.snapshot.to_snapshot())
214        }
215        async fn tool_catalog(
216            &self,
217            _session_id: &str,
218        ) -> Result<Vec<serde_json::Value>, PluginError> {
219            Ok(Vec::new())
220        }
221    }
222
223    #[async_trait]
224    impl SessionLifecycleService for BatonManager {
225        async fn create_session(
226            &self,
227            request: SessionCreateRequest,
228        ) -> Result<SessionHandle, PluginError> {
229            self.created.lock().expect("created").push(request.clone());
230            Ok(SessionHandle {
231                session_id: request.session_id.unwrap_or_else(|| "child".to_string()),
232                parent_session_id: request.relation.parent_session_id().map(ToOwned::to_owned),
233                policy: request.policy.unwrap_or_default(),
234            })
235        }
236
237        async fn close_session(&self, session_id: &str) -> Result<(), PluginError> {
238            self.closed
239                .lock()
240                .expect("closed")
241                .push(session_id.to_string());
242            Ok(())
243        }
244    }
245
246    #[async_trait]
247    impl SessionGraphService for BatonManager {}
248
249    #[async_trait]
250    impl lash_core::ProcessService for BatonManager {
251        async fn start(
252            &self,
253            _session_id: &str,
254            _registration: lash_core::ProcessRegistration,
255            _options: lash_core::ProcessStartOptions,
256            _scope: lash_core::ProcessOpScope<'_>,
257        ) -> Result<lash_core::ProcessRecord, PluginError> {
258            Err(PluginError::Session(
259                "process starts are unavailable in this test".to_string(),
260            ))
261        }
262
263        async fn await_process(
264            &self,
265            _process_id: &str,
266            _scope: lash_core::ProcessOpScope<'_>,
267        ) -> Result<lash_core::ProcessAwaitOutput, PluginError> {
268            Err(PluginError::Session(
269                "process awaiting is unavailable in this test".to_string(),
270            ))
271        }
272
273        async fn list_visible(
274            &self,
275            _session_id: &str,
276            _mode: lash_core::ProcessListMode,
277            _scope: lash_core::ProcessOpScope<'_>,
278        ) -> Result<Vec<lash_core::runtime::ProcessHandleGrantEntry>, PluginError> {
279            Ok(Vec::new())
280        }
281
282        async fn validate_visible(
283            &self,
284            _session_id: &str,
285            _handle_ids: &[String],
286            _scope: lash_core::ProcessOpScope<'_>,
287        ) -> Result<(), PluginError> {
288            Err(PluginError::Session(
289                "continue_as must not validate process handles".to_string(),
290            ))
291        }
292
293        async fn cancel(
294            &self,
295            _session_id: &str,
296            _process_id: &str,
297            _scope: lash_core::ProcessOpScope<'_>,
298        ) -> Result<lash_core::ProcessRecord, PluginError> {
299            Err(PluginError::Session(
300                "process cancellation is unavailable in this test".to_string(),
301            ))
302        }
303
304        async fn signal(
305            &self,
306            _session_id: &str,
307            _process_id: &str,
308            _signal_id: String,
309            _payload: serde_json::Value,
310            _scope: lash_core::ProcessOpScope<'_>,
311        ) -> Result<lash_core::ProcessEvent, PluginError> {
312            Err(PluginError::Session(
313                "process signalling is unavailable in this test".to_string(),
314            ))
315        }
316
317        async fn transfer(
318            &self,
319            _from_session_id: &str,
320            _to_session_id: &str,
321            _process_ids: Vec<String>,
322            _scope: lash_core::ProcessOpScope<'_>,
323        ) -> Result<(), PluginError> {
324            Err(PluginError::Session(
325                "continue_as must not transfer process handles".to_string(),
326            ))
327        }
328
329        async fn cancel_unreferenced(
330            &self,
331            _session_id: &str,
332            _keep_process_ids: Vec<String>,
333            _scope: lash_core::ProcessOpScope<'_>,
334        ) -> Result<Vec<lash_core::ProcessRecord>, PluginError> {
335            Err(PluginError::Session(
336                "continue_as must not cancel process handles".to_string(),
337            ))
338        }
339    }
340
341    async fn run_continue_as(
342        provider: &RlmControlToolsProvider,
343        manager: Arc<BatonManager>,
344        args: &Value,
345    ) -> ToolResult {
346        let sessions: Arc<dyn SessionStateService> = manager.clone();
347        let session_lifecycle: Arc<dyn SessionLifecycleService> = manager.clone();
348        let session_graph: Arc<dyn SessionGraphService> = manager.clone();
349        let processes: Arc<dyn lash_core::ProcessService> = manager;
350        let context = lash_core::ToolContext::__for_testing(
351            "test-session".to_string(),
352            sessions,
353            session_lifecycle,
354            session_graph,
355            processes,
356            Arc::new(lash_core::InMemoryAttachmentStore::new()),
357            lash_core::DirectCompletionClient::from_fn(|_, _| {
358                Err(lash_core::PluginError::Session(
359                    "direct completions are unavailable in continue_as tests".to_string(),
360                ))
361            }),
362            Some("continue-as-test".to_string()),
363        );
364        provider
365            .execute(lash_core::ToolCall {
366                name: "continue_as",
367                args,
368                context: &context,
369                progress: None,
370            })
371            .await
372    }
373
374    #[test]
375    fn rlm_control_definitions_include_continue_as_only() {
376        let provider = RlmControlToolsProvider;
377        let names = provider
378            .tool_manifests()
379            .into_iter()
380            .map(|tool| tool.name)
381            .collect::<Vec<_>>();
382        assert_eq!(names, vec!["continue_as"]);
383    }
384
385    #[tokio::test]
386    async fn continue_as_creates_empty_rlm_frame_with_seed_and_task() {
387        let mut session_graph = lash_core::SessionGraph::default();
388        session_graph.append_protocol_event(rlm_protocol_event(RlmProtocolEvent::RlmGlobalsPatch(
389            lash_rlm_types::RlmGlobalsPatchPluginBody {
390                set_default: serde_json::Map::from_iter([("diary".to_string(), json!([]))]),
391            },
392        )));
393        let manager = Arc::new(BatonManager {
394            snapshot: RuntimeSessionState {
395                policy: SessionPolicy {
396                    model: model_spec("model"),
397                    ..SessionPolicy::default()
398                },
399                protocol_turn_options: lash_core::ProtocolTurnOptions::typed(
400                    RlmTermination::SubmitRequired {
401                        schema: Some(json!({
402                            "type": "object",
403                            "properties": { "answer": { "type": "string" } },
404                            "required": ["answer"]
405                        })),
406                    },
407                )
408                .expect("valid rlm turn options"),
409                session_graph,
410                ..RuntimeSessionState::default()
411            },
412            created: Mutex::new(Vec::new()),
413            ..BatonManager::default()
414        });
415        let provider = RlmControlToolsProvider;
416
417        let args = json!({
418            "task": "finish from here",
419            "seed": { "x": 1, "query": "original" }
420        });
421        let result = run_continue_as(&provider, manager.clone(), &args).await;
422
423        assert!(result.is_success(), "{:?}", result.value_for_projection());
424        let value = result.value_for_projection();
425        assert!(value.get("frame_id").and_then(Value::as_str).is_some());
426        assert_eq!(value.get("seed_keys"), Some(&json!(["query", "x"])));
427        assert_eq!(value.get("seed_count"), Some(&json!(2)));
428        assert!(value.get("projected_count").is_none());
429        assert!(value.get("global_count").is_none());
430        let Some(ToolControl::SwitchAgentFrame {
431            frame_id,
432            initial_nodes,
433            task,
434        }) = result.as_output().control.as_ref()
435        else {
436            panic!("expected frame switch control");
437        };
438        assert_eq!(
439            value.get("frame_id").and_then(Value::as_str),
440            Some(frame_id.as_str())
441        );
442        assert_eq!(task.as_deref(), Some("finish from here"));
443        assert_eq!(initial_nodes.len(), 1);
444        let node = serde_json::from_value::<SessionAppendNode>(initial_nodes[0].clone())
445            .expect("decode initial node");
446        let SessionAppendNode::ProtocolEvent {
447            event: protocol_event,
448            ..
449        } = node
450        else {
451            panic!("expected seed globals event");
452        };
453        let Some(RlmProtocolEvent::RlmSeed(seed)) = decode_rlm_protocol_event(&protocol_event)
454        else {
455            panic!("expected RlmSeed");
456        };
457        assert_eq!(seed.globals["x"], json!(1));
458        assert_eq!(seed.globals["query"], json!("original"));
459        assert!(seed.projected.is_empty());
460        assert!(manager.created.lock().expect("created").is_empty());
461    }
462
463    #[tokio::test]
464    async fn continue_as_routes_projected_entries_and_globals_to_one_seed_event() {
465        // Mixed seed: `proj` was a projected source on the parent (encoded with
466        // the canonical `__projected__` JSON wrapper), `glob` was a regular
467        // global. The new frame receives both through one durable RLM seed event.
468        let manager = Arc::new(BatonManager {
469            snapshot: RuntimeSessionState {
470                policy: SessionPolicy {
471                    model: model_spec("model"),
472                    ..SessionPolicy::default()
473                },
474                ..RuntimeSessionState::default()
475            },
476            created: Mutex::new(Vec::new()),
477            ..BatonManager::default()
478        });
479        let provider = RlmControlToolsProvider;
480
481        let args = json!({
482            "task": "finish from here",
483            "seed": {
484                "proj": { "__projected__": "carry-over" },
485                "glob": 7
486            }
487        });
488        let result = run_continue_as(&provider, manager.clone(), &args).await;
489        assert!(result.is_success(), "{:?}", result.value_for_projection());
490        let value = result.value_for_projection();
491        assert_eq!(value.get("seed_keys"), Some(&json!(["glob", "proj"])));
492        assert_eq!(value.get("seed_count"), Some(&json!(2)));
493        assert!(value.get("projected_count").is_none());
494        assert!(value.get("global_count").is_none());
495
496        let Some(ToolControl::SwitchAgentFrame { initial_nodes, .. }) =
497            result.as_output().control.as_ref()
498        else {
499            panic!("expected frame switch control");
500        };
501        assert_eq!(initial_nodes.len(), 1);
502        let node = serde_json::from_value::<SessionAppendNode>(initial_nodes[0].clone())
503            .expect("decode initial node");
504        let SessionAppendNode::ProtocolEvent {
505            event: protocol_event,
506            ..
507        } = node
508        else {
509            panic!("expected seed globals event");
510        };
511        let Some(RlmProtocolEvent::RlmSeed(seed)) = decode_rlm_protocol_event(&protocol_event)
512        else {
513            panic!("expected RlmSeed");
514        };
515        assert_eq!(seed.globals.len(), 1, "only `glob` should land as a global");
516        assert_eq!(seed.globals["glob"], json!(7));
517        assert!(!seed.globals.contains_key("proj"));
518        assert_eq!(seed.projected.entries.len(), 1);
519        assert_eq!(seed.projected.entries[0].0, "proj");
520        assert_eq!(seed.projected.entries[0].1, json!("carry-over"));
521        assert!(manager.created.lock().expect("created").is_empty());
522    }
523
524    #[tokio::test]
525    async fn continue_as_preserves_process_shaped_seed_without_process_control() {
526        let manager = Arc::new(BatonManager {
527            snapshot: RuntimeSessionState {
528                policy: SessionPolicy {
529                    model: model_spec("model"),
530                    ..SessionPolicy::default()
531                },
532                ..RuntimeSessionState::default()
533            },
534            created: Mutex::new(Vec::new()),
535            ..BatonManager::default()
536        });
537        let provider = RlmControlToolsProvider;
538
539        let args = json!({
540            "task": "continue with background work",
541            "seed": {
542                "one": { "__handle__": "process", "id": "h1", "tool": "slow" },
543                "nested": [{ "h": { "__handle__": "process", "id": "h2", "tool": "slow" } }]
544            }
545        });
546        let result = run_continue_as(&provider, manager.clone(), &args).await;
547
548        assert!(result.is_success(), "{:?}", result.value_for_projection());
549        let Some(ToolControl::SwitchAgentFrame { initial_nodes, .. }) =
550            result.as_output().control.as_ref()
551        else {
552            panic!("expected frame switch control");
553        };
554        let node = serde_json::from_value::<SessionAppendNode>(initial_nodes[0].clone())
555            .expect("decode initial node");
556        let SessionAppendNode::ProtocolEvent {
557            event: protocol_event,
558            ..
559        } = node
560        else {
561            panic!("expected seed globals event");
562        };
563        let Some(RlmProtocolEvent::RlmSeed(seed)) = decode_rlm_protocol_event(&protocol_event)
564        else {
565            panic!("expected RlmSeed");
566        };
567        assert_eq!(
568            seed.globals["one"],
569            json!({ "__handle__": "process", "id": "h1", "tool": "slow" })
570        );
571        assert_eq!(
572            seed.globals["nested"],
573            json!([{ "h": { "__handle__": "process", "id": "h2", "tool": "slow" } }])
574        );
575    }
576
577    #[tokio::test]
578    async fn continue_as_does_not_validate_unknown_seed_handles() {
579        let manager = Arc::new(BatonManager {
580            snapshot: RuntimeSessionState {
581                policy: SessionPolicy {
582                    model: model_spec("model"),
583                    ..SessionPolicy::default()
584                },
585                ..RuntimeSessionState::default()
586            },
587            created: Mutex::new(Vec::new()),
588            ..BatonManager::default()
589        });
590        let provider = RlmControlToolsProvider;
591
592        let args = json!({
593            "task": "continue",
594            "seed": { "h": { "__handle__": "process", "id": "missing", "tool": "slow" } }
595        });
596        let result = run_continue_as(&provider, manager.clone(), &args).await;
597
598        assert!(result.is_success(), "{:?}", result.value_for_projection());
599        assert!(manager.created.lock().expect("created").is_empty());
600    }
601}