1use std::collections::BTreeMap;
2use std::sync::Arc;
3
4use crate::plugin::{DirectCompletion, PluginError, SessionHandle, SessionSnapshot, ToolHookHost};
5use crate::{
6 AttachmentCreateMeta, AttachmentRef, AttachmentStore, AttachmentStoreError, ToolContract,
7 ToolManifest, ToolResult,
8};
9
10#[derive(Clone, Debug)]
12pub struct SandboxMessage {
13 pub text: String,
14 pub kind: String,
16}
17
18pub type ProgressSender = tokio::sync::mpsc::UnboundedSender<SandboxMessage>;
20
21#[derive(Clone)]
24pub struct ToolContext {
25 pub(crate) session_id: String,
26 pub(crate) host: Arc<dyn ToolHookHost>,
27 pub(crate) cancellation_token: Option<tokio_util::sync::CancellationToken>,
28 pub(crate) async_task_id: Option<String>,
29 pub(crate) turn_context: crate::TurnContext,
30 pub(crate) attachment_store: Arc<dyn AttachmentStore>,
31 pub(crate) tool_call_id: Option<String>,
36 pub(crate) attempt_number: u32,
37 pub(crate) max_attempts: u32,
38 pub(crate) idempotency_key: Option<String>,
39}
40
41#[derive(Clone, Debug, PartialEq, Eq)]
42pub struct ToolSessionModel {
43 pub model: String,
44 pub model_variant: Option<String>,
45}
46
47#[derive(Clone)]
48pub struct ToolSessionControl {
49 host: Arc<dyn ToolHookHost>,
50}
51
52impl ToolSessionControl {
53 pub async fn create_session(
54 &self,
55 request: crate::SessionCreateRequest,
56 ) -> Result<SessionHandle, PluginError> {
57 self.host.create_session(request).await
58 }
59
60 pub async fn close_session(&self, session_id: &str) -> Result<(), PluginError> {
61 self.host.close_session(session_id).await
62 }
63
64 pub async fn start_turn_stream(
65 &self,
66 session_id: &str,
67 input: crate::TurnInput,
68 ) -> Result<crate::plugin::SessionTurnHandle, PluginError> {
69 self.host.start_turn_stream(session_id, input).await
70 }
71
72 pub async fn await_turn(&self, turn_id: &str) -> Result<crate::AssembledTurn, PluginError> {
73 self.host.await_turn(turn_id).await
74 }
75
76 pub async fn cancel_turn(&self, turn_id: &str) -> Result<(), PluginError> {
77 self.host.cancel_turn(turn_id).await
78 }
79}
80
81#[async_trait::async_trait]
82impl crate::plugin::SessionLifecycleHost for ToolSessionControl {
83 async fn create_session(
84 &self,
85 request: crate::SessionCreateRequest,
86 ) -> Result<SessionHandle, PluginError> {
87 ToolSessionControl::create_session(self, request).await
88 }
89
90 async fn close_session(&self, session_id: &str) -> Result<(), PluginError> {
91 ToolSessionControl::close_session(self, session_id).await
92 }
93}
94
95#[derive(Clone)]
96pub struct ToolTaskControl {
97 session_id: String,
98 host: Arc<dyn ToolHookHost>,
99}
100
101impl ToolTaskControl {
102 pub async fn register_background_task(
103 &self,
104 spec: crate::BackgroundTaskRegistration,
105 cancel: Option<crate::LocalBackgroundTaskCancel>,
106 ) -> Result<(), PluginError> {
107 self.host
108 .register_background_task(&self.session_id, spec, cancel)
109 .await
110 }
111
112 pub async fn unregister_background_task(&self, task_id: &str) {
113 self.unregister_background_task_for_session(&self.session_id, task_id)
114 .await;
115 }
116
117 pub async fn complete_background_task(&self, task_id: &str, state: crate::BackgroundTaskState) {
118 self.complete_background_task_for_session(&self.session_id, task_id, state)
119 .await;
120 }
121
122 pub async fn transition_background_task_live_state(
123 &self,
124 task_id: &str,
125 state: crate::BackgroundTaskState,
126 ) {
127 self.transition_background_task_live_state_for_session(&self.session_id, task_id, state)
128 .await;
129 }
130
131 pub async fn unregister_background_task_for_session(&self, session_id: &str, task_id: &str) {
132 self.host
133 .unregister_background_task(session_id, task_id)
134 .await;
135 }
136
137 pub async fn complete_background_task_for_session(
138 &self,
139 session_id: &str,
140 task_id: &str,
141 state: crate::BackgroundTaskState,
142 ) {
143 self.host
144 .complete_background_task(session_id, task_id, state)
145 .await;
146 }
147
148 pub async fn transition_background_task_live_state_for_session(
149 &self,
150 session_id: &str,
151 task_id: &str,
152 state: crate::BackgroundTaskState,
153 ) {
154 self.host
155 .transition_background_task_live_state(session_id, task_id, state)
156 .await;
157 }
158
159 pub async fn validate_async_handles_visible(
160 &self,
161 handle_ids: &[String],
162 ) -> Result<(), PluginError> {
163 self.host
164 .validate_async_handles_visible(&self.session_id, handle_ids)
165 .await
166 }
167
168 pub async fn transfer_async_handles_to_session(
169 &self,
170 successor_session_id: &str,
171 handle_ids: &[String],
172 ) -> Result<(), PluginError> {
173 self.host
174 .transfer_async_handles(&self.session_id, successor_session_id, handle_ids)
175 .await
176 }
177
178 pub async fn cancel_unreferenced_async_handles(
179 &self,
180 keep_handle_ids: &[String],
181 ) -> Result<Vec<crate::BackgroundTaskRecord>, PluginError> {
182 self.host
183 .cancel_unreferenced_async_handles(&self.session_id, keep_handle_ids)
184 .await
185 }
186}
187
188impl ToolContext {
189 pub(crate) fn new(
190 session_id: String,
191 host: Arc<dyn ToolHookHost>,
192 turn_context: crate::TurnContext,
193 attachment_store: Arc<dyn AttachmentStore>,
194 tool_call_id: Option<String>,
195 ) -> Self {
196 Self {
197 session_id,
198 host,
199 cancellation_token: None,
200 async_task_id: None,
201 turn_context,
202 attachment_store,
203 tool_call_id,
204 attempt_number: 1,
205 max_attempts: 1,
206 idempotency_key: None,
207 }
208 }
209
210 pub fn session_id(&self) -> &str {
211 &self.session_id
212 }
213
214 pub async fn session_model(&self) -> Result<ToolSessionModel, PluginError> {
215 let snapshot = self.session_snapshot().await?;
216 Ok(ToolSessionModel {
217 model: snapshot.policy.model,
218 model_variant: snapshot.policy.model_variant,
219 })
220 }
221
222 pub async fn session_snapshot(&self) -> Result<SessionSnapshot, PluginError> {
223 self.snapshot_current_session().await
224 }
225
226 pub async fn snapshot_current_session(&self) -> Result<SessionSnapshot, PluginError> {
227 self.snapshot_session(&self.session_id).await
228 }
229
230 pub async fn snapshot_session(
231 &self,
232 session_id: impl AsRef<str>,
233 ) -> Result<SessionSnapshot, PluginError> {
234 self.host.snapshot_session(session_id.as_ref()).await
235 }
236
237 pub async fn tool_catalog(&self) -> Result<Vec<serde_json::Value>, PluginError> {
238 self.host.tool_catalog(&self.session_id).await
239 }
240
241 pub async fn set_tools_availability(
242 &self,
243 names: &[String],
244 availability: Option<crate::ToolAvailability>,
245 ) -> Result<u64, PluginError> {
246 self.host
247 .set_tools_availability(&self.session_id, names, availability)
248 .await
249 }
250
251 pub fn sessions(&self) -> ToolSessionControl {
252 ToolSessionControl {
253 host: Arc::clone(&self.host),
254 }
255 }
256
257 pub fn tasks(&self) -> ToolTaskControl {
258 ToolTaskControl {
259 session_id: self.session_id.clone(),
260 host: Arc::clone(&self.host),
261 }
262 }
263
264 pub async fn direct_completion(
265 &self,
266 mut request: crate::DirectRequest,
267 usage_source: &str,
268 ) -> Result<DirectCompletion, PluginError> {
269 if request.session_id.is_none() {
270 request.session_id = Some(self.session_id.clone());
271 }
272 if request.originating_tool_call_id.is_none() {
273 request.originating_tool_call_id = self.tool_call_id.clone();
274 }
275 self.host.direct_completion(request, usage_source).await
276 }
277
278 pub fn cancellation_token(&self) -> Option<&tokio_util::sync::CancellationToken> {
279 self.cancellation_token.as_ref()
280 }
281
282 pub fn async_task_id(&self) -> Option<&str> {
283 self.async_task_id.as_deref()
284 }
285
286 pub fn turn_context(&self) -> &crate::TurnContext {
287 &self.turn_context
288 }
289
290 pub fn tool_call_id(&self) -> Option<&str> {
291 self.tool_call_id.as_deref()
292 }
293
294 pub fn attempt_number(&self) -> u32 {
295 self.attempt_number
296 }
297
298 pub fn max_attempts(&self) -> u32 {
299 self.max_attempts
300 }
301
302 pub fn idempotency_key(&self) -> Option<&str> {
303 self.idempotency_key.as_deref()
304 }
305
306 pub fn put_attachment(
307 &self,
308 data: Vec<u8>,
309 meta: AttachmentCreateMeta,
310 ) -> Result<AttachmentRef, AttachmentStoreError> {
311 self.attachment_store.put(data, meta)
312 }
313
314 pub fn plugin_input<T: 'static>(&self, plugin_id: &'static str) -> Option<&T> {
316 self.turn_context.plugin_input::<T>(plugin_id)
317 }
318
319 pub fn with_async_task(
320 mut self,
321 task_id: impl Into<String>,
322 cancellation_token: tokio_util::sync::CancellationToken,
323 ) -> Self {
324 self.async_task_id = Some(task_id.into());
325 self.cancellation_token = Some(cancellation_token);
326 self
327 }
328
329 pub(crate) fn with_retry_context(
330 mut self,
331 tool_name: &str,
332 attempt_number: u32,
333 max_attempts: u32,
334 ) -> Self {
335 self.attempt_number = attempt_number.max(1);
336 self.max_attempts = max_attempts.max(1);
337 self.idempotency_key = self
338 .tool_call_id
339 .as_ref()
340 .map(|call_id| format!("lash-tool:{}:{call_id}:{tool_name}", self.session_id));
341 self
342 }
343
344 #[cfg(any(test, feature = "testing"))]
347 #[doc(hidden)]
348 pub fn __for_testing(
349 session_id: String,
350 host: Arc<dyn ToolHookHost>,
351 turn_context: crate::TurnContext,
352 attachment_store: Arc<dyn AttachmentStore>,
353 tool_call_id: Option<String>,
354 ) -> Self {
355 Self::new(
356 session_id,
357 host,
358 turn_context,
359 attachment_store,
360 tool_call_id,
361 )
362 }
363}
364
365pub struct ToolCall<'a> {
372 pub name: &'a str,
373 pub args: &'a serde_json::Value,
374 pub context: &'a ToolContext,
375 pub progress: Option<&'a ProgressSender>,
376}
377
378#[async_trait::async_trait]
386pub trait ToolProvider: Send + Sync + 'static {
387 fn tool_manifests(&self) -> Vec<ToolManifest>;
388 fn resolve_manifest(&self, name: &str) -> Option<ToolManifest> {
389 self.tool_manifests()
390 .into_iter()
391 .find(|manifest| manifest.name == name)
392 }
393 fn resolve_contract(&self, name: &str) -> Option<Arc<ToolContract>>;
394 async fn execute(&self, call: ToolCall<'_>) -> ToolResult;
395}
396
397pub(crate) struct CompositeToolProvider {
398 tools: std::sync::RwLock<BTreeMap<String, (ToolManifest, usize)>>,
399 providers: Vec<(Arc<dyn ToolProvider>, Vec<String>)>,
400}
401
402impl CompositeToolProvider {
403 pub(crate) fn from_providers(providers: Vec<Arc<dyn ToolProvider>>) -> Self {
404 let mut tools = BTreeMap::new();
405 let mut entries = Vec::new();
406 for provider in providers {
407 let tool_names = provider
408 .tool_manifests()
409 .into_iter()
410 .map(|manifest| {
411 let name = manifest.name.clone();
412 tools.insert(name.clone(), (manifest, entries.len()));
413 name
414 })
415 .collect::<Vec<_>>();
416 entries.push((provider, tool_names));
417 }
418 Self {
419 tools: std::sync::RwLock::new(tools),
420 providers: entries,
421 }
422 }
423}
424
425#[async_trait::async_trait]
426impl ToolProvider for CompositeToolProvider {
427 fn tool_manifests(&self) -> Vec<ToolManifest> {
428 self.tools
429 .read()
430 .expect("composite tool provider lock poisoned")
431 .values()
432 .map(|(manifest, _)| manifest.clone())
433 .collect()
434 }
435
436 fn resolve_manifest(&self, name: &str) -> Option<ToolManifest> {
437 if let Some((manifest, _)) = self
438 .tools
439 .read()
440 .expect("composite tool provider lock poisoned")
441 .get(name)
442 {
443 return Some(manifest.clone());
444 }
445 for (provider_idx, (provider, _)) in self.providers.iter().enumerate() {
446 if let Some(manifest) = provider.resolve_manifest(name) {
447 self.tools
448 .write()
449 .expect("composite tool provider lock poisoned")
450 .insert(name.to_string(), (manifest.clone(), provider_idx));
451 return Some(manifest);
452 }
453 }
454 None
455 }
456
457 fn resolve_contract(&self, name: &str) -> Option<Arc<ToolContract>> {
458 let provider_idx = self.resolve_manifest(name).and_then(|_| {
459 self.tools
460 .read()
461 .expect("composite tool provider lock poisoned")
462 .get(name)
463 .map(|(_, provider_idx)| *provider_idx)
464 })?;
465 self.providers[provider_idx].0.resolve_contract(name)
466 }
467
468 async fn execute(&self, call: ToolCall<'_>) -> ToolResult {
469 let provider_idx = self.resolve_manifest(call.name).and_then(|_| {
470 self.tools
471 .read()
472 .expect("composite tool provider lock poisoned")
473 .get(call.name)
474 .map(|(_, provider_idx)| *provider_idx)
475 });
476 match provider_idx {
477 Some(provider_idx) => self.providers[provider_idx].0.execute(call).await,
478 None => ToolResult::err_fmt(format_args!("Unknown tool: {}", call.name)),
479 }
480 }
481}