1use async_trait::async_trait;
7use serde_json::Value;
8use std::collections::HashSet;
9use std::sync::Arc;
10#[cfg(feature = "wasm-host")]
11use std::time::Duration;
12
13use crate::error::{Error, Result};
14#[cfg(test)]
15use crate::extensions::JsExtensionRuntimeHandle;
16#[cfg(feature = "wasm-host")]
17use crate::extensions::WasmExtensionHandle;
18use crate::extensions::{ExtensionManager, ExtensionRuntimeHandle};
19use crate::extensions_js::ExtensionToolDef;
20use crate::tools::{Tool, ToolOutput, ToolUpdate};
21#[cfg(feature = "wasm-host")]
22use asupersync::time::{timeout, wall_now};
23
24const DEFAULT_EXTENSION_TOOL_TIMEOUT_MS: u64 = 60_000;
25
26pub struct ExtensionToolWrapper {
32 def: ExtensionToolDef,
33 runtime: ExtensionRuntimeHandle,
34 ctx_payload: Arc<Value>,
35 timeout_ms: u64,
36}
37
38impl ExtensionToolWrapper {
39 #[must_use]
40 pub fn new<R>(def: ExtensionToolDef, runtime: R) -> Self
41 where
42 R: Into<ExtensionRuntimeHandle>,
43 {
44 Self {
45 def,
46 runtime: runtime.into(),
47 ctx_payload: Arc::new(Value::Object(serde_json::Map::new())),
48 timeout_ms: DEFAULT_EXTENSION_TOOL_TIMEOUT_MS,
49 }
50 }
51
52 #[must_use]
53 pub fn with_ctx_payload(mut self, ctx_payload: Value) -> Self {
54 self.ctx_payload = Arc::new(ctx_payload);
55 self
56 }
57
58 #[must_use]
59 pub fn with_ctx_payload_shared(mut self, ctx_payload: Arc<Value>) -> Self {
60 self.ctx_payload = ctx_payload;
61 self
62 }
63
64 #[must_use]
65 pub fn with_timeout_ms(mut self, timeout_ms: u64) -> Self {
66 self.timeout_ms = timeout_ms.max(1);
67 self
68 }
69}
70
71#[cfg(feature = "wasm-host")]
72pub struct WasmExtensionToolWrapper {
73 def: ExtensionToolDef,
74 handle: WasmExtensionHandle,
75 timeout_ms: u64,
76}
77
78#[cfg(feature = "wasm-host")]
79impl WasmExtensionToolWrapper {
80 #[must_use]
81 pub const fn new(def: ExtensionToolDef, handle: WasmExtensionHandle) -> Self {
82 Self {
83 def,
84 handle,
85 timeout_ms: DEFAULT_EXTENSION_TOOL_TIMEOUT_MS,
86 }
87 }
88
89 #[must_use]
90 pub fn with_timeout_ms(mut self, timeout_ms: u64) -> Self {
91 self.timeout_ms = timeout_ms.max(1);
92 self
93 }
94}
95
96pub async fn collect_extension_tool_wrappers(
101 manager: &ExtensionManager,
102 ctx_payload: Value,
103) -> Result<Vec<Box<dyn Tool>>> {
104 let shared_ctx_payload = Arc::new(ctx_payload);
105 let active = manager
106 .active_tools()
107 .map(|tools| tools.into_iter().collect::<HashSet<_>>());
108
109 let mut wrappers: Vec<Box<dyn Tool>> = Vec::new();
110 let mut seen = HashSet::new();
111
112 if let Some(runtime) = manager.runtime() {
113 let mut defs = runtime.get_registered_tools().await?;
114 if let Some(active) = active.as_ref() {
115 defs.retain(|def| active.contains(&def.name));
116 }
117
118 defs.sort_by(|a, b| a.name.cmp(&b.name));
119 for def in defs {
120 if !seen.insert(def.name.clone()) {
121 tracing::warn!(tool = %def.name, "Duplicate extension tool name; ignoring");
122 continue;
123 }
124
125 wrappers.push(Box::new(
126 ExtensionToolWrapper::new(def, runtime.clone())
127 .with_ctx_payload_shared(Arc::clone(&shared_ctx_payload)),
128 ));
129 }
130 }
131
132 #[cfg(feature = "wasm-host")]
133 {
134 let mut wasm_defs: Vec<(ExtensionToolDef, WasmExtensionHandle)> = Vec::new();
135 for handle in manager.wasm_extensions() {
136 for def in handle.tool_defs() {
137 wasm_defs.push((def.clone(), handle.clone()));
138 }
139 }
140
141 wasm_defs.sort_by(|a, b| a.0.name.cmp(&b.0.name));
142 for (def, handle) in wasm_defs {
143 if let Some(active) = active.as_ref() {
144 if !active.contains(&def.name) {
145 continue;
146 }
147 }
148 if !seen.insert(def.name.clone()) {
149 tracing::warn!(tool = %def.name, "Duplicate extension tool name; ignoring");
150 continue;
151 }
152
153 wrappers.push(Box::new(WasmExtensionToolWrapper::new(def, handle)));
154 }
155 }
156
157 Ok(wrappers)
158}
159
160#[async_trait]
161impl Tool for ExtensionToolWrapper {
162 fn name(&self) -> &str {
163 &self.def.name
164 }
165
166 fn label(&self) -> &str {
167 self.def.label.as_deref().unwrap_or(&self.def.name)
168 }
169
170 fn description(&self) -> &str {
171 &self.def.description
172 }
173
174 fn parameters(&self) -> Value {
175 self.def.parameters.clone()
176 }
177
178 async fn execute(
179 &self,
180 tool_call_id: &str,
181 input: Value,
182 _on_update: Option<Box<dyn Fn(ToolUpdate) + Send + Sync>>,
183 ) -> Result<ToolOutput> {
184 let result = self
185 .runtime
186 .execute_tool_ref(
187 &self.def.name,
188 tool_call_id,
189 input,
190 Arc::clone(&self.ctx_payload),
191 self.timeout_ms,
192 )
193 .await
194 .map_err(|err| Error::tool(self.name(), err.to_string()))?;
195
196 serde_json::from_value(result).map_err(|err| {
197 Error::tool(
198 self.name(),
199 format!("Invalid extension tool output (expected ToolOutput JSON): {err}"),
200 )
201 })
202 }
203}
204
205#[cfg(feature = "wasm-host")]
206#[async_trait]
207impl Tool for WasmExtensionToolWrapper {
208 fn name(&self) -> &str {
209 &self.def.name
210 }
211
212 fn label(&self) -> &str {
213 self.def.label.as_deref().unwrap_or(&self.def.name)
214 }
215
216 fn description(&self) -> &str {
217 &self.def.description
218 }
219
220 fn parameters(&self) -> Value {
221 self.def.parameters.clone()
222 }
223
224 async fn execute(
225 &self,
226 _tool_call_id: &str,
227 input: Value,
228 _on_update: Option<Box<dyn Fn(ToolUpdate) + Send + Sync>>,
229 ) -> Result<ToolOutput> {
230 let fut = self.handle.handle_tool(&self.def.name, &input);
231 let output_json = if self.timeout_ms > 0 {
232 match timeout(
233 wall_now(),
234 Duration::from_millis(self.timeout_ms),
235 Box::pin(fut),
236 )
237 .await
238 {
239 Ok(result) => result,
240 Err(_) => {
241 return Err(Error::tool(
242 self.name(),
243 format!(
244 "WASM tool '{}' timed out after {}ms",
245 self.name(),
246 self.timeout_ms
247 ),
248 ));
249 }
250 }
251 } else {
252 fut.await
253 }
254 .map_err(|err| Error::tool(self.name(), err.to_string()))?;
255
256 serde_json::from_str(&output_json).map_err(|err| {
257 Error::tool(
258 self.name(),
259 format!("Invalid WASM tool output (expected ToolOutput JSON): {err}"),
260 )
261 })
262 }
263}
264
265#[cfg(test)]
266mod tests {
267 use super::*;
268
269 use crate::agent::{Agent, AgentConfig, AgentEvent, AgentSession};
270 use crate::extensions::{ExtensionManager, JsExtensionLoadSpec};
271 use crate::extensions_js::PiJsRuntimeConfig;
272 use crate::model::{
273 AssistantMessage, ContentBlock, Message, StopReason, StreamEvent, TextContent, ToolCall,
274 Usage,
275 };
276 use crate::provider::{Context, Provider, StreamOptions};
277 use crate::session::Session;
278 use crate::tools::ToolRegistry;
279 use asupersync::runtime::RuntimeBuilder;
280 use asupersync::sync::Mutex;
281 use async_trait::async_trait;
282 use futures::Stream;
283 use serde_json::json;
284 use std::pin::Pin;
285 use std::sync::Arc;
286
287 async fn setup_js_tool(
288 source: &str,
289 tool_name: &str,
290 ) -> (
291 tempfile::TempDir,
292 ExtensionManager,
293 JsExtensionRuntimeHandle,
294 ExtensionToolDef,
295 ) {
296 let temp_dir = tempfile::tempdir().expect("tempdir");
297 let entry_path = temp_dir.path().join("ext.mjs");
298 std::fs::write(&entry_path, source).expect("write extension entry");
299
300 let manager = ExtensionManager::new();
301 let tools = Arc::new(ToolRegistry::new(&[], temp_dir.path(), None));
302 let js_runtime = JsExtensionRuntimeHandle::start(
303 PiJsRuntimeConfig {
304 cwd: temp_dir.path().display().to_string(),
305 ..Default::default()
306 },
307 Arc::clone(&tools),
308 manager.clone(),
309 )
310 .await
311 .expect("start js runtime");
312 manager.set_js_runtime(js_runtime.clone());
313
314 let spec = JsExtensionLoadSpec::from_entry_path(&entry_path).expect("spec");
315 manager
316 .load_js_extensions(vec![spec])
317 .await
318 .expect("load js extensions");
319
320 let def = js_runtime
321 .get_registered_tools()
322 .await
323 .expect("get registered tools")
324 .into_iter()
325 .find(|tool| tool.name == tool_name)
326 .expect("tool registered");
327
328 (temp_dir, manager, js_runtime, def)
329 }
330
331 #[test]
332 fn extension_tool_wrapper_executes_registered_tool() {
333 let runtime = RuntimeBuilder::current_thread()
334 .build()
335 .expect("runtime build");
336
337 runtime.block_on(async {
338 let temp_dir = tempfile::tempdir().expect("tempdir");
339 let entry_path = temp_dir.path().join("ext.mjs");
340 std::fs::write(
341 &entry_path,
342 r#"
343 export default function init(pi) {
344 pi.registerTool({
345 name: "hello_tool",
346 label: "hello_tool",
347 description: "test tool",
348 parameters: { type: "object", properties: { name: { type: "string" } } },
349 execute: async (_callId, input, _onUpdate, _abort, ctx) => {
350 const who = input && input.name ? String(input.name) : "world";
351 const cwd = ctx && ctx.cwd ? String(ctx.cwd) : "";
352 return {
353 content: [{ type: "text", text: `hello ${who}` }],
354 details: { from: "extension", cwd: cwd },
355 isError: false
356 };
357 }
358 });
359 }
360 "#,
361 )
362 .expect("write extension entry");
363
364 let manager = ExtensionManager::new();
365 let tools = Arc::new(ToolRegistry::new(&[], temp_dir.path(), None));
366 let js_runtime = JsExtensionRuntimeHandle::start(
367 PiJsRuntimeConfig {
368 cwd: temp_dir.path().display().to_string(),
369 ..Default::default()
370 },
371 Arc::clone(&tools),
372 manager.clone(),
373 )
374 .await
375 .expect("start js runtime");
376 manager.set_js_runtime(js_runtime.clone());
377
378 let spec = JsExtensionLoadSpec::from_entry_path(&entry_path).expect("spec");
379 manager
380 .load_js_extensions(vec![spec])
381 .await
382 .expect("load js extensions");
383
384 let tool_defs = js_runtime
385 .get_registered_tools()
386 .await
387 .expect("get registered tools");
388 let def = tool_defs
389 .into_iter()
390 .find(|tool| tool.name == "hello_tool")
391 .expect("hello_tool registered");
392
393 let wrapper = ExtensionToolWrapper::new(def, js_runtime).with_ctx_payload(json!({
394 "cwd": temp_dir.path().display().to_string()
395 }));
396
397 let output = wrapper
398 .execute("call-1", json!({ "name": "pi" }), None)
399 .await
400 .expect("execute tool");
401
402 assert!(!output.is_error);
403
404 match output.content.as_slice() {
405 [ContentBlock::Text(text)] => assert_eq!(text.text, "hello pi"),
406 other => assert!(
407 matches!(other, [ContentBlock::Text(_)]),
408 "Expected single text content block, got: {other:?}"
409 ),
410 }
411
412 let details = output.details.expect("details present");
413 assert_eq!(
414 details.get("from").and_then(Value::as_str),
415 Some("extension")
416 );
417 let cwd = temp_dir.path().display().to_string();
418 assert_eq!(
419 details.get("cwd").and_then(Value::as_str),
420 Some(cwd.as_str())
421 );
422 });
423 }
424
425 #[test]
426 fn extension_tool_wrapper_metadata_and_timeout_clamp() {
427 let runtime = RuntimeBuilder::current_thread()
428 .build()
429 .expect("runtime build");
430
431 runtime.block_on(async {
432 let source = r#"
433 export default function init(pi) {
434 pi.registerTool({
435 name: "meta_tool",
436 label: "Meta Tool",
437 description: "metadata test tool",
438 parameters: { type: "object", properties: { x: { type: "number" } } },
439 execute: async (_callId, _input, _onUpdate, _abort, _ctx) => ({
440 content: [{ type: "text", text: "ok" }],
441 isError: false
442 })
443 });
444 }
445 "#;
446 let (_temp_dir, _manager, js_runtime, def) = setup_js_tool(source, "meta_tool").await;
447
448 let wrapper = ExtensionToolWrapper::new(def.clone(), js_runtime.clone())
449 .with_timeout_ms(0)
450 .with_ctx_payload(json!({"cwd": "/tmp"}));
451 assert_eq!(wrapper.timeout_ms, 1);
452 assert_eq!(wrapper.name(), "meta_tool");
453 assert_eq!(wrapper.label(), "Meta Tool");
454 assert_eq!(wrapper.description(), "metadata test tool");
455 assert_eq!(
456 wrapper.parameters(),
457 json!({ "type": "object", "properties": { "x": { "type": "number" } } })
458 );
459
460 let mut no_label = def;
461 no_label.label = None;
462 let fallback = ExtensionToolWrapper::new(no_label, js_runtime).with_timeout_ms(25);
463 assert_eq!(fallback.timeout_ms, 25);
464 assert_eq!(fallback.label(), "meta_tool");
465 });
466 }
467
468 #[test]
469 fn extension_tool_wrapper_maps_invalid_output_to_tool_error() {
470 let runtime = RuntimeBuilder::current_thread()
471 .build()
472 .expect("runtime build");
473
474 runtime.block_on(async {
475 let source = r#"
476 export default function init(pi) {
477 pi.registerTool({
478 name: "broken_tool",
479 label: "broken_tool",
480 description: "returns invalid output payload",
481 parameters: { type: "object", properties: {} },
482 execute: async (_callId, _input, _onUpdate, _abort, _ctx) => ({
483 nope: true
484 })
485 });
486 }
487 "#;
488 let (_temp_dir, _manager, js_runtime, def) = setup_js_tool(source, "broken_tool").await;
489
490 let wrapper = ExtensionToolWrapper::new(def, js_runtime);
491 let err = wrapper
492 .execute("call-1", json!({}), None)
493 .await
494 .expect_err("invalid tool output should fail");
495
496 match err {
497 Error::Tool { tool, message } => {
498 assert_eq!(tool, "broken_tool");
499 assert!(message.contains("Invalid extension tool output"));
500 }
501 other => panic!("expected tool error, got {other:?}"),
502 }
503 });
504 }
505
506 #[derive(Debug)]
507 struct ToolCallingProvider;
508
509 #[async_trait]
510 #[allow(clippy::unnecessary_literal_bound)]
511 impl Provider for ToolCallingProvider {
512 fn name(&self) -> &str {
513 "test-provider"
514 }
515
516 fn api(&self) -> &str {
517 "test-api"
518 }
519
520 fn model_id(&self) -> &str {
521 "test-model"
522 }
523
524 async fn stream(
525 &self,
526 context: &Context<'_>,
527 _options: &StreamOptions,
528 ) -> crate::error::Result<
529 Pin<Box<dyn Stream<Item = crate::error::Result<StreamEvent>> + Send>>,
530 > {
531 fn assistant_message(content: Vec<ContentBlock>) -> AssistantMessage {
532 AssistantMessage {
533 content,
534 api: "test-api".to_string(),
535 provider: "test-provider".to_string(),
536 model: "test-model".to_string(),
537 usage: Usage::default(),
538 stop_reason: StopReason::Stop,
539 error_message: None,
540 timestamp: 0,
541 }
542 }
543
544 let tool_def_present = context.tools.iter().any(|tool| tool.name == "hello_tool");
545 let tool_result = context.messages.iter().find_map(|message| match message {
546 Message::ToolResult(result) if result.tool_name == "hello_tool" => Some(result),
547 _ => None,
548 });
549
550 if let Some(result) = tool_result {
551 match result.content.as_slice() {
552 [ContentBlock::Text(text)] => assert_eq!(text.text, "hello pi"),
553 other => panic!("Expected single text content block, got: {other:?}"),
554 }
555
556 let events = vec![
557 Ok(StreamEvent::Start {
558 partial: assistant_message(Vec::new()),
559 }),
560 Ok(StreamEvent::Done {
561 reason: StopReason::Stop,
562 message: assistant_message(vec![ContentBlock::Text(TextContent::new(
563 "done",
564 ))]),
565 }),
566 ];
567 return Ok(Box::pin(futures::stream::iter(events)));
568 }
569
570 assert!(
571 tool_def_present,
572 "Expected extension tool to be present in provider tool defs"
573 );
574
575 let tool_call = ToolCall {
576 id: "call-1".to_string(),
577 name: "hello_tool".to_string(),
578 arguments: json!({ "name": "pi" }),
579 thought_signature: None,
580 };
581
582 let events = vec![
583 Ok(StreamEvent::Start {
584 partial: assistant_message(Vec::new()),
585 }),
586 Ok(StreamEvent::Done {
587 reason: StopReason::Stop,
588 message: assistant_message(vec![ContentBlock::ToolCall(tool_call)]),
589 }),
590 ];
591 Ok(Box::pin(futures::stream::iter(events)))
592 }
593 }
594
595 #[test]
596 fn agent_executes_extension_tool_registered_via_js() {
597 let runtime = RuntimeBuilder::current_thread()
598 .build()
599 .expect("runtime build");
600
601 runtime.block_on(async {
602 let temp_dir = tempfile::tempdir().expect("tempdir");
603 let entry_path = temp_dir.path().join("ext.mjs");
604 std::fs::write(
605 &entry_path,
606 r#"
607 export default function init(pi) {
608 pi.registerTool({
609 name: "hello_tool",
610 label: "hello_tool",
611 description: "test tool",
612 parameters: { type: "object", properties: { name: { type: "string" } } },
613 execute: async (_callId, input, _onUpdate, _abort, _ctx) => {
614 const who = input && input.name ? String(input.name) : "world";
615 return {
616 content: [{ type: "text", text: `hello ${who}` }],
617 details: { from: "extension" },
618 isError: false
619 };
620 }
621 });
622 }
623 "#,
624 )
625 .expect("write extension entry");
626
627 let manager = ExtensionManager::new();
628 let tools_for_runtime = Arc::new(ToolRegistry::new(&[], temp_dir.path(), None));
629 let js_runtime = JsExtensionRuntimeHandle::start(
630 PiJsRuntimeConfig {
631 cwd: temp_dir.path().display().to_string(),
632 ..Default::default()
633 },
634 Arc::clone(&tools_for_runtime),
635 manager.clone(),
636 )
637 .await
638 .expect("start js runtime");
639 manager.set_js_runtime(js_runtime.clone());
640
641 let spec = JsExtensionLoadSpec::from_entry_path(&entry_path).expect("spec");
642 manager
643 .load_js_extensions(vec![spec])
644 .await
645 .expect("load js extensions");
646
647 let wrappers = collect_extension_tool_wrappers(
648 &manager,
649 json!({ "cwd": temp_dir.path().display().to_string() }),
650 )
651 .await
652 .expect("collect wrappers");
653 assert_eq!(wrappers.len(), 1);
654
655 let provider = Arc::new(ToolCallingProvider);
656 let tools = ToolRegistry::new(&[], temp_dir.path(), None);
657 let mut agent = Agent::new(provider, tools, AgentConfig::default());
658 agent.extend_tools(wrappers);
659
660 let session = Arc::new(Mutex::new(Session::in_memory()));
661 let mut agent_session = AgentSession::new(
662 agent,
663 session,
664 false,
665 crate::compaction::ResolvedCompactionSettings::default(),
666 );
667 let message = agent_session
668 .run_text("hi".to_string(), |_event: AgentEvent| {})
669 .await
670 .expect("run_text");
671
672 match message.content.as_slice() {
673 [ContentBlock::Text(text)] => assert_eq!(text.text, "done"),
674 other => panic!("Expected single text content block, got: {other:?}"),
675 }
676 });
677 }
678
679 #[test]
682 fn extension_tool_wrapper_default_timeout() {
683 let runtime = RuntimeBuilder::current_thread()
684 .build()
685 .expect("runtime build");
686
687 runtime.block_on(async {
688 let source = r#"
689 export default function init(pi) {
690 pi.registerTool({
691 name: "t",
692 description: "d",
693 parameters: { type: "object" },
694 execute: async () => ({ content: [], isError: false })
695 });
696 }
697 "#;
698 let (_temp_dir, _manager, js_runtime, def) = setup_js_tool(source, "t").await;
699 let wrapper = ExtensionToolWrapper::new(def, js_runtime);
700 assert_eq!(wrapper.timeout_ms, DEFAULT_EXTENSION_TOOL_TIMEOUT_MS);
701 assert_eq!(wrapper.timeout_ms, 60_000);
702 });
703 }
704
705 #[test]
706 fn extension_tool_wrapper_timeout_clamp_boundary() {
707 let runtime = RuntimeBuilder::current_thread()
708 .build()
709 .expect("runtime build");
710
711 runtime.block_on(async {
712 let source = r#"
713 export default function init(pi) {
714 pi.registerTool({
715 name: "t",
716 description: "d",
717 parameters: { type: "object" },
718 execute: async () => ({ content: [], isError: false })
719 });
720 }
721 "#;
722 let (_temp_dir, _manager, js_runtime, def) = setup_js_tool(source, "t").await;
723
724 let w0 = ExtensionToolWrapper::new(def.clone(), js_runtime.clone()).with_timeout_ms(0);
726 assert_eq!(w0.timeout_ms, 1);
727
728 let w1 = ExtensionToolWrapper::new(def.clone(), js_runtime.clone()).with_timeout_ms(1);
730 assert_eq!(w1.timeout_ms, 1);
731
732 let wmax = ExtensionToolWrapper::new(def, js_runtime).with_timeout_ms(u64::MAX);
734 assert_eq!(wmax.timeout_ms, u64::MAX);
735 });
736 }
737
738 #[test]
739 fn extension_tool_wrapper_ctx_payload_default_empty() {
740 let runtime = RuntimeBuilder::current_thread()
741 .build()
742 .expect("runtime build");
743
744 runtime.block_on(async {
745 let source = r#"
746 export default function init(pi) {
747 pi.registerTool({
748 name: "t",
749 description: "d",
750 parameters: { type: "object" },
751 execute: async () => ({ content: [], isError: false })
752 });
753 }
754 "#;
755 let (_temp_dir, _manager, js_runtime, def) = setup_js_tool(source, "t").await;
756 let wrapper = ExtensionToolWrapper::new(def, js_runtime);
757 assert_eq!(wrapper.ctx_payload.as_ref(), &json!({}));
758 });
759 }
760
761 #[test]
762 fn extension_tool_wrapper_ctx_payload_override() {
763 let runtime = RuntimeBuilder::current_thread()
764 .build()
765 .expect("runtime build");
766
767 runtime.block_on(async {
768 let source = r#"
769 export default function init(pi) {
770 pi.registerTool({
771 name: "t",
772 description: "d",
773 parameters: { type: "object" },
774 execute: async () => ({ content: [], isError: false })
775 });
776 }
777 "#;
778 let (_temp_dir, _manager, js_runtime, def) = setup_js_tool(source, "t").await;
779 let custom_ctx = json!({"cwd": "/tmp", "user": "test"});
780 let wrapper =
781 ExtensionToolWrapper::new(def, js_runtime).with_ctx_payload(custom_ctx.clone());
782 assert_eq!(wrapper.ctx_payload.as_ref(), &custom_ctx);
783 });
784 }
785
786 #[test]
789 fn collect_wrappers_no_js_runtime_returns_empty() {
790 let runtime = RuntimeBuilder::current_thread()
791 .build()
792 .expect("runtime build");
793
794 runtime.block_on(async {
795 let manager = ExtensionManager::new();
796 let wrappers = collect_extension_tool_wrappers(&manager, json!({}))
797 .await
798 .expect("collect wrappers");
799 assert!(wrappers.is_empty());
800 });
801 }
802
803 #[test]
804 fn collect_wrappers_multiple_tools_from_one_extension() {
805 let runtime = RuntimeBuilder::current_thread()
806 .build()
807 .expect("runtime build");
808
809 runtime.block_on(async {
810 let temp_dir = tempfile::tempdir().expect("tempdir");
811 let entry_path = temp_dir.path().join("ext.mjs");
812 std::fs::write(
813 &entry_path,
814 r#"
815 export default function init(pi) {
816 pi.registerTool({
817 name: "tool_alpha",
818 description: "first tool",
819 parameters: { type: "object" },
820 execute: async () => ({ content: [{ type: "text", text: "alpha" }], isError: false })
821 });
822 pi.registerTool({
823 name: "tool_beta",
824 description: "second tool",
825 parameters: { type: "object" },
826 execute: async () => ({ content: [{ type: "text", text: "beta" }], isError: false })
827 });
828 }
829 "#,
830 )
831 .expect("write extension");
832
833 let manager = ExtensionManager::new();
834 let tools = Arc::new(ToolRegistry::new(&[], temp_dir.path(), None));
835 let js_runtime = JsExtensionRuntimeHandle::start(
836 PiJsRuntimeConfig {
837 cwd: temp_dir.path().display().to_string(),
838 ..Default::default()
839 },
840 Arc::clone(&tools),
841 manager.clone(),
842 )
843 .await
844 .expect("start js runtime");
845 manager.set_js_runtime(js_runtime.clone());
846
847 let spec = JsExtensionLoadSpec::from_entry_path(&entry_path).expect("spec");
848 manager
849 .load_js_extensions(vec![spec])
850 .await
851 .expect("load js extensions");
852
853 let wrappers = collect_extension_tool_wrappers(&manager, json!({}))
854 .await
855 .expect("collect wrappers");
856 assert_eq!(wrappers.len(), 2);
857
858 assert_eq!(wrappers[0].name(), "tool_alpha");
860 assert_eq!(wrappers[1].name(), "tool_beta");
861 });
862 }
863
864 #[test]
865 fn collect_wrappers_respects_active_tools_filter() {
866 let runtime = RuntimeBuilder::current_thread()
867 .build()
868 .expect("runtime build");
869
870 runtime.block_on(async {
871 let temp_dir = tempfile::tempdir().expect("tempdir");
872 let entry_path = temp_dir.path().join("ext.mjs");
873 std::fs::write(
874 &entry_path,
875 r#"
876 export default function init(pi) {
877 pi.registerTool({
878 name: "tool_keep",
879 description: "kept",
880 parameters: { type: "object" },
881 execute: async () => ({ content: [], isError: false })
882 });
883 pi.registerTool({
884 name: "tool_skip",
885 description: "skipped",
886 parameters: { type: "object" },
887 execute: async () => ({ content: [], isError: false })
888 });
889 }
890 "#,
891 )
892 .expect("write extension");
893
894 let manager = ExtensionManager::new();
895 let tools = Arc::new(ToolRegistry::new(&[], temp_dir.path(), None));
896 let js_runtime = JsExtensionRuntimeHandle::start(
897 PiJsRuntimeConfig {
898 cwd: temp_dir.path().display().to_string(),
899 ..Default::default()
900 },
901 Arc::clone(&tools),
902 manager.clone(),
903 )
904 .await
905 .expect("start js runtime");
906 manager.set_js_runtime(js_runtime.clone());
907
908 let spec = JsExtensionLoadSpec::from_entry_path(&entry_path).expect("spec");
909 manager
910 .load_js_extensions(vec![spec])
911 .await
912 .expect("load js extensions");
913
914 manager.set_active_tools(vec!["tool_keep".to_string()]);
916
917 let wrappers = collect_extension_tool_wrappers(&manager, json!({}))
918 .await
919 .expect("collect wrappers");
920 assert_eq!(wrappers.len(), 1);
921 assert_eq!(wrappers[0].name(), "tool_keep");
922 });
923 }
924
925 #[test]
926 fn extension_tool_wrapper_js_error_maps_to_tool_error() {
927 let runtime = RuntimeBuilder::current_thread()
928 .build()
929 .expect("runtime build");
930
931 runtime.block_on(async {
932 let source = r#"
933 export default function init(pi) {
934 pi.registerTool({
935 name: "throwing_tool",
936 description: "throws an error",
937 parameters: { type: "object" },
938 execute: async () => { throw new Error("boom!"); }
939 });
940 }
941 "#;
942 let (_temp_dir, _manager, js_runtime, def) =
943 setup_js_tool(source, "throwing_tool").await;
944
945 let wrapper = ExtensionToolWrapper::new(def, js_runtime);
946 let err = wrapper
947 .execute("call-1", json!({}), None)
948 .await
949 .expect_err("throwing tool should fail");
950
951 match err {
952 Error::Tool { tool, message } => {
953 assert_eq!(tool, "throwing_tool");
954 assert!(
955 message.contains("boom") || message.contains("error"),
956 "Expected error message to reference the thrown error, got: {message}"
957 );
958 }
959 other => panic!("expected tool error, got {other:?}"),
960 }
961 });
962 }
963
964 #[test]
965 fn extension_tool_wrapper_empty_content_result() {
966 let runtime = RuntimeBuilder::current_thread()
967 .build()
968 .expect("runtime build");
969
970 runtime.block_on(async {
971 let source = r#"
972 export default function init(pi) {
973 pi.registerTool({
974 name: "empty_tool",
975 description: "returns empty content",
976 parameters: { type: "object" },
977 execute: async () => ({
978 content: [],
979 isError: false
980 })
981 });
982 }
983 "#;
984 let (_temp_dir, _manager, js_runtime, def) = setup_js_tool(source, "empty_tool").await;
985
986 let wrapper = ExtensionToolWrapper::new(def, js_runtime);
987 let output = wrapper
988 .execute("call-1", json!({}), None)
989 .await
990 .expect("execute tool");
991
992 assert!(!output.is_error);
993 assert!(output.content.is_empty());
994 });
995 }
996
997 #[test]
998 fn extension_tool_wrapper_is_error_flag() {
999 let runtime = RuntimeBuilder::current_thread()
1000 .build()
1001 .expect("runtime build");
1002
1003 runtime.block_on(async {
1004 let source = r#"
1005 export default function init(pi) {
1006 pi.registerTool({
1007 name: "error_tool",
1008 description: "returns error flag",
1009 parameters: { type: "object" },
1010 execute: async () => ({
1011 content: [{ type: "text", text: "something went wrong" }],
1012 isError: true
1013 })
1014 });
1015 }
1016 "#;
1017 let (_temp_dir, _manager, js_runtime, def) = setup_js_tool(source, "error_tool").await;
1018
1019 let wrapper = ExtensionToolWrapper::new(def, js_runtime);
1020 let output = wrapper
1021 .execute("call-1", json!({}), None)
1022 .await
1023 .expect("execute tool");
1024
1025 assert!(output.is_error);
1026 match output.content.as_slice() {
1027 [ContentBlock::Text(text)] => {
1028 assert_eq!(text.text, "something went wrong");
1029 }
1030 other => panic!("expected text content, got {other:?}"),
1031 }
1032 });
1033 }
1034
1035 #[test]
1036 fn extension_tool_wrapper_passes_input_to_handler() {
1037 let runtime = RuntimeBuilder::current_thread()
1038 .build()
1039 .expect("runtime build");
1040
1041 runtime.block_on(async {
1042 let source = r#"
1043 export default function init(pi) {
1044 pi.registerTool({
1045 name: "echo_tool",
1046 description: "echoes input",
1047 parameters: { type: "object", properties: { msg: { type: "string" } } },
1048 execute: async (_callId, input) => ({
1049 content: [{ type: "text", text: JSON.stringify(input) }],
1050 isError: false
1051 })
1052 });
1053 }
1054 "#;
1055 let (_temp_dir, _manager, js_runtime, def) = setup_js_tool(source, "echo_tool").await;
1056
1057 let wrapper = ExtensionToolWrapper::new(def, js_runtime);
1058 let output = wrapper
1059 .execute("call-1", json!({"msg": "hello world"}), None)
1060 .await
1061 .expect("execute tool");
1062
1063 assert!(!output.is_error);
1064 match output.content.as_slice() {
1065 [ContentBlock::Text(text)] => {
1066 let parsed: serde_json::Value =
1067 serde_json::from_str(&text.text).expect("parse JSON");
1068 assert_eq!(parsed["msg"], "hello world");
1069 }
1070 other => panic!("expected text content, got {other:?}"),
1071 }
1072 });
1073 }
1074}