use std::fs;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use grain_agent_core::{AgentEvent, AgentTool, EventListener};
use grain_script_boa::{BoaExtension, BoaExtensionError};
use tempfile::TempDir;
use crate::transform::transform_pi_source;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PiCommand {
pub name: String,
pub description: String,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PiShortcut {
pub keys: String,
pub description: String,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum PiNotification {
Notify { text: String },
Confirm { request_id: u64, prompt: String },
Input { request_id: u64, prompt: String },
Select {
request_id: u64,
prompt: String,
items: Vec<String>,
},
}
#[derive(Debug, thiserror::Error)]
pub enum PiCompatError {
#[error("io: {0}")]
Io(#[from] std::io::Error),
#[error("boa: {0}")]
Boa(#[from] BoaExtensionError),
}
pub struct PiExtension {
name: &'static str,
inner: Arc<BoaExtension>,
_tempdir: TempDir,
}
impl PiExtension {
pub fn name(&self) -> &'static str {
self.name
}
pub fn tools(&self) -> Vec<Arc<dyn AgentTool>> {
self.inner.tools()
}
pub fn commands(&self) -> Vec<PiCommand> {
let mut entries: Vec<PiCommand> = self
.inner
.list_metas("command")
.into_iter()
.map(|(name, attrs)| {
let description = attrs
.get("description")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
PiCommand { name, description }
})
.collect();
entries.sort_by(|a, b| a.name.cmp(&b.name));
entries
}
pub async fn invoke_command(&self, name: &str, args: serde_json::Value) -> Result<(), String> {
self.inner
.invoke_callback(&format!("cmd:{name}"), args)
.await
}
pub fn shortcuts(&self) -> Vec<PiShortcut> {
let mut entries: Vec<PiShortcut> = self
.inner
.list_metas("shortcut")
.into_iter()
.map(|(keys, attrs)| {
let description = attrs
.get("description")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
PiShortcut { keys, description }
})
.collect();
entries.sort_by(|a, b| a.keys.cmp(&b.keys));
entries
}
pub async fn invoke_shortcut(&self, keys: &str) -> Result<(), String> {
self.inner
.invoke_callback(
&format!("shortcut:{keys}"),
serde_json::Value::Object(Default::default()),
)
.await
}
pub fn drain_notifications(&self) -> Vec<PiNotification> {
self.inner
.drain_notifications()
.into_iter()
.filter_map(decode_notification)
.collect()
}
pub fn resolve_modal(
&self,
request_id: u64,
response: serde_json::Value,
) -> Result<(), String> {
self.inner.resolve_modal(request_id, response)
}
pub fn listeners(&self) -> Vec<EventListener> {
let inner = self.inner.clone();
let dispatch: EventListener = Arc::new(move |event, _signal| {
let inner = inner.clone();
Box::pin(async move {
let Some((pi_name, payload)) = map_agent_event_to_pi(&event) else {
return;
};
let key = format!("on:{pi_name}");
let _ = inner.invoke_callback(&key, payload).await;
})
});
vec![dispatch]
}
pub fn from_pi_dirs(workspace_root: &Path) -> Result<Self, PiCompatError> {
let dirs = pi_search_paths(workspace_root);
Self::from_dirs(&dirs)
}
pub fn from_dirs(dirs: &[PathBuf]) -> Result<Self, PiCompatError> {
let tempdir = tempfile::tempdir()?;
let mut count = 0usize;
for dir in dirs {
if !dir.exists() {
continue;
}
let entries = match fs::read_dir(dir) {
Ok(rd) => rd,
Err(_) => continue,
};
for entry in entries.flatten() {
let path = entry.path();
let Some(ext) = path.extension().and_then(|s| s.to_str()) else {
continue;
};
if ext != "js" {
continue;
}
let source = fs::read_to_string(&path)?;
let transformed = transform_pi_source(&source);
let stem = path
.file_stem()
.and_then(|s| s.to_str())
.unwrap_or("anonymous");
let out_name = format!("{count:03}_{stem}.js");
fs::write(tempdir.path().join(&out_name), transformed)?;
count += 1;
}
}
let inner = Arc::new(BoaExtension::from_scripts_dir(tempdir.path())?);
Ok(PiExtension {
name: "grain-pi-compat",
inner,
_tempdir: tempdir,
})
}
}
fn map_agent_event_to_pi(event: &AgentEvent) -> Option<(&'static str, serde_json::Value)> {
match event {
AgentEvent::AgentStart => Some(("agent_start", serde_json::json!({}))),
AgentEvent::AgentEnd { messages } => Some((
"agent_end",
serde_json::json!({ "message_count": messages.len() }),
)),
AgentEvent::MessageStart { message } => Some((
"message_start",
serde_json::json!({ "role": message.role() }),
)),
AgentEvent::MessageEnd { message } => {
Some(("message_end", serde_json::json!({ "role": message.role() })))
}
AgentEvent::ToolExecutionStart {
tool_call_id,
tool_name,
args,
} => Some((
"tool_call",
serde_json::json!({
"tool_call_id": tool_call_id,
"tool_name": tool_name,
"args": args,
}),
)),
AgentEvent::ToolExecutionEnd {
tool_call_id,
tool_name,
result,
is_error,
} => Some((
"tool_result",
serde_json::json!({
"tool_call_id": tool_call_id,
"tool_name": tool_name,
"is_error": is_error,
"content": result.content,
}),
)),
_ => None,
}
}
fn decode_notification(v: serde_json::Value) -> Option<PiNotification> {
let kind = v.get("kind")?.as_str()?;
match kind {
"notify" => {
let text = v.get("text")?.as_str()?.to_string();
Some(PiNotification::Notify { text })
}
"confirm" => {
let request_id = v.get("request_id")?.as_u64()?;
let prompt = v.get("prompt")?.as_str()?.to_string();
Some(PiNotification::Confirm { request_id, prompt })
}
"input" => {
let request_id = v.get("request_id")?.as_u64()?;
let prompt = v.get("prompt")?.as_str()?.to_string();
Some(PiNotification::Input { request_id, prompt })
}
"select" => {
let request_id = v.get("request_id")?.as_u64()?;
let prompt = v.get("prompt")?.as_str()?.to_string();
let items = v
.get("items")?
.as_array()?
.iter()
.filter_map(|v| v.as_str().map(str::to_string))
.collect();
Some(PiNotification::Select {
request_id,
prompt,
items,
})
}
_ => None,
}
}
fn pi_search_paths(workspace_root: &Path) -> Vec<PathBuf> {
let mut paths = vec![workspace_root.join(".pi").join("extensions")];
if let Some(home) = dirs::home_dir() {
paths.push(home.join(".pi").join("agent").join("extensions"));
}
paths
}
#[cfg(test)]
mod tests {
use super::*;
use grain_agent_core::{AgentEvent, AgentToolError, ToolUpdateCallback, UserContent};
use std::sync::Arc;
use tokio_util::sync::CancellationToken;
fn write_script(dir: &Path, name: &str, body: &str) {
std::fs::write(dir.join(name), body).unwrap();
}
async fn run_tool(
tool: &Arc<dyn AgentTool>,
args: serde_json::Value,
) -> Result<String, AgentToolError> {
let cb: ToolUpdateCallback = Arc::new(|_| {});
let result = tool
.execute("tc-1", args, CancellationToken::new(), cb)
.await?;
let text = result
.content
.iter()
.filter_map(|c| match c {
UserContent::Text(t) => Some(t.text.clone()),
_ => None,
})
.next()
.unwrap_or_default();
Ok(text)
}
#[tokio::test]
async fn factory_style_pi_extension_works() {
let tmp = tempfile::tempdir().unwrap();
write_script(
tmp.path(),
"shout.js",
r#"
export default (pi) => {
pi.registerTool({
name: "shout",
description: "Uppercases the input",
parameters: { type: "object", properties: { text: { type: "string" }}},
execute: (args) => args.text.toUpperCase(),
});
};
"#,
);
let ext = PiExtension::from_dirs(&[tmp.path().to_path_buf()]).unwrap();
let tools = ext.tools();
assert_eq!(tools.len(), 1);
assert_eq!(tools[0].definition().name, "shout");
let out = run_tool(&tools[0], serde_json::json!({ "text": "hi" }))
.await
.unwrap();
assert_eq!(out, "HI");
}
#[tokio::test]
async fn top_level_pi_call_also_works_without_factory() {
let tmp = tempfile::tempdir().unwrap();
write_script(
tmp.path(),
"reverse.js",
r#"
pi.registerTool({
name: "reverse",
description: "Reverses text",
parameters: { type: "object" },
execute: (args) => args.text.split("").reverse().join(""),
});
"#,
);
let ext = PiExtension::from_dirs(&[tmp.path().to_path_buf()]).unwrap();
let tools = ext.tools();
assert_eq!(tools.len(), 1);
let out = run_tool(&tools[0], serde_json::json!({ "text": "hello" }))
.await
.unwrap();
assert_eq!(out, "olleh");
}
#[tokio::test]
async fn ignores_non_js_files() {
let tmp = tempfile::tempdir().unwrap();
write_script(tmp.path(), "should-be-ignored.ts", "throw 'this is TS';");
write_script(
tmp.path(),
"ok.js",
r#"pi.registerTool({ name: "ok", description: "", parameters: {}, execute: () => "" });"#,
);
let ext = PiExtension::from_dirs(&[tmp.path().to_path_buf()]).unwrap();
let tools = ext.tools();
assert_eq!(tools.len(), 1);
assert_eq!(tools[0].definition().name, "ok");
}
#[tokio::test]
async fn missing_dirs_are_skipped_silently() {
let nonexistent = PathBuf::from("/tmp/grain-pi-no-such-dir-2026-05");
let ext = PiExtension::from_dirs(&[nonexistent]).unwrap();
assert!(ext.tools().is_empty());
}
#[tokio::test]
async fn pi_on_routes_through_invoke_callback() {
let tmp = tempfile::tempdir().unwrap();
write_script(
tmp.path(),
"listener.js",
r#"
pi.on("tool_call", (event) => {
if (event.tool_name !== "expected") {
throw new Error("got tool_name=" + event.tool_name);
}
});
"#,
);
let ext = PiExtension::from_dirs(&[tmp.path().to_path_buf()]).unwrap();
let ok = ext
.inner
.invoke_callback(
"on:tool_call",
serde_json::json!({ "tool_name": "expected" }),
)
.await;
assert!(ok.is_ok(), "expected Ok, got {ok:?}");
let err = ext
.inner
.invoke_callback("on:tool_call", serde_json::json!({ "tool_name": "wrong" }))
.await;
let Err(msg) = err else {
panic!("expected JS throw to surface as Err");
};
assert!(msg.contains("got tool_name=wrong"), "{msg}");
}
#[tokio::test]
async fn unregistered_callback_name_is_a_noop() {
let tmp = tempfile::tempdir().unwrap();
write_script(tmp.path(), "x.js", r#"pi.on("tool_call", () => {});"#);
let ext = PiExtension::from_dirs(&[tmp.path().to_path_buf()]).unwrap();
let res = ext
.inner
.invoke_callback("on:agent_end", serde_json::json!({}))
.await;
assert!(res.is_ok(), "unregistered event must be silent: {res:?}");
}
#[tokio::test]
async fn listeners_dispatches_supported_agent_events() {
let tmp = tempfile::tempdir().unwrap();
write_script(
tmp.path(),
"tap.js",
r#"
pi.on("agent_end", (event) => {
if (event.message_count < 0) {
throw new Error("negative message_count?!");
}
});
"#,
);
let ext = PiExtension::from_dirs(&[tmp.path().to_path_buf()]).unwrap();
let listeners = ext.listeners();
assert_eq!(listeners.len(), 1, "single dispatching listener");
let signal = CancellationToken::new();
let evt = AgentEvent::AgentEnd { messages: vec![] };
listeners[0](evt, signal).await;
}
#[tokio::test]
async fn register_command_surfaces_in_commands_list() {
let tmp = tempfile::tempdir().unwrap();
write_script(
tmp.path(),
"cmds.js",
r#"
export default (pi) => {
pi.registerCommand("audit", {
description: "Print an audit log",
handler: () => {},
});
pi.registerCommand("aaa-first", {
description: "Comes first alphabetically",
handler: () => {},
});
};
"#,
);
let ext = PiExtension::from_dirs(&[tmp.path().to_path_buf()]).unwrap();
let cmds = ext.commands();
assert_eq!(cmds.len(), 2);
assert_eq!(cmds[0].name, "aaa-first");
assert_eq!(cmds[1].name, "audit");
assert_eq!(cmds[1].description, "Print an audit log");
}
#[tokio::test]
async fn invoke_command_dispatches_to_js_handler() {
let tmp = tempfile::tempdir().unwrap();
write_script(
tmp.path(),
"ck.js",
r#"
pi.registerCommand("check", {
description: "Throws if the magic number is wrong",
handler: (args) => {
if (args.magic !== 42) {
throw new Error("magic was " + args.magic);
}
},
});
"#,
);
let ext = PiExtension::from_dirs(&[tmp.path().to_path_buf()]).unwrap();
let ok = ext
.invoke_command("check", serde_json::json!({ "magic": 42 }))
.await;
assert!(ok.is_ok(), "expected Ok, got {ok:?}");
let err = ext
.invoke_command("check", serde_json::json!({ "magic": 7 }))
.await;
let Err(msg) = err else {
panic!("expected JS throw to surface as Err");
};
assert!(msg.contains("magic was 7"), "{msg}");
}
#[tokio::test]
async fn commands_is_empty_when_no_script_registers_any() {
let tmp = tempfile::tempdir().unwrap();
write_script(
tmp.path(),
"just_tool.js",
r#"pi.registerTool({ name: "t", description: "", parameters: {}, execute: () => "" });"#,
);
let ext = PiExtension::from_dirs(&[tmp.path().to_path_buf()]).unwrap();
assert!(ext.commands().is_empty());
}
#[tokio::test]
async fn register_shortcut_surfaces_in_shortcuts_list_and_dispatches() {
let tmp = tempfile::tempdir().unwrap();
write_script(
tmp.path(),
"sc.js",
r#"
export default (pi) => {
pi.registerShortcut("ctrl+x", {
description: "Cut",
handler: () => { /* nothing */ },
});
pi.registerShortcut("ctrl+s", {
description: "Save — throws if 'saving' state mismatched",
handler: () => { throw new Error("not saving!"); },
});
};
"#,
);
let ext = PiExtension::from_dirs(&[tmp.path().to_path_buf()]).unwrap();
let scs = ext.shortcuts();
assert_eq!(scs.len(), 2);
assert_eq!(scs[0].keys, "ctrl+s");
assert_eq!(
scs[0].description,
"Save — throws if 'saving' state mismatched"
);
assert_eq!(scs[1].keys, "ctrl+x");
let ok = ext.invoke_shortcut("ctrl+x").await;
assert!(ok.is_ok(), "expected Ok, got {ok:?}");
let err = ext.invoke_shortcut("ctrl+s").await;
let Err(msg) = err else {
panic!("expected JS throw to surface as Err");
};
assert!(msg.contains("not saving!"), "{msg}");
}
#[tokio::test]
async fn pi_ui_notify_pushes_into_the_queue_and_drain_clears_it() {
let tmp = tempfile::tempdir().unwrap();
write_script(
tmp.path(),
"noisy.js",
r#"
// Top-level notifications fire at load time; handlers
// can also use pi.ui.notify after registration.
pi.ui.notify("hello from script");
pi.ui.notify("second line");
"#,
);
let ext = PiExtension::from_dirs(&[tmp.path().to_path_buf()]).unwrap();
let drained = ext.drain_notifications();
assert_eq!(drained.len(), 2);
assert_eq!(
drained[0],
PiNotification::Notify {
text: "hello from script".into()
}
);
assert_eq!(
drained[1],
PiNotification::Notify {
text: "second line".into()
}
);
assert!(ext.drain_notifications().is_empty());
}
#[tokio::test]
async fn pi_ui_notify_inside_command_handler_routes_through_queue() {
let tmp = tempfile::tempdir().unwrap();
write_script(
tmp.path(),
"cmd.js",
r#"
pi.registerCommand("say", {
description: "Push a notification",
handler: (args) => { pi.ui.notify("said: " + args.what); },
});
"#,
);
let ext = PiExtension::from_dirs(&[tmp.path().to_path_buf()]).unwrap();
assert!(ext.drain_notifications().is_empty());
ext.invoke_command("say", serde_json::json!({ "what": "hi" }))
.await
.unwrap();
let drained = ext.drain_notifications();
assert_eq!(drained.len(), 1);
assert_eq!(
drained[0],
PiNotification::Notify {
text: "said: hi".into()
}
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn pi_ui_confirm_blocks_until_host_resolves() {
let tmp = tempfile::tempdir().unwrap();
write_script(
tmp.path(),
"ask.js",
r#"
pi.registerCommand("ask", {
description: "Ask a yes/no question",
handler: () => {
const ok = pi.ui.confirm("really?");
pi.ui.notify("answer was " + ok);
},
});
"#,
);
let ext = Arc::new(PiExtension::from_dirs(&[tmp.path().to_path_buf()]).unwrap());
let ext_for_invoke = ext.clone();
let invoke_task = tokio::spawn(async move {
ext_for_invoke
.invoke_command("ask", serde_json::json!({}))
.await
});
let mut confirm_id = None;
for _ in 0..200 {
for note in ext.drain_notifications() {
if let PiNotification::Confirm { request_id, prompt } = note {
assert_eq!(prompt, "really?");
confirm_id = Some(request_id);
break;
}
}
if confirm_id.is_some() {
break;
}
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
}
let confirm_id = confirm_id.expect("confirm modal never appeared");
ext.resolve_modal(confirm_id, serde_json::json!(true))
.unwrap();
invoke_task.await.unwrap().unwrap();
let leftover = ext.drain_notifications();
assert!(
leftover.iter().any(|n| matches!(n,
PiNotification::Notify { text } if text == "answer was true"
)),
"expected post-confirm notify, got {leftover:?}"
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn pi_ui_input_returns_resolved_string() {
let tmp = tempfile::tempdir().unwrap();
write_script(
tmp.path(),
"name.js",
r#"
pi.registerCommand("name", {
description: "Ask for a name",
handler: () => {
const who = pi.ui.input("who are you?");
pi.ui.notify("hello " + who);
},
});
"#,
);
let ext = Arc::new(PiExtension::from_dirs(&[tmp.path().to_path_buf()]).unwrap());
let ext_for_invoke = ext.clone();
let invoke_task = tokio::spawn(async move {
ext_for_invoke
.invoke_command("name", serde_json::json!({}))
.await
});
let mut input_id = None;
for _ in 0..200 {
for note in ext.drain_notifications() {
if let PiNotification::Input { request_id, .. } = note {
input_id = Some(request_id);
break;
}
}
if input_id.is_some() {
break;
}
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
}
let input_id = input_id.expect("input modal never appeared");
ext.resolve_modal(input_id, serde_json::json!("Yoda"))
.unwrap();
invoke_task.await.unwrap().unwrap();
let leftover = ext.drain_notifications();
assert!(
leftover.iter().any(|n| matches!(n,
PiNotification::Notify { text } if text == "hello Yoda"
)),
"expected greeting notify, got {leftover:?}"
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn pi_ui_select_round_trip() {
let tmp = tempfile::tempdir().unwrap();
write_script(
tmp.path(),
"pick.js",
r#"
pi.registerCommand("pick", {
description: "Pick a fruit",
handler: () => {
const fruit = pi.ui.select("which?", ["apple", "banana", "cherry"]);
pi.ui.notify("picked " + fruit);
},
});
"#,
);
let ext = Arc::new(PiExtension::from_dirs(&[tmp.path().to_path_buf()]).unwrap());
let ext_for_invoke = ext.clone();
let invoke_task = tokio::spawn(async move {
ext_for_invoke
.invoke_command("pick", serde_json::json!({}))
.await
});
let mut select_id = None;
let mut received_items = vec![];
for _ in 0..200 {
for note in ext.drain_notifications() {
if let PiNotification::Select {
request_id, items, ..
} = note
{
select_id = Some(request_id);
received_items = items;
break;
}
}
if select_id.is_some() {
break;
}
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
}
let select_id = select_id.expect("select modal never appeared");
assert_eq!(received_items, vec!["apple", "banana", "cherry"]);
ext.resolve_modal(select_id, serde_json::json!("banana"))
.unwrap();
invoke_task.await.unwrap().unwrap();
let leftover = ext.drain_notifications();
assert!(
leftover.iter().any(|n| matches!(n,
PiNotification::Notify { text } if text == "picked banana"
)),
"got {leftover:?}"
);
}
#[tokio::test]
async fn from_pi_dirs_resolves_workspace_dot_pi() {
let tmp = tempfile::tempdir().unwrap();
let ext_dir = tmp.path().join(".pi").join("extensions");
std::fs::create_dir_all(&ext_dir).unwrap();
write_script(
&ext_dir,
"demo.js",
r#"
export default (pi) => {
pi.registerTool({
name: "demo",
description: "",
parameters: {},
execute: () => "ok",
});
};
"#,
);
let ext = PiExtension::from_pi_dirs(tmp.path()).unwrap();
assert_eq!(ext.tools().len(), 1);
assert_eq!(ext.tools()[0].definition().name, "demo");
}
}