Skip to main content

lash_core/
tool_provider.rs

1use std::collections::BTreeMap;
2use std::sync::Arc;
3
4use crate::plugin::{DirectCompletion, PluginError, SessionHandle, SessionSnapshot, ToolHookHost};
5use crate::{
6    AttachmentCreateMeta, AttachmentRef, AttachmentStore, AttachmentStoreError, ToolContract,
7    ToolManifest, ToolResult,
8};
9
10/// A message sent from the sandbox to the host during execution.
11#[derive(Clone, Debug)]
12pub struct SandboxMessage {
13    pub text: String,
14    /// "tool_output" or another host-rendered progress event kind.
15    pub kind: String,
16}
17
18/// Sender for streaming progress messages from tools (e.g. live bash output).
19pub type ProgressSender = tokio::sync::mpsc::UnboundedSender<SandboxMessage>;
20
21/// Per-call environment for [`ToolProvider::execute`]. Fields are sealed so
22/// the runtime can add capabilities without breaking tool authors.
23#[derive(Clone)]
24pub struct ToolContext {
25    pub(crate) session_id: String,
26    pub(crate) host: Arc<dyn ToolHookHost>,
27    pub(crate) cancellation_token: Option<tokio_util::sync::CancellationToken>,
28    pub(crate) async_task_id: Option<String>,
29    pub(crate) turn_context: crate::TurnContext,
30    pub(crate) attachment_store: Arc<dyn AttachmentStore>,
31    /// The id of the in-flight tool call that is invoking this tool. Set by
32    /// the runtime tool dispatcher; tools should propagate it onto any
33    /// `DirectRequest::originating_tool_call_id` they issue so the trace
34    /// renderer can group fan-out LLM calls under the parent tool entry.
35    pub(crate) tool_call_id: Option<String>,
36    pub(crate) attempt_number: u32,
37    pub(crate) max_attempts: u32,
38    pub(crate) idempotency_key: Option<String>,
39}
40
41#[derive(Clone, Debug, PartialEq, Eq)]
42pub struct ToolSessionModel {
43    pub model: String,
44    pub model_variant: Option<String>,
45}
46
47#[derive(Clone)]
48pub struct ToolSessionControl {
49    host: Arc<dyn ToolHookHost>,
50}
51
52impl ToolSessionControl {
53    pub async fn create_session(
54        &self,
55        request: crate::SessionCreateRequest,
56    ) -> Result<SessionHandle, PluginError> {
57        self.host.create_session(request).await
58    }
59
60    pub async fn close_session(&self, session_id: &str) -> Result<(), PluginError> {
61        self.host.close_session(session_id).await
62    }
63
64    pub async fn start_turn_stream(
65        &self,
66        session_id: &str,
67        input: crate::TurnInput,
68    ) -> Result<crate::plugin::SessionTurnHandle, PluginError> {
69        self.host.start_turn_stream(session_id, input).await
70    }
71
72    pub async fn await_turn(&self, turn_id: &str) -> Result<crate::AssembledTurn, PluginError> {
73        self.host.await_turn(turn_id).await
74    }
75
76    pub async fn cancel_turn(&self, turn_id: &str) -> Result<(), PluginError> {
77        self.host.cancel_turn(turn_id).await
78    }
79}
80
81#[async_trait::async_trait]
82impl crate::plugin::SessionLifecycleHost for ToolSessionControl {
83    async fn create_session(
84        &self,
85        request: crate::SessionCreateRequest,
86    ) -> Result<SessionHandle, PluginError> {
87        ToolSessionControl::create_session(self, request).await
88    }
89
90    async fn close_session(&self, session_id: &str) -> Result<(), PluginError> {
91        ToolSessionControl::close_session(self, session_id).await
92    }
93}
94
95#[derive(Clone)]
96pub struct ToolTaskControl {
97    session_id: String,
98    host: Arc<dyn ToolHookHost>,
99}
100
101impl ToolTaskControl {
102    pub async fn register_background_task(
103        &self,
104        spec: crate::BackgroundTaskRegistration,
105        cancel: Option<crate::LocalBackgroundTaskCancel>,
106    ) -> Result<(), PluginError> {
107        self.host
108            .register_background_task(&self.session_id, spec, cancel)
109            .await
110    }
111
112    pub async fn unregister_background_task(&self, task_id: &str) {
113        self.unregister_background_task_for_session(&self.session_id, task_id)
114            .await;
115    }
116
117    pub async fn complete_background_task(&self, task_id: &str, state: crate::BackgroundTaskState) {
118        self.complete_background_task_for_session(&self.session_id, task_id, state)
119            .await;
120    }
121
122    pub async fn transition_background_task_live_state(
123        &self,
124        task_id: &str,
125        state: crate::BackgroundTaskState,
126    ) {
127        self.transition_background_task_live_state_for_session(&self.session_id, task_id, state)
128            .await;
129    }
130
131    pub async fn unregister_background_task_for_session(&self, session_id: &str, task_id: &str) {
132        self.host
133            .unregister_background_task(session_id, task_id)
134            .await;
135    }
136
137    pub async fn complete_background_task_for_session(
138        &self,
139        session_id: &str,
140        task_id: &str,
141        state: crate::BackgroundTaskState,
142    ) {
143        self.host
144            .complete_background_task(session_id, task_id, state)
145            .await;
146    }
147
148    pub async fn transition_background_task_live_state_for_session(
149        &self,
150        session_id: &str,
151        task_id: &str,
152        state: crate::BackgroundTaskState,
153    ) {
154        self.host
155            .transition_background_task_live_state(session_id, task_id, state)
156            .await;
157    }
158
159    pub async fn validate_async_handles_visible(
160        &self,
161        handle_ids: &[String],
162    ) -> Result<(), PluginError> {
163        self.host
164            .validate_async_handles_visible(&self.session_id, handle_ids)
165            .await
166    }
167
168    pub async fn transfer_async_handles_to_session(
169        &self,
170        successor_session_id: &str,
171        handle_ids: &[String],
172    ) -> Result<(), PluginError> {
173        self.host
174            .transfer_async_handles(&self.session_id, successor_session_id, handle_ids)
175            .await
176    }
177
178    pub async fn cancel_unreferenced_async_handles(
179        &self,
180        keep_handle_ids: &[String],
181    ) -> Result<Vec<crate::BackgroundTaskRecord>, PluginError> {
182        self.host
183            .cancel_unreferenced_async_handles(&self.session_id, keep_handle_ids)
184            .await
185    }
186}
187
188impl ToolContext {
189    pub(crate) fn new(
190        session_id: String,
191        host: Arc<dyn ToolHookHost>,
192        turn_context: crate::TurnContext,
193        attachment_store: Arc<dyn AttachmentStore>,
194        tool_call_id: Option<String>,
195    ) -> Self {
196        Self {
197            session_id,
198            host,
199            cancellation_token: None,
200            async_task_id: None,
201            turn_context,
202            attachment_store,
203            tool_call_id,
204            attempt_number: 1,
205            max_attempts: 1,
206            idempotency_key: None,
207        }
208    }
209
210    pub fn session_id(&self) -> &str {
211        &self.session_id
212    }
213
214    pub async fn session_model(&self) -> Result<ToolSessionModel, PluginError> {
215        let snapshot = self.session_snapshot().await?;
216        Ok(ToolSessionModel {
217            model: snapshot.policy.model,
218            model_variant: snapshot.policy.model_variant,
219        })
220    }
221
222    pub async fn session_snapshot(&self) -> Result<SessionSnapshot, PluginError> {
223        self.snapshot_current_session().await
224    }
225
226    pub async fn snapshot_current_session(&self) -> Result<SessionSnapshot, PluginError> {
227        self.snapshot_session(&self.session_id).await
228    }
229
230    pub async fn snapshot_session(
231        &self,
232        session_id: impl AsRef<str>,
233    ) -> Result<SessionSnapshot, PluginError> {
234        self.host.snapshot_session(session_id.as_ref()).await
235    }
236
237    pub async fn tool_catalog(&self) -> Result<Vec<serde_json::Value>, PluginError> {
238        self.host.tool_catalog(&self.session_id).await
239    }
240
241    pub async fn set_tools_availability(
242        &self,
243        names: &[String],
244        availability: Option<crate::ToolAvailability>,
245    ) -> Result<u64, PluginError> {
246        self.host
247            .set_tools_availability(&self.session_id, names, availability)
248            .await
249    }
250
251    pub fn sessions(&self) -> ToolSessionControl {
252        ToolSessionControl {
253            host: Arc::clone(&self.host),
254        }
255    }
256
257    pub fn tasks(&self) -> ToolTaskControl {
258        ToolTaskControl {
259            session_id: self.session_id.clone(),
260            host: Arc::clone(&self.host),
261        }
262    }
263
264    pub async fn direct_completion(
265        &self,
266        mut request: crate::DirectRequest,
267        usage_source: &str,
268    ) -> Result<DirectCompletion, PluginError> {
269        if request.session_id.is_none() {
270            request.session_id = Some(self.session_id.clone());
271        }
272        if request.originating_tool_call_id.is_none() {
273            request.originating_tool_call_id = self.tool_call_id.clone();
274        }
275        self.host.direct_completion(request, usage_source).await
276    }
277
278    pub fn cancellation_token(&self) -> Option<&tokio_util::sync::CancellationToken> {
279        self.cancellation_token.as_ref()
280    }
281
282    pub fn async_task_id(&self) -> Option<&str> {
283        self.async_task_id.as_deref()
284    }
285
286    pub fn turn_context(&self) -> &crate::TurnContext {
287        &self.turn_context
288    }
289
290    pub fn tool_call_id(&self) -> Option<&str> {
291        self.tool_call_id.as_deref()
292    }
293
294    pub fn attempt_number(&self) -> u32 {
295        self.attempt_number
296    }
297
298    pub fn max_attempts(&self) -> u32 {
299        self.max_attempts
300    }
301
302    pub fn idempotency_key(&self) -> Option<&str> {
303        self.idempotency_key.as_deref()
304    }
305
306    pub fn put_attachment(
307        &self,
308        data: Vec<u8>,
309        meta: AttachmentCreateMeta,
310    ) -> Result<AttachmentRef, AttachmentStoreError> {
311        self.attachment_store.put(data, meta)
312    }
313
314    /// Shortcut for [`TurnContext::plugin_input`](crate::TurnContext::plugin_input).
315    pub fn plugin_input<T: 'static>(&self, plugin_id: &'static str) -> Option<&T> {
316        self.turn_context.plugin_input::<T>(plugin_id)
317    }
318
319    pub fn with_async_task(
320        mut self,
321        task_id: impl Into<String>,
322        cancellation_token: tokio_util::sync::CancellationToken,
323    ) -> Self {
324        self.async_task_id = Some(task_id.into());
325        self.cancellation_token = Some(cancellation_token);
326        self
327    }
328
329    pub(crate) fn with_retry_context(
330        mut self,
331        tool_name: &str,
332        attempt_number: u32,
333        max_attempts: u32,
334    ) -> Self {
335        self.attempt_number = attempt_number.max(1);
336        self.max_attempts = max_attempts.max(1);
337        self.idempotency_key = self
338            .tool_call_id
339            .as_ref()
340            .map(|call_id| format!("lash-tool:{}:{call_id}:{tool_name}", self.session_id));
341        self
342    }
343
344    /// Constructor reserved for `lash_core::testing` helpers. Do not call directly;
345    /// use [`lash_core::testing::mock_tool_context`] instead.
346    #[cfg(any(test, feature = "testing"))]
347    #[doc(hidden)]
348    pub fn __for_testing(
349        session_id: String,
350        host: Arc<dyn ToolHookHost>,
351        turn_context: crate::TurnContext,
352        attachment_store: Arc<dyn AttachmentStore>,
353        tool_call_id: Option<String>,
354    ) -> Self {
355        Self::new(
356            session_id,
357            host,
358            turn_context,
359            attachment_store,
360            tool_call_id,
361        )
362    }
363}
364
365/// Per-call inputs handed to [`ToolProvider::execute`].
366///
367/// Fields are `pub` because `ToolCall` is a transient borrow; consumers
368/// typically destructure (`let ToolCall { name, args, .. } = call`). The
369/// stable surface lives on [`ToolContext`] (sealed) and the runtime's
370/// dispatcher, which constructs `ToolCall` values.
371pub struct ToolCall<'a> {
372    pub name: &'a str,
373    pub args: &'a serde_json::Value,
374    pub context: &'a ToolContext,
375    pub progress: Option<&'a ProgressSender>,
376}
377
378/// Trait for providing tools to the sandbox. Implement this per-project.
379///
380/// Implementations supply cheap [`ToolManifest`]s, lazily resolved
381/// [`ToolContract`]s, and a single
382/// [`execute`](Self::execute) method that handles every call. Tools that
383/// need session state read it from `call.context`; tools that stream
384/// progress send through `call.progress`.
385#[async_trait::async_trait]
386pub trait ToolProvider: Send + Sync + 'static {
387    fn tool_manifests(&self) -> Vec<ToolManifest>;
388    fn resolve_manifest(&self, name: &str) -> Option<ToolManifest> {
389        self.tool_manifests()
390            .into_iter()
391            .find(|manifest| manifest.name == name)
392    }
393    fn resolve_contract(&self, name: &str) -> Option<Arc<ToolContract>>;
394    async fn execute(&self, call: ToolCall<'_>) -> ToolResult;
395}
396
397pub(crate) struct CompositeToolProvider {
398    tools: std::sync::RwLock<BTreeMap<String, (ToolManifest, usize)>>,
399    providers: Vec<(Arc<dyn ToolProvider>, Vec<String>)>,
400}
401
402impl CompositeToolProvider {
403    pub(crate) fn from_providers(providers: Vec<Arc<dyn ToolProvider>>) -> Self {
404        let mut tools = BTreeMap::new();
405        let mut entries = Vec::new();
406        for provider in providers {
407            let tool_names = provider
408                .tool_manifests()
409                .into_iter()
410                .map(|manifest| {
411                    let name = manifest.name.clone();
412                    tools.insert(name.clone(), (manifest, entries.len()));
413                    name
414                })
415                .collect::<Vec<_>>();
416            entries.push((provider, tool_names));
417        }
418        Self {
419            tools: std::sync::RwLock::new(tools),
420            providers: entries,
421        }
422    }
423}
424
425#[async_trait::async_trait]
426impl ToolProvider for CompositeToolProvider {
427    fn tool_manifests(&self) -> Vec<ToolManifest> {
428        self.tools
429            .read()
430            .expect("composite tool provider lock poisoned")
431            .values()
432            .map(|(manifest, _)| manifest.clone())
433            .collect()
434    }
435
436    fn resolve_manifest(&self, name: &str) -> Option<ToolManifest> {
437        if let Some((manifest, _)) = self
438            .tools
439            .read()
440            .expect("composite tool provider lock poisoned")
441            .get(name)
442        {
443            return Some(manifest.clone());
444        }
445        for (provider_idx, (provider, _)) in self.providers.iter().enumerate() {
446            if let Some(manifest) = provider.resolve_manifest(name) {
447                self.tools
448                    .write()
449                    .expect("composite tool provider lock poisoned")
450                    .insert(name.to_string(), (manifest.clone(), provider_idx));
451                return Some(manifest);
452            }
453        }
454        None
455    }
456
457    fn resolve_contract(&self, name: &str) -> Option<Arc<ToolContract>> {
458        let provider_idx = self.resolve_manifest(name).and_then(|_| {
459            self.tools
460                .read()
461                .expect("composite tool provider lock poisoned")
462                .get(name)
463                .map(|(_, provider_idx)| *provider_idx)
464        })?;
465        self.providers[provider_idx].0.resolve_contract(name)
466    }
467
468    async fn execute(&self, call: ToolCall<'_>) -> ToolResult {
469        let provider_idx = self.resolve_manifest(call.name).and_then(|_| {
470            self.tools
471                .read()
472                .expect("composite tool provider lock poisoned")
473                .get(call.name)
474                .map(|(_, provider_idx)| *provider_idx)
475        });
476        match provider_idx {
477            Some(provider_idx) => self.providers[provider_idx].0.execute(call).await,
478            None => ToolResult::err_fmt(format_args!("Unknown tool: {}", call.name)),
479        }
480    }
481}