Skip to main content

pi/
extension_tools.rs

1//! Extension tools integration.
2//!
3//! This module provides adapters that allow JavaScript extension-registered tools to be used as
4//! normal Rust `Tool` implementations inside the agent tool registry.
5
6use async_trait::async_trait;
7use serde_json::Value;
8use std::collections::HashSet;
9use std::sync::Arc;
10#[cfg(feature = "wasm-host")]
11use std::time::Duration;
12
13use crate::error::{Error, Result};
14#[cfg(test)]
15use crate::extensions::JsExtensionRuntimeHandle;
16#[cfg(feature = "wasm-host")]
17use crate::extensions::WasmExtensionHandle;
18use crate::extensions::{ExtensionManager, ExtensionRuntimeHandle};
19use crate::extensions_js::ExtensionToolDef;
20use crate::tools::{Tool, ToolOutput, ToolUpdate};
21#[cfg(feature = "wasm-host")]
22use asupersync::time::{timeout, wall_now};
23
24const DEFAULT_EXTENSION_TOOL_TIMEOUT_MS: u64 = 60_000;
25
26/// Wraps a JS extension-registered tool so it can be used as a Rust [`Tool`].
27///
28/// Note: This wrapper uses [`ExtensionRuntimeHandle`] rather than
29/// [`crate::extensions_js::PiJsRuntime`] so it remains `Send + Sync` and can be
30/// stored in the shared tool registry.
31pub struct ExtensionToolWrapper {
32    def: ExtensionToolDef,
33    runtime: ExtensionRuntimeHandle,
34    ctx_payload: Arc<Value>,
35    timeout_ms: u64,
36}
37
38impl ExtensionToolWrapper {
39    #[must_use]
40    pub fn new<R>(def: ExtensionToolDef, runtime: R) -> Self
41    where
42        R: Into<ExtensionRuntimeHandle>,
43    {
44        Self {
45            def,
46            runtime: runtime.into(),
47            ctx_payload: Arc::new(Value::Object(serde_json::Map::new())),
48            timeout_ms: DEFAULT_EXTENSION_TOOL_TIMEOUT_MS,
49        }
50    }
51
52    #[must_use]
53    pub fn with_ctx_payload(mut self, ctx_payload: Value) -> Self {
54        self.ctx_payload = Arc::new(ctx_payload);
55        self
56    }
57
58    #[must_use]
59    pub fn with_ctx_payload_shared(mut self, ctx_payload: Arc<Value>) -> Self {
60        self.ctx_payload = ctx_payload;
61        self
62    }
63
64    #[must_use]
65    pub fn with_timeout_ms(mut self, timeout_ms: u64) -> Self {
66        self.timeout_ms = timeout_ms.max(1);
67        self
68    }
69}
70
71#[cfg(feature = "wasm-host")]
72pub struct WasmExtensionToolWrapper {
73    def: ExtensionToolDef,
74    handle: WasmExtensionHandle,
75    timeout_ms: u64,
76}
77
78#[cfg(feature = "wasm-host")]
79impl WasmExtensionToolWrapper {
80    #[must_use]
81    pub const fn new(def: ExtensionToolDef, handle: WasmExtensionHandle) -> Self {
82        Self {
83            def,
84            handle,
85            timeout_ms: DEFAULT_EXTENSION_TOOL_TIMEOUT_MS,
86        }
87    }
88
89    #[must_use]
90    pub fn with_timeout_ms(mut self, timeout_ms: u64) -> Self {
91        self.timeout_ms = timeout_ms.max(1);
92        self
93    }
94}
95
96/// Collect all registered extension tools and wrap them as Rust [`Tool`]s.
97///
98/// This is intended to be called after extensions are loaded/activated so the returned tool list
99/// can be injected into the agent [`crate::tools::ToolRegistry`].
100pub async fn collect_extension_tool_wrappers(
101    manager: &ExtensionManager,
102    ctx_payload: Value,
103) -> Result<Vec<Box<dyn Tool>>> {
104    let shared_ctx_payload = Arc::new(ctx_payload);
105    let active = manager
106        .active_tools()
107        .map(|tools| tools.into_iter().collect::<HashSet<_>>());
108
109    let mut wrappers: Vec<Box<dyn Tool>> = Vec::new();
110    let mut seen = HashSet::new();
111
112    if let Some(runtime) = manager.runtime() {
113        let mut defs = runtime.get_registered_tools().await?;
114        if let Some(active) = active.as_ref() {
115            defs.retain(|def| active.contains(&def.name));
116        }
117
118        defs.sort_by(|a, b| a.name.cmp(&b.name));
119        for def in defs {
120            if !seen.insert(def.name.clone()) {
121                tracing::warn!(tool = %def.name, "Duplicate extension tool name; ignoring");
122                continue;
123            }
124
125            wrappers.push(Box::new(
126                ExtensionToolWrapper::new(def, runtime.clone())
127                    .with_ctx_payload_shared(Arc::clone(&shared_ctx_payload)),
128            ));
129        }
130    }
131
132    #[cfg(feature = "wasm-host")]
133    {
134        let mut wasm_defs: Vec<(ExtensionToolDef, WasmExtensionHandle)> = Vec::new();
135        for handle in manager.wasm_extensions() {
136            for def in handle.tool_defs() {
137                wasm_defs.push((def.clone(), handle.clone()));
138            }
139        }
140
141        wasm_defs.sort_by(|a, b| a.0.name.cmp(&b.0.name));
142        for (def, handle) in wasm_defs {
143            if let Some(active) = active.as_ref() {
144                if !active.contains(&def.name) {
145                    continue;
146                }
147            }
148            if !seen.insert(def.name.clone()) {
149                tracing::warn!(tool = %def.name, "Duplicate extension tool name; ignoring");
150                continue;
151            }
152
153            wrappers.push(Box::new(WasmExtensionToolWrapper::new(def, handle)));
154        }
155    }
156
157    Ok(wrappers)
158}
159
160#[async_trait]
161impl Tool for ExtensionToolWrapper {
162    fn name(&self) -> &str {
163        &self.def.name
164    }
165
166    fn label(&self) -> &str {
167        self.def.label.as_deref().unwrap_or(&self.def.name)
168    }
169
170    fn description(&self) -> &str {
171        &self.def.description
172    }
173
174    fn parameters(&self) -> Value {
175        self.def.parameters.clone()
176    }
177
178    async fn execute(
179        &self,
180        tool_call_id: &str,
181        input: Value,
182        _on_update: Option<Box<dyn Fn(ToolUpdate) + Send + Sync>>,
183    ) -> Result<ToolOutput> {
184        let result = self
185            .runtime
186            .execute_tool_ref(
187                &self.def.name,
188                tool_call_id,
189                input,
190                Arc::clone(&self.ctx_payload),
191                self.timeout_ms,
192            )
193            .await
194            .map_err(|err| Error::tool(self.name(), err.to_string()))?;
195
196        serde_json::from_value(result).map_err(|err| {
197            Error::tool(
198                self.name(),
199                format!("Invalid extension tool output (expected ToolOutput JSON): {err}"),
200            )
201        })
202    }
203}
204
205#[cfg(feature = "wasm-host")]
206#[async_trait]
207impl Tool for WasmExtensionToolWrapper {
208    fn name(&self) -> &str {
209        &self.def.name
210    }
211
212    fn label(&self) -> &str {
213        self.def.label.as_deref().unwrap_or(&self.def.name)
214    }
215
216    fn description(&self) -> &str {
217        &self.def.description
218    }
219
220    fn parameters(&self) -> Value {
221        self.def.parameters.clone()
222    }
223
224    async fn execute(
225        &self,
226        _tool_call_id: &str,
227        input: Value,
228        _on_update: Option<Box<dyn Fn(ToolUpdate) + Send + Sync>>,
229    ) -> Result<ToolOutput> {
230        let fut = self.handle.handle_tool(&self.def.name, &input);
231        let output_json = if self.timeout_ms > 0 {
232            match timeout(
233                wall_now(),
234                Duration::from_millis(self.timeout_ms),
235                Box::pin(fut),
236            )
237            .await
238            {
239                Ok(result) => result,
240                Err(_) => {
241                    return Err(Error::tool(
242                        self.name(),
243                        format!(
244                            "WASM tool '{}' timed out after {}ms",
245                            self.name(),
246                            self.timeout_ms
247                        ),
248                    ));
249                }
250            }
251        } else {
252            fut.await
253        }
254        .map_err(|err| Error::tool(self.name(), err.to_string()))?;
255
256        serde_json::from_str(&output_json).map_err(|err| {
257            Error::tool(
258                self.name(),
259                format!("Invalid WASM tool output (expected ToolOutput JSON): {err}"),
260            )
261        })
262    }
263}
264
265#[cfg(test)]
266mod tests {
267    use super::*;
268
269    use crate::agent::{Agent, AgentConfig, AgentEvent, AgentSession};
270    use crate::extensions::{ExtensionManager, JsExtensionLoadSpec};
271    use crate::extensions_js::PiJsRuntimeConfig;
272    use crate::model::{
273        AssistantMessage, ContentBlock, Message, StopReason, StreamEvent, TextContent, ToolCall,
274        Usage,
275    };
276    use crate::provider::{Context, Provider, StreamOptions};
277    use crate::session::Session;
278    use crate::tools::ToolRegistry;
279    use asupersync::runtime::RuntimeBuilder;
280    use asupersync::sync::Mutex;
281    use async_trait::async_trait;
282    use futures::Stream;
283    use serde_json::json;
284    use std::pin::Pin;
285    use std::sync::Arc;
286
287    async fn setup_js_tool(
288        source: &str,
289        tool_name: &str,
290    ) -> (
291        tempfile::TempDir,
292        ExtensionManager,
293        JsExtensionRuntimeHandle,
294        ExtensionToolDef,
295    ) {
296        let temp_dir = tempfile::tempdir().expect("tempdir");
297        let entry_path = temp_dir.path().join("ext.mjs");
298        std::fs::write(&entry_path, source).expect("write extension entry");
299
300        let manager = ExtensionManager::new();
301        let tools = Arc::new(ToolRegistry::new(&[], temp_dir.path(), None));
302        let js_runtime = JsExtensionRuntimeHandle::start(
303            PiJsRuntimeConfig {
304                cwd: temp_dir.path().display().to_string(),
305                ..Default::default()
306            },
307            Arc::clone(&tools),
308            manager.clone(),
309        )
310        .await
311        .expect("start js runtime");
312        manager.set_js_runtime(js_runtime.clone());
313
314        let spec = JsExtensionLoadSpec::from_entry_path(&entry_path).expect("spec");
315        manager
316            .load_js_extensions(vec![spec])
317            .await
318            .expect("load js extensions");
319
320        let def = js_runtime
321            .get_registered_tools()
322            .await
323            .expect("get registered tools")
324            .into_iter()
325            .find(|tool| tool.name == tool_name)
326            .expect("tool registered");
327
328        (temp_dir, manager, js_runtime, def)
329    }
330
331    #[test]
332    fn extension_tool_wrapper_executes_registered_tool() {
333        let runtime = RuntimeBuilder::current_thread()
334            .build()
335            .expect("runtime build");
336
337        runtime.block_on(async {
338            let temp_dir = tempfile::tempdir().expect("tempdir");
339            let entry_path = temp_dir.path().join("ext.mjs");
340            std::fs::write(
341                &entry_path,
342                r#"
343                export default function init(pi) {
344                  pi.registerTool({
345                    name: "hello_tool",
346                    label: "hello_tool",
347                    description: "test tool",
348                    parameters: { type: "object", properties: { name: { type: "string" } } },
349                    execute: async (_callId, input, _onUpdate, _abort, ctx) => {
350                      const who = input && input.name ? String(input.name) : "world";
351                      const cwd = ctx && ctx.cwd ? String(ctx.cwd) : "";
352                      return {
353                        content: [{ type: "text", text: `hello ${who}` }],
354                        details: { from: "extension", cwd: cwd },
355                        isError: false
356                      };
357                    }
358                  });
359                }
360                "#,
361            )
362            .expect("write extension entry");
363
364            let manager = ExtensionManager::new();
365            let tools = Arc::new(ToolRegistry::new(&[], temp_dir.path(), None));
366            let js_runtime = JsExtensionRuntimeHandle::start(
367                PiJsRuntimeConfig {
368                    cwd: temp_dir.path().display().to_string(),
369                    ..Default::default()
370                },
371                Arc::clone(&tools),
372                manager.clone(),
373            )
374            .await
375            .expect("start js runtime");
376            manager.set_js_runtime(js_runtime.clone());
377
378            let spec = JsExtensionLoadSpec::from_entry_path(&entry_path).expect("spec");
379            manager
380                .load_js_extensions(vec![spec])
381                .await
382                .expect("load js extensions");
383
384            let tool_defs = js_runtime
385                .get_registered_tools()
386                .await
387                .expect("get registered tools");
388            let def = tool_defs
389                .into_iter()
390                .find(|tool| tool.name == "hello_tool")
391                .expect("hello_tool registered");
392
393            let wrapper = ExtensionToolWrapper::new(def, js_runtime).with_ctx_payload(json!({
394                "cwd": temp_dir.path().display().to_string()
395            }));
396
397            let output = wrapper
398                .execute("call-1", json!({ "name": "pi" }), None)
399                .await
400                .expect("execute tool");
401
402            assert!(!output.is_error);
403
404            match output.content.as_slice() {
405                [ContentBlock::Text(text)] => assert_eq!(text.text, "hello pi"),
406                other => assert!(
407                    matches!(other, [ContentBlock::Text(_)]),
408                    "Expected single text content block, got: {other:?}"
409                ),
410            }
411
412            let details = output.details.expect("details present");
413            assert_eq!(
414                details.get("from").and_then(Value::as_str),
415                Some("extension")
416            );
417            let cwd = temp_dir.path().display().to_string();
418            assert_eq!(
419                details.get("cwd").and_then(Value::as_str),
420                Some(cwd.as_str())
421            );
422        });
423    }
424
425    #[test]
426    fn extension_tool_wrapper_metadata_and_timeout_clamp() {
427        let runtime = RuntimeBuilder::current_thread()
428            .build()
429            .expect("runtime build");
430
431        runtime.block_on(async {
432            let source = r#"
433                export default function init(pi) {
434                  pi.registerTool({
435                    name: "meta_tool",
436                    label: "Meta Tool",
437                    description: "metadata test tool",
438                    parameters: { type: "object", properties: { x: { type: "number" } } },
439                    execute: async (_callId, _input, _onUpdate, _abort, _ctx) => ({
440                      content: [{ type: "text", text: "ok" }],
441                      isError: false
442                    })
443                  });
444                }
445                "#;
446            let (_temp_dir, _manager, js_runtime, def) = setup_js_tool(source, "meta_tool").await;
447
448            let wrapper = ExtensionToolWrapper::new(def.clone(), js_runtime.clone())
449                .with_timeout_ms(0)
450                .with_ctx_payload(json!({"cwd": "/tmp"}));
451            assert_eq!(wrapper.timeout_ms, 1);
452            assert_eq!(wrapper.name(), "meta_tool");
453            assert_eq!(wrapper.label(), "Meta Tool");
454            assert_eq!(wrapper.description(), "metadata test tool");
455            assert_eq!(
456                wrapper.parameters(),
457                json!({ "type": "object", "properties": { "x": { "type": "number" } } })
458            );
459
460            let mut no_label = def;
461            no_label.label = None;
462            let fallback = ExtensionToolWrapper::new(no_label, js_runtime).with_timeout_ms(25);
463            assert_eq!(fallback.timeout_ms, 25);
464            assert_eq!(fallback.label(), "meta_tool");
465        });
466    }
467
468    #[test]
469    fn extension_tool_wrapper_maps_invalid_output_to_tool_error() {
470        let runtime = RuntimeBuilder::current_thread()
471            .build()
472            .expect("runtime build");
473
474        runtime.block_on(async {
475            let source = r#"
476                export default function init(pi) {
477                  pi.registerTool({
478                    name: "broken_tool",
479                    label: "broken_tool",
480                    description: "returns invalid output payload",
481                    parameters: { type: "object", properties: {} },
482                    execute: async (_callId, _input, _onUpdate, _abort, _ctx) => ({
483                      nope: true
484                    })
485                  });
486                }
487                "#;
488            let (_temp_dir, _manager, js_runtime, def) = setup_js_tool(source, "broken_tool").await;
489
490            let wrapper = ExtensionToolWrapper::new(def, js_runtime);
491            let err = wrapper
492                .execute("call-1", json!({}), None)
493                .await
494                .expect_err("invalid tool output should fail");
495
496            match err {
497                Error::Tool { tool, message } => {
498                    assert_eq!(tool, "broken_tool");
499                    assert!(message.contains("Invalid extension tool output"));
500                }
501                other => panic!("expected tool error, got {other:?}"),
502            }
503        });
504    }
505
506    #[derive(Debug)]
507    struct ToolCallingProvider;
508
509    #[async_trait]
510    #[allow(clippy::unnecessary_literal_bound)]
511    impl Provider for ToolCallingProvider {
512        fn name(&self) -> &str {
513            "test-provider"
514        }
515
516        fn api(&self) -> &str {
517            "test-api"
518        }
519
520        fn model_id(&self) -> &str {
521            "test-model"
522        }
523
524        async fn stream(
525            &self,
526            context: &Context<'_>,
527            _options: &StreamOptions,
528        ) -> crate::error::Result<
529            Pin<Box<dyn Stream<Item = crate::error::Result<StreamEvent>> + Send>>,
530        > {
531            fn assistant_message(content: Vec<ContentBlock>) -> AssistantMessage {
532                AssistantMessage {
533                    content,
534                    api: "test-api".to_string(),
535                    provider: "test-provider".to_string(),
536                    model: "test-model".to_string(),
537                    usage: Usage::default(),
538                    stop_reason: StopReason::Stop,
539                    error_message: None,
540                    timestamp: 0,
541                }
542            }
543
544            let tool_def_present = context.tools.iter().any(|tool| tool.name == "hello_tool");
545            let tool_result = context.messages.iter().find_map(|message| match message {
546                Message::ToolResult(result) if result.tool_name == "hello_tool" => Some(result),
547                _ => None,
548            });
549
550            if let Some(result) = tool_result {
551                match result.content.as_slice() {
552                    [ContentBlock::Text(text)] => assert_eq!(text.text, "hello pi"),
553                    other => panic!("Expected single text content block, got: {other:?}"),
554                }
555
556                let events = vec![
557                    Ok(StreamEvent::Start {
558                        partial: assistant_message(Vec::new()),
559                    }),
560                    Ok(StreamEvent::Done {
561                        reason: StopReason::Stop,
562                        message: assistant_message(vec![ContentBlock::Text(TextContent::new(
563                            "done",
564                        ))]),
565                    }),
566                ];
567                return Ok(Box::pin(futures::stream::iter(events)));
568            }
569
570            assert!(
571                tool_def_present,
572                "Expected extension tool to be present in provider tool defs"
573            );
574
575            let tool_call = ToolCall {
576                id: "call-1".to_string(),
577                name: "hello_tool".to_string(),
578                arguments: json!({ "name": "pi" }),
579                thought_signature: None,
580            };
581
582            let events = vec![
583                Ok(StreamEvent::Start {
584                    partial: assistant_message(Vec::new()),
585                }),
586                Ok(StreamEvent::Done {
587                    reason: StopReason::Stop,
588                    message: assistant_message(vec![ContentBlock::ToolCall(tool_call)]),
589                }),
590            ];
591            Ok(Box::pin(futures::stream::iter(events)))
592        }
593    }
594
595    #[test]
596    fn agent_executes_extension_tool_registered_via_js() {
597        let runtime = RuntimeBuilder::current_thread()
598            .build()
599            .expect("runtime build");
600
601        runtime.block_on(async {
602            let temp_dir = tempfile::tempdir().expect("tempdir");
603            let entry_path = temp_dir.path().join("ext.mjs");
604            std::fs::write(
605                &entry_path,
606                r#"
607                export default function init(pi) {
608                  pi.registerTool({
609                    name: "hello_tool",
610                    label: "hello_tool",
611                    description: "test tool",
612                    parameters: { type: "object", properties: { name: { type: "string" } } },
613                    execute: async (_callId, input, _onUpdate, _abort, _ctx) => {
614                      const who = input && input.name ? String(input.name) : "world";
615                      return {
616                        content: [{ type: "text", text: `hello ${who}` }],
617                        details: { from: "extension" },
618                        isError: false
619                      };
620                    }
621                  });
622                }
623                "#,
624            )
625            .expect("write extension entry");
626
627            let manager = ExtensionManager::new();
628            let tools_for_runtime = Arc::new(ToolRegistry::new(&[], temp_dir.path(), None));
629            let js_runtime = JsExtensionRuntimeHandle::start(
630                PiJsRuntimeConfig {
631                    cwd: temp_dir.path().display().to_string(),
632                    ..Default::default()
633                },
634                Arc::clone(&tools_for_runtime),
635                manager.clone(),
636            )
637            .await
638            .expect("start js runtime");
639            manager.set_js_runtime(js_runtime.clone());
640
641            let spec = JsExtensionLoadSpec::from_entry_path(&entry_path).expect("spec");
642            manager
643                .load_js_extensions(vec![spec])
644                .await
645                .expect("load js extensions");
646
647            let wrappers = collect_extension_tool_wrappers(
648                &manager,
649                json!({ "cwd": temp_dir.path().display().to_string() }),
650            )
651            .await
652            .expect("collect wrappers");
653            assert_eq!(wrappers.len(), 1);
654
655            let provider = Arc::new(ToolCallingProvider);
656            let tools = ToolRegistry::new(&[], temp_dir.path(), None);
657            let mut agent = Agent::new(provider, tools, AgentConfig::default());
658            agent.extend_tools(wrappers);
659
660            let session = Arc::new(Mutex::new(Session::in_memory()));
661            let mut agent_session = AgentSession::new(
662                agent,
663                session,
664                false,
665                crate::compaction::ResolvedCompactionSettings::default(),
666            );
667            let message = agent_session
668                .run_text("hi".to_string(), |_event: AgentEvent| {})
669                .await
670                .expect("run_text");
671
672            match message.content.as_slice() {
673                [ContentBlock::Text(text)] => assert_eq!(text.text, "done"),
674                other => panic!("Expected single text content block, got: {other:?}"),
675            }
676        });
677    }
678
679    // -- Constructor & builder tests --
680
681    #[test]
682    fn extension_tool_wrapper_default_timeout() {
683        let runtime = RuntimeBuilder::current_thread()
684            .build()
685            .expect("runtime build");
686
687        runtime.block_on(async {
688            let source = r#"
689                export default function init(pi) {
690                  pi.registerTool({
691                    name: "t",
692                    description: "d",
693                    parameters: { type: "object" },
694                    execute: async () => ({ content: [], isError: false })
695                  });
696                }
697            "#;
698            let (_temp_dir, _manager, js_runtime, def) = setup_js_tool(source, "t").await;
699            let wrapper = ExtensionToolWrapper::new(def, js_runtime);
700            assert_eq!(wrapper.timeout_ms, DEFAULT_EXTENSION_TOOL_TIMEOUT_MS);
701            assert_eq!(wrapper.timeout_ms, 60_000);
702        });
703    }
704
705    #[test]
706    fn extension_tool_wrapper_timeout_clamp_boundary() {
707        let runtime = RuntimeBuilder::current_thread()
708            .build()
709            .expect("runtime build");
710
711        runtime.block_on(async {
712            let source = r#"
713                export default function init(pi) {
714                  pi.registerTool({
715                    name: "t",
716                    description: "d",
717                    parameters: { type: "object" },
718                    execute: async () => ({ content: [], isError: false })
719                  });
720                }
721            "#;
722            let (_temp_dir, _manager, js_runtime, def) = setup_js_tool(source, "t").await;
723
724            // timeout=0 clamped to 1
725            let w0 = ExtensionToolWrapper::new(def.clone(), js_runtime.clone()).with_timeout_ms(0);
726            assert_eq!(w0.timeout_ms, 1);
727
728            // timeout=1 stays 1
729            let w1 = ExtensionToolWrapper::new(def.clone(), js_runtime.clone()).with_timeout_ms(1);
730            assert_eq!(w1.timeout_ms, 1);
731
732            // timeout=u64::MAX stays u64::MAX
733            let wmax = ExtensionToolWrapper::new(def, js_runtime).with_timeout_ms(u64::MAX);
734            assert_eq!(wmax.timeout_ms, u64::MAX);
735        });
736    }
737
738    #[test]
739    fn extension_tool_wrapper_ctx_payload_default_empty() {
740        let runtime = RuntimeBuilder::current_thread()
741            .build()
742            .expect("runtime build");
743
744        runtime.block_on(async {
745            let source = r#"
746                export default function init(pi) {
747                  pi.registerTool({
748                    name: "t",
749                    description: "d",
750                    parameters: { type: "object" },
751                    execute: async () => ({ content: [], isError: false })
752                  });
753                }
754            "#;
755            let (_temp_dir, _manager, js_runtime, def) = setup_js_tool(source, "t").await;
756            let wrapper = ExtensionToolWrapper::new(def, js_runtime);
757            assert_eq!(wrapper.ctx_payload.as_ref(), &json!({}));
758        });
759    }
760
761    #[test]
762    fn extension_tool_wrapper_ctx_payload_override() {
763        let runtime = RuntimeBuilder::current_thread()
764            .build()
765            .expect("runtime build");
766
767        runtime.block_on(async {
768            let source = r#"
769                export default function init(pi) {
770                  pi.registerTool({
771                    name: "t",
772                    description: "d",
773                    parameters: { type: "object" },
774                    execute: async () => ({ content: [], isError: false })
775                  });
776                }
777            "#;
778            let (_temp_dir, _manager, js_runtime, def) = setup_js_tool(source, "t").await;
779            let custom_ctx = json!({"cwd": "/tmp", "user": "test"});
780            let wrapper =
781                ExtensionToolWrapper::new(def, js_runtime).with_ctx_payload(custom_ctx.clone());
782            assert_eq!(wrapper.ctx_payload.as_ref(), &custom_ctx);
783        });
784    }
785
786    // -- collect_extension_tool_wrappers tests --
787
788    #[test]
789    fn collect_wrappers_no_js_runtime_returns_empty() {
790        let runtime = RuntimeBuilder::current_thread()
791            .build()
792            .expect("runtime build");
793
794        runtime.block_on(async {
795            let manager = ExtensionManager::new();
796            let wrappers = collect_extension_tool_wrappers(&manager, json!({}))
797                .await
798                .expect("collect wrappers");
799            assert!(wrappers.is_empty());
800        });
801    }
802
803    #[test]
804    fn collect_wrappers_multiple_tools_from_one_extension() {
805        let runtime = RuntimeBuilder::current_thread()
806            .build()
807            .expect("runtime build");
808
809        runtime.block_on(async {
810            let temp_dir = tempfile::tempdir().expect("tempdir");
811            let entry_path = temp_dir.path().join("ext.mjs");
812            std::fs::write(
813                &entry_path,
814                r#"
815                export default function init(pi) {
816                  pi.registerTool({
817                    name: "tool_alpha",
818                    description: "first tool",
819                    parameters: { type: "object" },
820                    execute: async () => ({ content: [{ type: "text", text: "alpha" }], isError: false })
821                  });
822                  pi.registerTool({
823                    name: "tool_beta",
824                    description: "second tool",
825                    parameters: { type: "object" },
826                    execute: async () => ({ content: [{ type: "text", text: "beta" }], isError: false })
827                  });
828                }
829                "#,
830            )
831            .expect("write extension");
832
833            let manager = ExtensionManager::new();
834            let tools = Arc::new(ToolRegistry::new(&[], temp_dir.path(), None));
835            let js_runtime = JsExtensionRuntimeHandle::start(
836                PiJsRuntimeConfig {
837                    cwd: temp_dir.path().display().to_string(),
838                    ..Default::default()
839                },
840                Arc::clone(&tools),
841                manager.clone(),
842            )
843            .await
844            .expect("start js runtime");
845            manager.set_js_runtime(js_runtime.clone());
846
847            let spec = JsExtensionLoadSpec::from_entry_path(&entry_path).expect("spec");
848            manager
849                .load_js_extensions(vec![spec])
850                .await
851                .expect("load js extensions");
852
853            let wrappers = collect_extension_tool_wrappers(&manager, json!({}))
854                .await
855                .expect("collect wrappers");
856            assert_eq!(wrappers.len(), 2);
857
858            // Sorted alphabetically
859            assert_eq!(wrappers[0].name(), "tool_alpha");
860            assert_eq!(wrappers[1].name(), "tool_beta");
861        });
862    }
863
864    #[test]
865    fn collect_wrappers_respects_active_tools_filter() {
866        let runtime = RuntimeBuilder::current_thread()
867            .build()
868            .expect("runtime build");
869
870        runtime.block_on(async {
871            let temp_dir = tempfile::tempdir().expect("tempdir");
872            let entry_path = temp_dir.path().join("ext.mjs");
873            std::fs::write(
874                &entry_path,
875                r#"
876                export default function init(pi) {
877                  pi.registerTool({
878                    name: "tool_keep",
879                    description: "kept",
880                    parameters: { type: "object" },
881                    execute: async () => ({ content: [], isError: false })
882                  });
883                  pi.registerTool({
884                    name: "tool_skip",
885                    description: "skipped",
886                    parameters: { type: "object" },
887                    execute: async () => ({ content: [], isError: false })
888                  });
889                }
890                "#,
891            )
892            .expect("write extension");
893
894            let manager = ExtensionManager::new();
895            let tools = Arc::new(ToolRegistry::new(&[], temp_dir.path(), None));
896            let js_runtime = JsExtensionRuntimeHandle::start(
897                PiJsRuntimeConfig {
898                    cwd: temp_dir.path().display().to_string(),
899                    ..Default::default()
900                },
901                Arc::clone(&tools),
902                manager.clone(),
903            )
904            .await
905            .expect("start js runtime");
906            manager.set_js_runtime(js_runtime.clone());
907
908            let spec = JsExtensionLoadSpec::from_entry_path(&entry_path).expect("spec");
909            manager
910                .load_js_extensions(vec![spec])
911                .await
912                .expect("load js extensions");
913
914            // Set active_tools to only include tool_keep
915            manager.set_active_tools(vec!["tool_keep".to_string()]);
916
917            let wrappers = collect_extension_tool_wrappers(&manager, json!({}))
918                .await
919                .expect("collect wrappers");
920            assert_eq!(wrappers.len(), 1);
921            assert_eq!(wrappers[0].name(), "tool_keep");
922        });
923    }
924
925    #[test]
926    fn extension_tool_wrapper_js_error_maps_to_tool_error() {
927        let runtime = RuntimeBuilder::current_thread()
928            .build()
929            .expect("runtime build");
930
931        runtime.block_on(async {
932            let source = r#"
933                export default function init(pi) {
934                  pi.registerTool({
935                    name: "throwing_tool",
936                    description: "throws an error",
937                    parameters: { type: "object" },
938                    execute: async () => { throw new Error("boom!"); }
939                  });
940                }
941            "#;
942            let (_temp_dir, _manager, js_runtime, def) =
943                setup_js_tool(source, "throwing_tool").await;
944
945            let wrapper = ExtensionToolWrapper::new(def, js_runtime);
946            let err = wrapper
947                .execute("call-1", json!({}), None)
948                .await
949                .expect_err("throwing tool should fail");
950
951            match err {
952                Error::Tool { tool, message } => {
953                    assert_eq!(tool, "throwing_tool");
954                    assert!(
955                        message.contains("boom") || message.contains("error"),
956                        "Expected error message to reference the thrown error, got: {message}"
957                    );
958                }
959                other => panic!("expected tool error, got {other:?}"),
960            }
961        });
962    }
963
964    #[test]
965    fn extension_tool_wrapper_empty_content_result() {
966        let runtime = RuntimeBuilder::current_thread()
967            .build()
968            .expect("runtime build");
969
970        runtime.block_on(async {
971            let source = r#"
972                export default function init(pi) {
973                  pi.registerTool({
974                    name: "empty_tool",
975                    description: "returns empty content",
976                    parameters: { type: "object" },
977                    execute: async () => ({
978                      content: [],
979                      isError: false
980                    })
981                  });
982                }
983            "#;
984            let (_temp_dir, _manager, js_runtime, def) = setup_js_tool(source, "empty_tool").await;
985
986            let wrapper = ExtensionToolWrapper::new(def, js_runtime);
987            let output = wrapper
988                .execute("call-1", json!({}), None)
989                .await
990                .expect("execute tool");
991
992            assert!(!output.is_error);
993            assert!(output.content.is_empty());
994        });
995    }
996
997    #[test]
998    fn extension_tool_wrapper_is_error_flag() {
999        let runtime = RuntimeBuilder::current_thread()
1000            .build()
1001            .expect("runtime build");
1002
1003        runtime.block_on(async {
1004            let source = r#"
1005                export default function init(pi) {
1006                  pi.registerTool({
1007                    name: "error_tool",
1008                    description: "returns error flag",
1009                    parameters: { type: "object" },
1010                    execute: async () => ({
1011                      content: [{ type: "text", text: "something went wrong" }],
1012                      isError: true
1013                    })
1014                  });
1015                }
1016            "#;
1017            let (_temp_dir, _manager, js_runtime, def) = setup_js_tool(source, "error_tool").await;
1018
1019            let wrapper = ExtensionToolWrapper::new(def, js_runtime);
1020            let output = wrapper
1021                .execute("call-1", json!({}), None)
1022                .await
1023                .expect("execute tool");
1024
1025            assert!(output.is_error);
1026            match output.content.as_slice() {
1027                [ContentBlock::Text(text)] => {
1028                    assert_eq!(text.text, "something went wrong");
1029                }
1030                other => panic!("expected text content, got {other:?}"),
1031            }
1032        });
1033    }
1034
1035    #[test]
1036    fn extension_tool_wrapper_passes_input_to_handler() {
1037        let runtime = RuntimeBuilder::current_thread()
1038            .build()
1039            .expect("runtime build");
1040
1041        runtime.block_on(async {
1042            let source = r#"
1043                export default function init(pi) {
1044                  pi.registerTool({
1045                    name: "echo_tool",
1046                    description: "echoes input",
1047                    parameters: { type: "object", properties: { msg: { type: "string" } } },
1048                    execute: async (_callId, input) => ({
1049                      content: [{ type: "text", text: JSON.stringify(input) }],
1050                      isError: false
1051                    })
1052                  });
1053                }
1054            "#;
1055            let (_temp_dir, _manager, js_runtime, def) = setup_js_tool(source, "echo_tool").await;
1056
1057            let wrapper = ExtensionToolWrapper::new(def, js_runtime);
1058            let output = wrapper
1059                .execute("call-1", json!({"msg": "hello world"}), None)
1060                .await
1061                .expect("execute tool");
1062
1063            assert!(!output.is_error);
1064            match output.content.as_slice() {
1065                [ContentBlock::Text(text)] => {
1066                    let parsed: serde_json::Value =
1067                        serde_json::from_str(&text.text).expect("parse JSON");
1068                    assert_eq!(parsed["msg"], "hello world");
1069                }
1070                other => panic!("expected text content, got {other:?}"),
1071            }
1072        });
1073    }
1074}