Skip to main content

grain_pi_compat/
extension.rs

1//! Public surface: discover pi extension files, transform each, load
2//! the transformed bundle through [`grain_script_boa::BoaExtension`].
3
4use std::fs;
5use std::path::{Path, PathBuf};
6use std::sync::Arc;
7
8use grain_agent_core::{AgentEvent, AgentTool, EventListener};
9use grain_script_boa::{BoaExtension, BoaExtensionError};
10use tempfile::TempDir;
11
12use crate::transform::transform_pi_source;
13
14/// One slash command surfaced by a pi extension. Built from a
15/// `pi.registerCommand(name, { description, handler })` call.
16#[derive(Debug, Clone, PartialEq, Eq)]
17pub struct PiCommand {
18    pub name: String,
19    pub description: String,
20}
21
22/// One keyboard shortcut surfaced by a pi extension. Built from a
23/// `pi.registerShortcut(keys, { description, handler })` call. The
24/// `keys` string is verbatim from pi's API (e.g. `"ctrl+x"`,
25/// `"shift+alt+a"`); parsing into a `crossterm::KeyEvent` is the
26/// TUI's responsibility.
27#[derive(Debug, Clone, PartialEq, Eq)]
28pub struct PiShortcut {
29    pub keys: String,
30    pub description: String,
31}
32
33/// One UI event a pi extension can surface to the host. Includes
34/// both fire-and-forget toasts (Notify) and synchronous modal
35/// round-trips (Confirm / Input / Select).
36///
37/// For modal variants the host MUST eventually call
38/// [`PiExtension::resolve_modal`] with the embedded `request_id` —
39/// otherwise the worker thread stays blocked forever.
40#[derive(Debug, Clone, PartialEq, Eq)]
41pub enum PiNotification {
42    /// Fire-and-forget toast. The host renders the text however it
43    /// wants — transcript info line, ratatui toast widget, etc.
44    Notify { text: String },
45    /// Yes/no modal. Host must resolve with a JSON boolean.
46    Confirm { request_id: u64, prompt: String },
47    /// Free-text input modal. Host must resolve with a JSON string.
48    Input { request_id: u64, prompt: String },
49    /// Pick-from-list modal. Host must resolve with one of the
50    /// `items` (as a JSON string).
51    Select {
52        request_id: u64,
53        prompt: String,
54        items: Vec<String>,
55    },
56}
57
58/// Errors raised while constructing a [`PiExtension`].
59#[derive(Debug, thiserror::Error)]
60pub enum PiCompatError {
61    #[error("io: {0}")]
62    Io(#[from] std::io::Error),
63    #[error("boa: {0}")]
64    Boa(#[from] BoaExtensionError),
65}
66
67/// One pi-compat scripting environment. Owns the transformed-script
68/// temp directory + a [`BoaExtension`] for the underlying JS runtime.
69pub struct PiExtension {
70    name: &'static str,
71    /// Wrapped in Arc so [`Self::listeners`] can hand out clones that
72    /// outlive `&self` borrows — each EventListener captures its own
73    /// strong ref to the underlying Boa runtime.
74    inner: Arc<BoaExtension>,
75    /// Holds the transformed-script dir open so the Boa worker can
76    /// read them. Dropped together with the rest of the extension.
77    _tempdir: TempDir,
78}
79
80impl PiExtension {
81    /// Stable name for logging.
82    pub fn name(&self) -> &'static str {
83        self.name
84    }
85
86    /// Tools registered by all loaded pi extension files.
87    pub fn tools(&self) -> Vec<Arc<dyn AgentTool>> {
88        self.inner.tools()
89    }
90
91    /// Slash commands registered via `pi.registerCommand(name, {...})`.
92    /// Sorted by name for deterministic display in pickers.
93    pub fn commands(&self) -> Vec<PiCommand> {
94        let mut entries: Vec<PiCommand> = self
95            .inner
96            .list_metas("command")
97            .into_iter()
98            .map(|(name, attrs)| {
99                let description = attrs
100                    .get("description")
101                    .and_then(|v| v.as_str())
102                    .unwrap_or("")
103                    .to_string();
104                PiCommand { name, description }
105            })
106            .collect();
107        entries.sort_by(|a, b| a.name.cmp(&b.name));
108        entries
109    }
110
111    /// Dispatch a slash command registered via `pi.registerCommand`.
112    /// `args` is forwarded as the first argument to the JS handler.
113    /// JS-side throws come back as `Err(msg)`.
114    pub async fn invoke_command(&self, name: &str, args: serde_json::Value) -> Result<(), String> {
115        self.inner
116            .invoke_callback(&format!("cmd:{name}"), args)
117            .await
118    }
119
120    /// Keyboard shortcuts registered via
121    /// `pi.registerShortcut(keys, {...})`. Sorted by `keys` for
122    /// deterministic display in the TUI's help / cheatsheet.
123    pub fn shortcuts(&self) -> Vec<PiShortcut> {
124        let mut entries: Vec<PiShortcut> = self
125            .inner
126            .list_metas("shortcut")
127            .into_iter()
128            .map(|(keys, attrs)| {
129                let description = attrs
130                    .get("description")
131                    .and_then(|v| v.as_str())
132                    .unwrap_or("")
133                    .to_string();
134                PiShortcut { keys, description }
135            })
136            .collect();
137        entries.sort_by(|a, b| a.keys.cmp(&b.keys));
138        entries
139    }
140
141    /// Dispatch a shortcut registered via `pi.registerShortcut`. The
142    /// TUI matches key events against `shortcuts()` and calls this
143    /// with the matched spec. No args are forwarded — shortcuts are
144    /// nullary in pi's API.
145    pub async fn invoke_shortcut(&self, keys: &str) -> Result<(), String> {
146        self.inner
147            .invoke_callback(
148                &format!("shortcut:{keys}"),
149                serde_json::Value::Object(Default::default()),
150            )
151            .await
152    }
153
154    /// Drain every `pi.ui.notify(text)` payload that the scripts
155    /// have pushed since the last call. The TUI polls this each
156    /// tick / event loop iteration and renders entries however it
157    /// wants. Unknown payload shapes are silently dropped.
158    pub fn drain_notifications(&self) -> Vec<PiNotification> {
159        self.inner
160            .drain_notifications()
161            .into_iter()
162            .filter_map(decode_notification)
163            .collect()
164    }
165
166    /// Resolve a pending modal initiated via
167    /// `pi.ui.confirm/input/select`. The JS caller blocks until
168    /// this is called; `request_id` matches the field in the
169    /// `PiNotification` payload. `response` must be the right shape
170    /// for the modal kind: bool for Confirm, string for Input or
171    /// Select. Mismatched shapes will simply land in JS as whatever
172    /// type they decode to — the JS code is responsible for
173    /// checking.
174    pub fn resolve_modal(
175        &self,
176        request_id: u64,
177        response: serde_json::Value,
178    ) -> Result<(), String> {
179        self.inner.resolve_modal(request_id, response)
180    }
181
182    /// One [`EventListener`] that translates supported `AgentEvent`
183    /// variants into the pi event schema and dispatches them into
184    /// JS handlers registered via `pi.on(event_name, fn)`.
185    ///
186    /// Subscribe this to an [`grain_agent_core::Agent`] via
187    /// `agent.subscribe(listener).await` and pi scripts will start
188    /// receiving events.
189    pub fn listeners(&self) -> Vec<EventListener> {
190        let inner = self.inner.clone();
191        let dispatch: EventListener = Arc::new(move |event, _signal| {
192            let inner = inner.clone();
193            Box::pin(async move {
194                let Some((pi_name, payload)) = map_agent_event_to_pi(&event) else {
195                    return;
196                };
197                let key = format!("on:{pi_name}");
198                // Swallow errors — listeners can't return diagnostics
199                // to the agent. JS-side throws are stringified at the
200                // worker boundary and stay there; they don't break
201                // the agent's run.
202                let _ = inner.invoke_callback(&key, payload).await;
203            })
204        });
205        vec![dispatch]
206    }
207
208    /// Scan pi's conventional locations and load every `*.js` file:
209    ///
210    /// - `<workspace>/.pi/extensions/` (per-project)
211    /// - `~/.pi/agent/extensions/` (global)
212    ///
213    /// Missing locations are not an error — they're simply skipped.
214    pub fn from_pi_dirs(workspace_root: &Path) -> Result<Self, PiCompatError> {
215        let dirs = pi_search_paths(workspace_root);
216        Self::from_dirs(&dirs)
217    }
218
219    /// Explicit-paths variant. Useful for tests and for callers who
220    /// want to override pi's default search behavior.
221    pub fn from_dirs(dirs: &[PathBuf]) -> Result<Self, PiCompatError> {
222        let tempdir = tempfile::tempdir()?;
223        let mut count = 0usize;
224        for dir in dirs {
225            if !dir.exists() {
226                continue;
227            }
228            let entries = match fs::read_dir(dir) {
229                Ok(rd) => rd,
230                Err(_) => continue,
231            };
232            for entry in entries.flatten() {
233                let path = entry.path();
234                let Some(ext) = path.extension().and_then(|s| s.to_str()) else {
235                    continue;
236                };
237                // Phase 1: JS only. TypeScript lands in Phase 3 via
238                // an swc transpile step.
239                if ext != "js" {
240                    continue;
241                }
242                let source = fs::read_to_string(&path)?;
243                let transformed = transform_pi_source(&source);
244                let stem = path
245                    .file_stem()
246                    .and_then(|s| s.to_str())
247                    .unwrap_or("anonymous");
248                // Numeric prefix preserves a stable load order even
249                // across directories.
250                let out_name = format!("{count:03}_{stem}.js");
251                fs::write(tempdir.path().join(&out_name), transformed)?;
252                count += 1;
253            }
254        }
255        let inner = Arc::new(BoaExtension::from_scripts_dir(tempdir.path())?);
256        Ok(PiExtension {
257            name: "grain-pi-compat",
258            inner,
259            _tempdir: tempdir,
260        })
261    }
262}
263
264/// Map our [`AgentEvent`] variants to pi's documented event schema.
265/// Returns `None` for events pi doesn't declare today (e.g. our
266/// turn-level lifecycle is pi's per-message lifecycle plus tool
267/// hooks, so `TurnStart` / `TurnEnd` have no direct pi equivalent).
268fn map_agent_event_to_pi(event: &AgentEvent) -> Option<(&'static str, serde_json::Value)> {
269    match event {
270        AgentEvent::AgentStart => Some(("agent_start", serde_json::json!({}))),
271        AgentEvent::AgentEnd { messages } => Some((
272            "agent_end",
273            serde_json::json!({ "message_count": messages.len() }),
274        )),
275        AgentEvent::MessageStart { message } => Some((
276            "message_start",
277            serde_json::json!({ "role": message.role() }),
278        )),
279        AgentEvent::MessageEnd { message } => {
280            Some(("message_end", serde_json::json!({ "role": message.role() })))
281        }
282        AgentEvent::ToolExecutionStart {
283            tool_call_id,
284            tool_name,
285            args,
286        } => Some((
287            "tool_call",
288            serde_json::json!({
289                "tool_call_id": tool_call_id,
290                "tool_name": tool_name,
291                "args": args,
292            }),
293        )),
294        AgentEvent::ToolExecutionEnd {
295            tool_call_id,
296            tool_name,
297            result,
298            is_error,
299        } => Some((
300            "tool_result",
301            serde_json::json!({
302                "tool_call_id": tool_call_id,
303                "tool_name": tool_name,
304                "is_error": is_error,
305                // AgentToolResult is `Serialize`; project just the
306                // text content list to keep the JS payload simple.
307                "content": result.content,
308            }),
309        )),
310        _ => None,
311    }
312}
313
314/// Map one raw queue payload from the Boa worker into a typed
315/// [`PiNotification`]. Returns `None` for unknown shapes so the
316/// queue stays forward-compatible.
317fn decode_notification(v: serde_json::Value) -> Option<PiNotification> {
318    let kind = v.get("kind")?.as_str()?;
319    match kind {
320        "notify" => {
321            let text = v.get("text")?.as_str()?.to_string();
322            Some(PiNotification::Notify { text })
323        }
324        "confirm" => {
325            let request_id = v.get("request_id")?.as_u64()?;
326            let prompt = v.get("prompt")?.as_str()?.to_string();
327            Some(PiNotification::Confirm { request_id, prompt })
328        }
329        "input" => {
330            let request_id = v.get("request_id")?.as_u64()?;
331            let prompt = v.get("prompt")?.as_str()?.to_string();
332            Some(PiNotification::Input { request_id, prompt })
333        }
334        "select" => {
335            let request_id = v.get("request_id")?.as_u64()?;
336            let prompt = v.get("prompt")?.as_str()?.to_string();
337            let items = v
338                .get("items")?
339                .as_array()?
340                .iter()
341                .filter_map(|v| v.as_str().map(str::to_string))
342                .collect();
343            Some(PiNotification::Select {
344                request_id,
345                prompt,
346                items,
347            })
348        }
349        _ => None,
350    }
351}
352
353/// pi's conventional discovery paths, in load order.
354fn pi_search_paths(workspace_root: &Path) -> Vec<PathBuf> {
355    let mut paths = vec![workspace_root.join(".pi").join("extensions")];
356    if let Some(home) = dirs::home_dir() {
357        paths.push(home.join(".pi").join("agent").join("extensions"));
358    }
359    paths
360}
361
362#[cfg(test)]
363mod tests {
364    use super::*;
365    use grain_agent_core::{AgentEvent, AgentToolError, ToolUpdateCallback, UserContent};
366    use std::sync::Arc;
367    use tokio_util::sync::CancellationToken;
368
369    fn write_script(dir: &Path, name: &str, body: &str) {
370        std::fs::write(dir.join(name), body).unwrap();
371    }
372
373    async fn run_tool(
374        tool: &Arc<dyn AgentTool>,
375        args: serde_json::Value,
376    ) -> Result<String, AgentToolError> {
377        let cb: ToolUpdateCallback = Arc::new(|_| {});
378        let result = tool
379            .execute("tc-1", args, CancellationToken::new(), cb)
380            .await?;
381        let text = result
382            .content
383            .iter()
384            .filter_map(|c| match c {
385                UserContent::Text(t) => Some(t.text.clone()),
386                _ => None,
387            })
388            .next()
389            .unwrap_or_default();
390        Ok(text)
391    }
392
393    #[tokio::test]
394    async fn factory_style_pi_extension_works() {
395        let tmp = tempfile::tempdir().unwrap();
396        write_script(
397            tmp.path(),
398            "shout.js",
399            r#"
400            export default (pi) => {
401                pi.registerTool({
402                    name: "shout",
403                    description: "Uppercases the input",
404                    parameters: { type: "object", properties: { text: { type: "string" }}},
405                    execute: (args) => args.text.toUpperCase(),
406                });
407            };
408            "#,
409        );
410        let ext = PiExtension::from_dirs(&[tmp.path().to_path_buf()]).unwrap();
411        let tools = ext.tools();
412        assert_eq!(tools.len(), 1);
413        assert_eq!(tools[0].definition().name, "shout");
414        let out = run_tool(&tools[0], serde_json::json!({ "text": "hi" }))
415            .await
416            .unwrap();
417        assert_eq!(out, "HI");
418    }
419
420    #[tokio::test]
421    async fn top_level_pi_call_also_works_without_factory() {
422        let tmp = tempfile::tempdir().unwrap();
423        write_script(
424            tmp.path(),
425            "reverse.js",
426            r#"
427            pi.registerTool({
428                name: "reverse",
429                description: "Reverses text",
430                parameters: { type: "object" },
431                execute: (args) => args.text.split("").reverse().join(""),
432            });
433            "#,
434        );
435        let ext = PiExtension::from_dirs(&[tmp.path().to_path_buf()]).unwrap();
436        let tools = ext.tools();
437        assert_eq!(tools.len(), 1);
438        let out = run_tool(&tools[0], serde_json::json!({ "text": "hello" }))
439            .await
440            .unwrap();
441        assert_eq!(out, "olleh");
442    }
443
444    #[tokio::test]
445    async fn ignores_non_js_files() {
446        let tmp = tempfile::tempdir().unwrap();
447        write_script(tmp.path(), "should-be-ignored.ts", "throw 'this is TS';");
448        write_script(
449            tmp.path(),
450            "ok.js",
451            r#"pi.registerTool({ name: "ok", description: "", parameters: {}, execute: () => "" });"#,
452        );
453        let ext = PiExtension::from_dirs(&[tmp.path().to_path_buf()]).unwrap();
454        let tools = ext.tools();
455        assert_eq!(tools.len(), 1);
456        assert_eq!(tools[0].definition().name, "ok");
457    }
458
459    #[tokio::test]
460    async fn missing_dirs_are_skipped_silently() {
461        let nonexistent = PathBuf::from("/tmp/grain-pi-no-such-dir-2026-05");
462        let ext = PiExtension::from_dirs(&[nonexistent]).unwrap();
463        assert!(ext.tools().is_empty());
464    }
465
466    #[tokio::test]
467    async fn pi_on_routes_through_invoke_callback() {
468        // The pi.on shim should land the JS handler in
469        // grain.register_callback, keyed `on:<event>`. We can prove
470        // it by directly invoking via BoaExtension::invoke_callback
471        // and asserting the JS throws (or doesn't) based on payload.
472        let tmp = tempfile::tempdir().unwrap();
473        write_script(
474            tmp.path(),
475            "listener.js",
476            r#"
477            pi.on("tool_call", (event) => {
478                if (event.tool_name !== "expected") {
479                    throw new Error("got tool_name=" + event.tool_name);
480                }
481            });
482            "#,
483        );
484        let ext = PiExtension::from_dirs(&[tmp.path().to_path_buf()]).unwrap();
485
486        // Happy path — payload matches what the JS expects.
487        let ok = ext
488            .inner
489            .invoke_callback(
490                "on:tool_call",
491                serde_json::json!({ "tool_name": "expected" }),
492            )
493            .await;
494        assert!(ok.is_ok(), "expected Ok, got {ok:?}");
495
496        // Sad path — JS throws, error string surfaces back.
497        let err = ext
498            .inner
499            .invoke_callback("on:tool_call", serde_json::json!({ "tool_name": "wrong" }))
500            .await;
501        let Err(msg) = err else {
502            panic!("expected JS throw to surface as Err");
503        };
504        assert!(msg.contains("got tool_name=wrong"), "{msg}");
505    }
506
507    #[tokio::test]
508    async fn unregistered_callback_name_is_a_noop() {
509        let tmp = tempfile::tempdir().unwrap();
510        write_script(tmp.path(), "x.js", r#"pi.on("tool_call", () => {});"#);
511        let ext = PiExtension::from_dirs(&[tmp.path().to_path_buf()]).unwrap();
512        // No handler subscribed to "agent_end" — must NOT error.
513        let res = ext
514            .inner
515            .invoke_callback("on:agent_end", serde_json::json!({}))
516            .await;
517        assert!(res.is_ok(), "unregistered event must be silent: {res:?}");
518    }
519
520    #[tokio::test]
521    async fn listeners_dispatches_supported_agent_events() {
522        // Same idea as `pi_on_routes_through_invoke_callback`, but
523        // exercising the public `listeners()` path (the one we'd
524        // subscribe to a real Agent).
525        let tmp = tempfile::tempdir().unwrap();
526        write_script(
527            tmp.path(),
528            "tap.js",
529            r#"
530            pi.on("agent_end", (event) => {
531                if (event.message_count < 0) {
532                    throw new Error("negative message_count?!");
533                }
534            });
535            "#,
536        );
537        let ext = PiExtension::from_dirs(&[tmp.path().to_path_buf()]).unwrap();
538        let listeners = ext.listeners();
539        assert_eq!(listeners.len(), 1, "single dispatching listener");
540
541        let signal = CancellationToken::new();
542        let evt = AgentEvent::AgentEnd { messages: vec![] };
543        // Listener returns BoxFuture<()>; awaiting it succeeds since
544        // the JS handler doesn't throw.
545        listeners[0](evt, signal).await;
546    }
547
548    #[tokio::test]
549    async fn register_command_surfaces_in_commands_list() {
550        let tmp = tempfile::tempdir().unwrap();
551        write_script(
552            tmp.path(),
553            "cmds.js",
554            r#"
555            export default (pi) => {
556                pi.registerCommand("audit", {
557                    description: "Print an audit log",
558                    handler: () => {},
559                });
560                pi.registerCommand("aaa-first", {
561                    description: "Comes first alphabetically",
562                    handler: () => {},
563                });
564            };
565            "#,
566        );
567        let ext = PiExtension::from_dirs(&[tmp.path().to_path_buf()]).unwrap();
568        let cmds = ext.commands();
569        assert_eq!(cmds.len(), 2);
570        // Sorted by name.
571        assert_eq!(cmds[0].name, "aaa-first");
572        assert_eq!(cmds[1].name, "audit");
573        assert_eq!(cmds[1].description, "Print an audit log");
574    }
575
576    #[tokio::test]
577    async fn invoke_command_dispatches_to_js_handler() {
578        let tmp = tempfile::tempdir().unwrap();
579        write_script(
580            tmp.path(),
581            "ck.js",
582            r#"
583            pi.registerCommand("check", {
584                description: "Throws if the magic number is wrong",
585                handler: (args) => {
586                    if (args.magic !== 42) {
587                        throw new Error("magic was " + args.magic);
588                    }
589                },
590            });
591            "#,
592        );
593        let ext = PiExtension::from_dirs(&[tmp.path().to_path_buf()]).unwrap();
594        // Happy path.
595        let ok = ext
596            .invoke_command("check", serde_json::json!({ "magic": 42 }))
597            .await;
598        assert!(ok.is_ok(), "expected Ok, got {ok:?}");
599        // Sad path — JS throws.
600        let err = ext
601            .invoke_command("check", serde_json::json!({ "magic": 7 }))
602            .await;
603        let Err(msg) = err else {
604            panic!("expected JS throw to surface as Err");
605        };
606        assert!(msg.contains("magic was 7"), "{msg}");
607    }
608
609    #[tokio::test]
610    async fn commands_is_empty_when_no_script_registers_any() {
611        let tmp = tempfile::tempdir().unwrap();
612        write_script(
613            tmp.path(),
614            "just_tool.js",
615            r#"pi.registerTool({ name: "t", description: "", parameters: {}, execute: () => "" });"#,
616        );
617        let ext = PiExtension::from_dirs(&[tmp.path().to_path_buf()]).unwrap();
618        assert!(ext.commands().is_empty());
619    }
620
621    #[tokio::test]
622    async fn register_shortcut_surfaces_in_shortcuts_list_and_dispatches() {
623        let tmp = tempfile::tempdir().unwrap();
624        write_script(
625            tmp.path(),
626            "sc.js",
627            r#"
628            export default (pi) => {
629                pi.registerShortcut("ctrl+x", {
630                    description: "Cut",
631                    handler: () => { /* nothing */ },
632                });
633                pi.registerShortcut("ctrl+s", {
634                    description: "Save — throws if 'saving' state mismatched",
635                    handler: () => { throw new Error("not saving!"); },
636                });
637            };
638            "#,
639        );
640        let ext = PiExtension::from_dirs(&[tmp.path().to_path_buf()]).unwrap();
641        let scs = ext.shortcuts();
642        // Sorted by `keys`.
643        assert_eq!(scs.len(), 2);
644        assert_eq!(scs[0].keys, "ctrl+s");
645        assert_eq!(
646            scs[0].description,
647            "Save — throws if 'saving' state mismatched"
648        );
649        assert_eq!(scs[1].keys, "ctrl+x");
650
651        // Dispatch the no-op shortcut.
652        let ok = ext.invoke_shortcut("ctrl+x").await;
653        assert!(ok.is_ok(), "expected Ok, got {ok:?}");
654        // Dispatch the throwing shortcut.
655        let err = ext.invoke_shortcut("ctrl+s").await;
656        let Err(msg) = err else {
657            panic!("expected JS throw to surface as Err");
658        };
659        assert!(msg.contains("not saving!"), "{msg}");
660    }
661
662    #[tokio::test]
663    async fn pi_ui_notify_pushes_into_the_queue_and_drain_clears_it() {
664        let tmp = tempfile::tempdir().unwrap();
665        write_script(
666            tmp.path(),
667            "noisy.js",
668            r#"
669            // Top-level notifications fire at load time; handlers
670            // can also use pi.ui.notify after registration.
671            pi.ui.notify("hello from script");
672            pi.ui.notify("second line");
673            "#,
674        );
675        let ext = PiExtension::from_dirs(&[tmp.path().to_path_buf()]).unwrap();
676        let drained = ext.drain_notifications();
677        assert_eq!(drained.len(), 2);
678        assert_eq!(
679            drained[0],
680            PiNotification::Notify {
681                text: "hello from script".into()
682            }
683        );
684        assert_eq!(
685            drained[1],
686            PiNotification::Notify {
687                text: "second line".into()
688            }
689        );
690        // Second drain returns empty — queue was cleared.
691        assert!(ext.drain_notifications().is_empty());
692    }
693
694    #[tokio::test]
695    async fn pi_ui_notify_inside_command_handler_routes_through_queue() {
696        let tmp = tempfile::tempdir().unwrap();
697        write_script(
698            tmp.path(),
699            "cmd.js",
700            r#"
701            pi.registerCommand("say", {
702                description: "Push a notification",
703                handler: (args) => { pi.ui.notify("said: " + args.what); },
704            });
705            "#,
706        );
707        let ext = PiExtension::from_dirs(&[tmp.path().to_path_buf()]).unwrap();
708        // No notifications yet — registration alone shouldn't fire any.
709        assert!(ext.drain_notifications().is_empty());
710        ext.invoke_command("say", serde_json::json!({ "what": "hi" }))
711            .await
712            .unwrap();
713        let drained = ext.drain_notifications();
714        assert_eq!(drained.len(), 1);
715        assert_eq!(
716            drained[0],
717            PiNotification::Notify {
718                text: "said: hi".into()
719            }
720        );
721    }
722
723    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
724    async fn pi_ui_confirm_blocks_until_host_resolves() {
725        // Multi-thread tokio so we can run the JS handler on one
726        // worker and resolve the modal from another.
727        let tmp = tempfile::tempdir().unwrap();
728        write_script(
729            tmp.path(),
730            "ask.js",
731            r#"
732            pi.registerCommand("ask", {
733                description: "Ask a yes/no question",
734                handler: () => {
735                    const ok = pi.ui.confirm("really?");
736                    pi.ui.notify("answer was " + ok);
737                },
738            });
739            "#,
740        );
741        let ext = Arc::new(PiExtension::from_dirs(&[tmp.path().to_path_buf()]).unwrap());
742        // Spawn the command invocation — it'll block inside the
743        // Boa worker until we resolve the modal.
744        let ext_for_invoke = ext.clone();
745        let invoke_task = tokio::spawn(async move {
746            ext_for_invoke
747                .invoke_command("ask", serde_json::json!({}))
748                .await
749        });
750        // Wait for the confirm modal to appear in the queue. Poll
751        // because the JS handler ran on the worker thread and we
752        // can't precisely await that.
753        let mut confirm_id = None;
754        for _ in 0..200 {
755            for note in ext.drain_notifications() {
756                if let PiNotification::Confirm { request_id, prompt } = note {
757                    assert_eq!(prompt, "really?");
758                    confirm_id = Some(request_id);
759                    break;
760                }
761            }
762            if confirm_id.is_some() {
763                break;
764            }
765            tokio::time::sleep(std::time::Duration::from_millis(10)).await;
766        }
767        let confirm_id = confirm_id.expect("confirm modal never appeared");
768        // Resolve with `true`.
769        ext.resolve_modal(confirm_id, serde_json::json!(true))
770            .unwrap();
771        // The handler should now finish and post the answer.
772        invoke_task.await.unwrap().unwrap();
773        // Drain any remaining notifications — should include the
774        // post-confirm notify.
775        let leftover = ext.drain_notifications();
776        assert!(
777            leftover.iter().any(|n| matches!(n,
778                PiNotification::Notify { text } if text == "answer was true"
779            )),
780            "expected post-confirm notify, got {leftover:?}"
781        );
782    }
783
784    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
785    async fn pi_ui_input_returns_resolved_string() {
786        let tmp = tempfile::tempdir().unwrap();
787        write_script(
788            tmp.path(),
789            "name.js",
790            r#"
791            pi.registerCommand("name", {
792                description: "Ask for a name",
793                handler: () => {
794                    const who = pi.ui.input("who are you?");
795                    pi.ui.notify("hello " + who);
796                },
797            });
798            "#,
799        );
800        let ext = Arc::new(PiExtension::from_dirs(&[tmp.path().to_path_buf()]).unwrap());
801        let ext_for_invoke = ext.clone();
802        let invoke_task = tokio::spawn(async move {
803            ext_for_invoke
804                .invoke_command("name", serde_json::json!({}))
805                .await
806        });
807        let mut input_id = None;
808        for _ in 0..200 {
809            for note in ext.drain_notifications() {
810                if let PiNotification::Input { request_id, .. } = note {
811                    input_id = Some(request_id);
812                    break;
813                }
814            }
815            if input_id.is_some() {
816                break;
817            }
818            tokio::time::sleep(std::time::Duration::from_millis(10)).await;
819        }
820        let input_id = input_id.expect("input modal never appeared");
821        ext.resolve_modal(input_id, serde_json::json!("Yoda"))
822            .unwrap();
823        invoke_task.await.unwrap().unwrap();
824        let leftover = ext.drain_notifications();
825        assert!(
826            leftover.iter().any(|n| matches!(n,
827                PiNotification::Notify { text } if text == "hello Yoda"
828            )),
829            "expected greeting notify, got {leftover:?}"
830        );
831    }
832
833    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
834    async fn pi_ui_select_round_trip() {
835        let tmp = tempfile::tempdir().unwrap();
836        write_script(
837            tmp.path(),
838            "pick.js",
839            r#"
840            pi.registerCommand("pick", {
841                description: "Pick a fruit",
842                handler: () => {
843                    const fruit = pi.ui.select("which?", ["apple", "banana", "cherry"]);
844                    pi.ui.notify("picked " + fruit);
845                },
846            });
847            "#,
848        );
849        let ext = Arc::new(PiExtension::from_dirs(&[tmp.path().to_path_buf()]).unwrap());
850        let ext_for_invoke = ext.clone();
851        let invoke_task = tokio::spawn(async move {
852            ext_for_invoke
853                .invoke_command("pick", serde_json::json!({}))
854                .await
855        });
856        let mut select_id = None;
857        let mut received_items = vec![];
858        for _ in 0..200 {
859            for note in ext.drain_notifications() {
860                if let PiNotification::Select {
861                    request_id, items, ..
862                } = note
863                {
864                    select_id = Some(request_id);
865                    received_items = items;
866                    break;
867                }
868            }
869            if select_id.is_some() {
870                break;
871            }
872            tokio::time::sleep(std::time::Duration::from_millis(10)).await;
873        }
874        let select_id = select_id.expect("select modal never appeared");
875        assert_eq!(received_items, vec!["apple", "banana", "cherry"]);
876        ext.resolve_modal(select_id, serde_json::json!("banana"))
877            .unwrap();
878        invoke_task.await.unwrap().unwrap();
879        let leftover = ext.drain_notifications();
880        assert!(
881            leftover.iter().any(|n| matches!(n,
882                PiNotification::Notify { text } if text == "picked banana"
883            )),
884            "got {leftover:?}"
885        );
886    }
887
888    #[tokio::test]
889    async fn from_pi_dirs_resolves_workspace_dot_pi() {
890        let tmp = tempfile::tempdir().unwrap();
891        let ext_dir = tmp.path().join(".pi").join("extensions");
892        std::fs::create_dir_all(&ext_dir).unwrap();
893        write_script(
894            &ext_dir,
895            "demo.js",
896            r#"
897            export default (pi) => {
898                pi.registerTool({
899                    name: "demo",
900                    description: "",
901                    parameters: {},
902                    execute: () => "ok",
903                });
904            };
905            "#,
906        );
907        let ext = PiExtension::from_pi_dirs(tmp.path()).unwrap();
908        assert_eq!(ext.tools().len(), 1);
909        assert_eq!(ext.tools()[0].definition().name, "demo");
910    }
911}