#[allow(unused_imports)]
use crate::sync_util::LockExt;
use std::sync::{Arc, Mutex, OnceLock};
use rig::tool::{ToolDyn, ToolError};
use super::{PluginManager, escape_janet_string};
#[cfg_attr(not(feature = "plugin"), allow(dead_code))]
static PLUGIN_MANAGER: OnceLock<Arc<Mutex<PluginManager>>> = OnceLock::new();
#[cfg_attr(not(feature = "plugin"), allow(dead_code))]
pub fn init_global(manager: Arc<Mutex<PluginManager>>) {
let _ = PLUGIN_MANAGER.set(manager);
}
#[cfg_attr(not(feature = "plugin"), allow(dead_code))]
pub fn global() -> Option<Arc<Mutex<PluginManager>>> {
PLUGIN_MANAGER.get().cloned()
}
pub struct HookedToolDyn {
inner: Box<dyn ToolDyn>,
pm: Option<Arc<Mutex<PluginManager>>>,
}
impl HookedToolDyn {
#[allow(dead_code)]
pub fn wrap_global(inner: Box<dyn ToolDyn>) -> Box<dyn ToolDyn> {
let pm = global();
if pm.is_none() {
return inner;
}
Box::new(HookedToolDyn { inner, pm })
}
#[allow(dead_code)] pub fn with_manager(inner: Box<dyn ToolDyn>, pm: Option<Arc<Mutex<PluginManager>>>) -> Self {
HookedToolDyn { inner, pm }
}
}
impl ToolDyn for HookedToolDyn {
fn name(&self) -> String {
self.inner.name()
}
fn definition<'a>(
&'a self,
prompt: String,
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = rig::completion::ToolDefinition> + Send + 'a>,
> {
self.inner.definition(prompt)
}
fn call<'a>(
&'a self,
args: String,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<String, ToolError>> + Send + 'a>>
{
Box::pin(async move {
let name = self.inner.name();
let (block, mutated) = match &self.pm {
Some(pm) => {
let ctx = format!(
"@{{:tool \"{}\" :args \"{}\"}}",
escape_janet_string(&name),
escape_janet_string(&args),
);
let mut mgr = pm.lock_ignore_poison();
let result = mgr
.dispatch_tool_hook("on-tool-start", &ctx)
.unwrap_or_default();
(result.block, result.mutate_input)
}
None => (None, None),
};
if let Some(reason) = block {
return Err(ToolError::ToolCallError(Box::<
dyn std::error::Error + Send + Sync,
>::from(format!(
"blocked by plugin: {}",
reason
))));
}
let final_args = mutated.unwrap_or(args);
let result = self.inner.call(final_args).await;
let replace = match &self.pm {
Some(pm) => {
let output_for_ctx = match &result {
Ok(s) => s.clone(),
Err(e) => e.to_string(),
};
let ctx = format!(
"@{{:tool \"{}\" :output \"{}\"}}",
escape_janet_string(&name),
escape_janet_string(&output_for_ctx),
);
let mut mgr = pm.lock_ignore_poison();
mgr.dispatch_tool_hook("on-tool-end", &ctx)
.unwrap_or_default()
.replace_result
}
None => None,
};
match (result, replace) {
(_, Some(new_output)) => Ok(new_output),
(other, None) => other,
}
})
}
}
#[cfg(all(test, feature = "plugin"))]
mod tests {
use super::*;
use rig::completion::ToolDefinition;
struct Echo;
impl ToolDyn for Echo {
fn name(&self) -> String {
"echo".to_string()
}
fn definition<'a>(
&'a self,
_prompt: String,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = ToolDefinition> + Send + 'a>>
{
Box::pin(async move {
ToolDefinition {
name: "echo".to_string(),
description: "echo".to_string(),
parameters: serde_json::json!({}),
}
})
}
fn call<'a>(
&'a self,
args: String,
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = Result<String, ToolError>> + Send + 'a>,
> {
Box::pin(async move { Ok(args) })
}
}
fn pm() -> Arc<Mutex<PluginManager>> {
Arc::new(Mutex::new(PluginManager::try_new().unwrap()))
}
async fn wrap_and_call(
pm_arc: Arc<Mutex<PluginManager>>,
args: &str,
) -> Result<String, ToolError> {
let wrapper = HookedToolDyn::with_manager(Box::new(Echo), Some(pm_arc));
wrapper.call(args.to_string()).await
}
#[tokio::test]
async fn passthrough_when_no_hooks_registered() {
let result = wrap_and_call(pm(), r#"{"x":1}"#).await.unwrap();
assert_eq!(result, r#"{"x":1}"#);
}
#[tokio::test]
async fn block_returns_tool_error_with_reason() {
let pm_arc = pm();
{
let mut mgr = pm_arc.lock().unwrap();
mgr.eval(r#"(defn deny [ctx] (harness/block "danger"))"#)
.unwrap();
mgr.register("on-tool-start", "deny");
}
let err = wrap_and_call(pm_arc, r#"{"x":1}"#).await.unwrap_err();
let msg = err.to_string();
assert!(msg.contains("blocked by plugin"), "got: {msg}");
assert!(msg.contains("danger"), "got: {msg}");
}
#[tokio::test]
async fn mutate_input_replaces_args_before_inner_call() {
let pm_arc = pm();
{
let mut mgr = pm_arc.lock().unwrap();
mgr.eval(r#"(defn rewrite [ctx] (harness/mutate-input "{\"x\":42}"))"#)
.unwrap();
mgr.register("on-tool-start", "rewrite");
}
let result = wrap_and_call(pm_arc, r#"{"x":1}"#).await.unwrap();
assert_eq!(result, r#"{"x":42}"#);
}
#[tokio::test]
async fn replace_result_swaps_inner_output() {
let pm_arc = pm();
{
let mut mgr = pm_arc.lock().unwrap();
mgr.eval(r#"(defn truncate [ctx] (harness/replace-result "[filtered]"))"#)
.unwrap();
mgr.register("on-tool-end", "truncate");
}
let result = wrap_and_call(pm_arc, r#"{"x":1}"#).await.unwrap();
assert_eq!(result, "[filtered]");
}
#[tokio::test]
async fn block_precedence_over_mutate_when_both_set() {
let pm_arc = pm();
{
let mut mgr = pm_arc.lock().unwrap();
mgr.eval(
r#"(defn paranoid [ctx]
(harness/mutate-input "{\"x\":99}")
(harness/block "no way"))"#,
)
.unwrap();
mgr.register("on-tool-start", "paranoid");
}
let err = wrap_and_call(pm_arc, r#"{"x":1}"#).await.unwrap_err();
assert!(err.to_string().contains("no way"));
}
#[tokio::test]
async fn slots_reset_between_calls() {
let pm_arc = pm();
{
let mut mgr = pm_arc.lock().unwrap();
mgr.eval(
r#"(var seen 0)
(defn once-blocker [ctx]
(set seen (+ seen 1))
(when (= seen 1) (harness/block "first only")))"#,
)
.unwrap();
mgr.register("on-tool-start", "once-blocker");
}
let err = wrap_and_call(pm_arc.clone(), r#"{"x":1}"#)
.await
.unwrap_err();
assert!(err.to_string().contains("first only"));
let result = wrap_and_call(pm_arc, r#"{"x":2}"#).await.unwrap();
assert_eq!(result, r#"{"x":2}"#);
}
#[tokio::test]
async fn on_tool_end_fires_when_inner_returns_error() {
struct AlwaysFail;
impl ToolDyn for AlwaysFail {
fn name(&self) -> String {
"always_fail".to_string()
}
fn definition<'a>(
&'a self,
_prompt: String,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = ToolDefinition> + Send + 'a>>
{
Box::pin(async move {
ToolDefinition {
name: "always_fail".to_string(),
description: "always errors".to_string(),
parameters: serde_json::json!({}),
}
})
}
fn call<'a>(
&'a self,
_args: String,
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = Result<String, ToolError>> + Send + 'a>,
> {
Box::pin(async move {
Err(ToolError::ToolCallError(Box::<
dyn std::error::Error + Send + Sync,
>::from(
"deliberate failure".to_string()
)))
})
}
}
let pm_arc = pm();
{
let mut mgr = pm_arc.lock().unwrap();
mgr.eval(
r#"(defn rewrite-error [ctx]
(harness/replace-result "[error swallowed by plugin]"))"#,
)
.unwrap();
mgr.register("on-tool-end", "rewrite-error");
}
let wrapper = HookedToolDyn::with_manager(Box::new(AlwaysFail), Some(pm_arc));
let result = wrapper.call(String::new()).await;
assert_eq!(result.unwrap(), "[error swallowed by plugin]");
}
}