#![cfg(feature = "plugin")]
use super::*;
use crate::agent::agent_loop::hooks::AfterToolCallContext;
use crate::agent::agent_loop::message::{StreamEvent, UserMessage};
use crate::agent::agent_loop::plugin_hooks::{
after_hook_from_plugin_manager, before_hook_from_plugin_manager,
compaction_hooks_from_plugin_manager, get_followup_messages_from_plugin_manager,
get_steering_messages_from_plugin_manager, prepare_next_turn_from_plugin_manager,
should_stop_after_turn_from_plugin_manager, transform_context_from_plugin_manager,
};
use crate::agent::agent_loop::result::LoopToolResult;
use crate::agent::agent_loop::stream::StreamFn;
use crate::agent::agent_loop::tool::{AbortSignal, LoopTool, LoopToolUpdate};
use crate::agent::agent_loop::types::{ConvertToLlmFn, LoopConfig, ToolExecutionMode};
use crate::plugin::PluginManager;
use serde_json::Value;
use std::future::Future;
use std::pin::Pin;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Mutex};
use tokio::sync::mpsc;
fn try_pm() -> Option<Arc<Mutex<PluginManager>>> {
match PluginManager::try_new() {
Ok(mgr) => Some(Arc::new(Mutex::new(mgr))),
Err(_) => None,
}
}
fn identity_converter() -> ConvertToLlmFn {
Arc::new(|messages: &[Value]| {
messages
.iter()
.filter(|m| {
let role = m.get("role").and_then(|r| r.as_str()).unwrap_or("");
matches!(role, "user" | "assistant" | "tool" | "toolResult")
})
.cloned()
.collect()
})
}
fn build_config() -> LoopConfig {
LoopConfig {
convert_to_llm: identity_converter(),
transform_context: None,
compaction_hooks: None,
get_api_key: None,
api_key: None,
tool_execution: ToolExecutionMode::Sequential,
before_tool_call: None,
after_tool_call: None,
prepare_next_turn: None,
should_stop_after_turn: None,
get_steering_messages: None,
get_followup_messages: None,
reasoning: None,
thinking_budgets: None,
headers: std::collections::HashMap::new(),
metadata: std::collections::HashMap::new(),
request_timeout: None,
provider_name: None,
model_name: None,
compact_model: None,
storm_mutating_tools: None,
storm_exempt_tools: None,
repair_stats: Arc::new(crate::agent::agent_loop::tool_input_repair::RepairStats::new()),
truncation_notes: Arc::new(Mutex::new(std::collections::HashMap::new())),
tool_def_filter: None,
dynamic_tool_search: false,
escalation_stream_fn: None,
escalation_provider_name: None,
escalation_pending: Arc::new(Mutex::new(None)),
escalation_max_per_session: 3,
escalation_remaining: Arc::new(std::sync::atomic::AtomicUsize::new(3)),
file_touch_tracker: None,
verifier: None,
critic_fn: None,
goal: None,
max_turns: None,
}
}
fn empty_context() -> Context {
Context {
system_prompt: String::new(),
messages: Vec::new(),
tools: Vec::new(),
}
}
fn user(text: &str) -> LoopMessage {
LoopMessage::User(UserMessage {
content: text.to_string(),
})
}
fn text_response(text: &str) -> AssistantMessage {
AssistantMessage::new(
vec![ContentBlock::Text {
text: text.to_string(),
}],
StopReason::Stop,
)
}
fn tool_use_response(id: &str, name: &str, args: Value) -> AssistantMessage {
AssistantMessage::new(
vec![ContentBlock::ToolCall {
id: id.to_string(),
name: name.to_string(),
arguments: args,
}],
StopReason::ToolUse,
)
}
fn canned_factory(responses: Vec<AssistantMessage>) -> StreamFn {
let counter = Arc::new(AtomicUsize::new(0));
let responses = Arc::new(responses);
Arc::new(move |_ctx, _opts| {
let n = counter.fetch_add(1, Ordering::SeqCst);
let msg = responses.get(n).cloned().unwrap_or_else(|| {
AssistantMessage::new(
vec![ContentBlock::Text {
text: "end".to_string(),
}],
StopReason::Stop,
)
});
let reason = msg.stop_reason;
Box::pin(futures::stream::iter(vec![StreamEvent::Done {
reason,
message: msg,
usage: None,
}]))
})
}
#[derive(Debug)]
struct RecordingTool {
name_str: String,
calls: Arc<Mutex<Vec<(String, Value)>>>,
}
impl RecordingTool {
fn new(name: &str) -> Self {
Self {
name_str: name.to_string(),
calls: Arc::new(Mutex::new(Vec::new())),
}
}
fn calls(&self) -> Vec<(String, Value)> {
self.calls.lock().unwrap().clone()
}
}
impl LoopTool for RecordingTool {
fn name(&self) -> &str {
&self.name_str
}
fn description(&self) -> &str {
"Recording mock"
}
fn label(&self) -> &str {
"Recording"
}
fn parameters(&self) -> &Value {
static EMPTY: std::sync::OnceLock<Value> = std::sync::OnceLock::new();
EMPTY.get_or_init(|| serde_json::json!({"type": "object"}))
}
fn execute<'a>(
&'a self,
id: &'a str,
args: Value,
_signal: AbortSignal,
_on_update: LoopToolUpdate,
) -> Pin<Box<dyn Future<Output = Result<LoopToolResult, String>> + Send + 'a>> {
let calls = self.calls.clone();
let id = id.to_string();
Box::pin(async move {
calls.lock().unwrap().push((id.clone(), args.clone()));
Ok(LoopToolResult {
content: vec![serde_json::json!({"type": "text", "text": "original output"})],
details: args,
terminate: None,
})
})
}
}
async fn drain(rx: &mut mpsc::Receiver<LoopEvent>) -> Vec<LoopEvent> {
let mut out = Vec::new();
while let Some(e) = rx.recv().await {
out.push(e);
}
out
}
#[tokio::test]
async fn ywj_before_tool_call_block_prevents_invocation() {
let Some(pm) = try_pm() else {
eprintln!("[skipped] PluginManager::try_new failed");
return;
};
{
let mut mgr = pm.lock().unwrap();
mgr.eval(r#"(defn deny [_ctx] (harness/block "policy denial"))"#)
.expect("install deny");
mgr.register("on-tool-start", "deny");
}
let tool = Arc::new(RecordingTool::new("bash"));
let mut ctx = empty_context();
ctx.tools.push(tool.clone());
let factory = canned_factory(vec![
tool_use_response("call-1", "bash", serde_json::json!({"cmd": "ls"})),
text_response("done"),
]);
let mut cfg = build_config();
cfg.before_tool_call = Some(before_hook_from_plugin_manager(pm.clone()));
let (tx, _rx) = mpsc::channel::<LoopEvent>(64);
let messages = run_agent_loop(
vec![user("run bash")],
ctx,
cfg,
AbortSignal::new(),
&tx,
&factory,
None,
None, )
.await;
drop(tx);
assert!(
tool.calls().is_empty(),
"blocked tool must not be invoked; got calls: {:?}",
tool.calls(),
);
let saw_block_reason = messages.iter().any(|m| match m {
LoopMessage::ToolResult(t) => t.content.iter().any(|c| match c {
ContentBlock::Text { text } => text.contains("policy denial"),
_ => false,
}),
_ => false,
});
assert!(
saw_block_reason,
"tool result should carry the block reason text",
);
}
#[tokio::test]
async fn ywj_before_tool_call_mutation_threads_through() {
let Some(pm) = try_pm() else {
return;
};
{
let mut mgr = pm.lock().unwrap();
mgr.eval(r#"(defn rewrite [_ctx] (harness/mutate-input "{\"cmd\":\"echo mutated\"}"))"#)
.expect("install rewrite");
mgr.register("on-tool-start", "rewrite");
}
let tool = Arc::new(RecordingTool::new("bash"));
let mut ctx = empty_context();
ctx.tools.push(tool.clone());
let factory = canned_factory(vec![
tool_use_response("call-1", "bash", serde_json::json!({"cmd": "ls"})),
text_response("done"),
]);
let mut cfg = build_config();
cfg.before_tool_call = Some(before_hook_from_plugin_manager(pm.clone()));
let (tx, _rx) = mpsc::channel::<LoopEvent>(64);
let _ = run_agent_loop(
vec![user("run bash")],
ctx,
cfg,
AbortSignal::new(),
&tx,
&factory,
None,
None, )
.await;
let calls = tool.calls();
assert_eq!(calls.len(), 1, "tool should run exactly once");
assert_eq!(
calls[0].1,
serde_json::json!({"cmd": "echo mutated"}),
"tool must observe MUTATED args, not original",
);
}
#[tokio::test]
async fn ywj_after_tool_call_replaces_result_content() {
let Some(pm) = try_pm() else {
return;
};
{
let mut mgr = pm.lock().unwrap();
mgr.eval(r#"(defn rewrite [_ctx] (harness/replace-result "REPLACED"))"#)
.expect("install rewrite");
mgr.register("on-tool-end", "rewrite");
}
let tool = Arc::new(RecordingTool::new("bash"));
let mut ctx = empty_context();
ctx.tools.push(tool.clone());
let factory = canned_factory(vec![
tool_use_response("call-1", "bash", serde_json::json!({"cmd": "ls"})),
text_response("done"),
]);
let mut cfg = build_config();
cfg.after_tool_call = Some(after_hook_from_plugin_manager(pm.clone()));
let (tx, mut rx) = mpsc::channel::<LoopEvent>(128);
let messages = run_agent_loop(
vec![user("run bash")],
ctx,
cfg,
AbortSignal::new(),
&tx,
&factory,
None,
None, )
.await;
drop(tx);
assert_eq!(tool.calls().len(), 1);
let tool_result = messages
.iter()
.find_map(|m| match m {
LoopMessage::ToolResult(t) => Some(t.clone()),
_ => None,
})
.expect("tool result message present");
let saw_replaced = tool_result.content.iter().any(|c| match c {
ContentBlock::Text { text } => text.contains("REPLACED"),
_ => false,
});
let saw_original = tool_result.content.iter().any(|c| match c {
ContentBlock::Text { text } => text.contains("original output"),
_ => false,
});
assert!(saw_replaced, "tool result should carry replaced content");
assert!(
!saw_original,
"tool result must NOT carry original content after replacement",
);
let events = drain(&mut rx).await;
let exec_end_replaced = events.iter().any(|e| match e {
LoopEvent::ToolExecutionEnd { result, .. } => result.content.iter().any(|c| {
c.as_object()
.and_then(|o| o.get("text"))
.and_then(|t| t.as_str())
.map(|s| s.contains("REPLACED"))
.unwrap_or(false)
}),
_ => false,
});
assert!(exec_end_replaced, "ToolExecutionEnd should carry REPLACED");
}
#[tokio::test]
async fn ywj_prepare_next_turn_returns_turn_update() {
let Some(pm) = try_pm() else {
return;
};
{
let mut mgr = pm.lock().unwrap();
mgr.eval(r#"(defn bump [_ctx] (harness/set-next-thinking-level "high"))"#)
.unwrap();
mgr.register("on-tool-end", "bump");
}
let tool = Arc::new(RecordingTool::new("noop"));
let mut ctx = empty_context();
ctx.tools.push(tool.clone());
let factory = canned_factory(vec![
tool_use_response("call-1", "noop", serde_json::json!({})),
text_response("done"),
]);
let mut cfg = build_config();
cfg.prepare_next_turn = Some(prepare_next_turn_from_plugin_manager(pm.clone()));
let (tx, _rx) = mpsc::channel::<LoopEvent>(128);
let messages = run_agent_loop(
vec![user("hi")],
ctx,
cfg,
AbortSignal::new(),
&tx,
&factory,
None,
None, )
.await;
let roles: Vec<&'static str> = messages.iter().map(|m| m.role()).collect();
assert_eq!(roles, vec!["user", "assistant", "toolResult", "assistant"]);
let pending = pm.lock().unwrap().take_pending_next_thinking_level();
assert!(
pending.is_none(),
"prepare_next_turn should have drained the slot",
);
}
#[tokio::test]
async fn ywj_should_stop_after_turn_terminates_loop() {
let Some(pm) = try_pm() else {
return;
};
{
let mut mgr = pm.lock().unwrap();
mgr.eval(r#"(defn stop [_ctx] (harness/request-stop-after-turn))"#)
.unwrap();
mgr.register("on-tool-end", "stop");
}
let tool = Arc::new(RecordingTool::new("noop"));
let mut ctx = empty_context();
ctx.tools.push(tool.clone());
let llm_calls = Arc::new(AtomicUsize::new(0));
let llm_calls_clone = llm_calls.clone();
let factory: StreamFn = Arc::new(move |_ctx, _opts| {
let n = llm_calls_clone.fetch_add(1, Ordering::SeqCst);
let msg = if n == 0 {
tool_use_response("call-1", "noop", serde_json::json!({}))
} else {
text_response("should not appear")
};
let reason = msg.stop_reason;
Box::pin(futures::stream::iter(vec![StreamEvent::Done {
reason,
message: msg,
usage: None,
}]))
});
let mut cfg = build_config();
cfg.after_tool_call = Some(after_hook_from_plugin_manager(pm.clone()));
cfg.should_stop_after_turn = Some(should_stop_after_turn_from_plugin_manager(pm.clone()));
let (tx, mut rx) = mpsc::channel::<LoopEvent>(128);
let _messages = run_agent_loop(
vec![user("hi")],
ctx,
cfg,
AbortSignal::new(),
&tx,
&factory,
None,
None, )
.await;
drop(tx);
assert_eq!(
llm_calls.load(Ordering::SeqCst),
1,
"should_stop_after_turn must prevent the second LLM call",
);
let kinds: Vec<&str> = drain(&mut rx).await.iter().map(|e| e.kind()).collect();
assert!(kinds.contains(&"agent_end"), "agent_end fires on stop");
}
#[tokio::test]
async fn ywj_get_steering_messages_injects_user_message_at_boundary() {
let Some(pm) = try_pm() else {
return;
};
{
let mut mgr = pm.lock().unwrap();
mgr.eval(r#"(harness/add-steering "queued steering")"#)
.unwrap();
}
let factory = canned_factory(vec![text_response("ok")]);
let mut cfg = build_config();
cfg.get_steering_messages = Some(get_steering_messages_from_plugin_manager(pm.clone()));
let (tx, _rx) = mpsc::channel::<LoopEvent>(64);
let messages = run_agent_loop(
vec![user("hi")],
empty_context(),
cfg,
AbortSignal::new(),
&tx,
&factory,
None,
None, )
.await;
let user_contents: Vec<String> = messages
.iter()
.filter_map(|m| match m {
LoopMessage::User(u) => Some(u.content.clone()),
_ => None,
})
.collect();
assert!(
user_contents.contains(&"hi".to_string()),
"original prompt preserved",
);
assert!(
user_contents.contains(&"queued steering".to_string()),
"steering message injected as User; got {user_contents:?}",
);
}
#[tokio::test]
async fn ywj_get_followup_messages_reenters_outer_loop() {
let Some(pm) = try_pm() else {
return;
};
{
let mut mgr = pm.lock().unwrap();
mgr.eval(r#"(harness/add-followup "followup question")"#)
.unwrap();
}
let factory = canned_factory(vec![
text_response("first done"),
text_response("second done"),
]);
let mut cfg = build_config();
cfg.get_followup_messages = Some(get_followup_messages_from_plugin_manager(pm.clone()));
let (tx, _rx) = mpsc::channel::<LoopEvent>(64);
let messages = run_agent_loop(
vec![user("hi")],
empty_context(),
cfg,
AbortSignal::new(),
&tx,
&factory,
None,
None, )
.await;
let roles: Vec<&'static str> = messages.iter().map(|m| m.role()).collect();
assert_eq!(
roles,
vec!["user", "assistant", "user", "assistant"],
"outer-loop re-entered with followup as new user prompt",
);
let user_contents: Vec<String> = messages
.iter()
.filter_map(|m| match m {
LoopMessage::User(u) => Some(u.content.clone()),
_ => None,
})
.collect();
assert_eq!(user_contents, vec!["hi", "followup question"]);
}
#[test]
fn before_agent_start_appends_system_prompt() {
let Some(pm) = try_pm() else {
eprintln!("[skipped] PluginManager::try_new failed");
return;
};
let mut mgr = pm.lock().unwrap();
mgr.eval(r#"(defn aug [_ctx] (harness/append-system-prompt "TEAM RULES"))"#)
.expect("install aug");
mgr.register("before-agent-start", "aug");
mgr.dispatch("before-agent-start", "@{:system-prompt \"base preamble\"}")
.expect("dispatch");
assert_eq!(
mgr.take_system_prompt_append().as_deref(),
Some("TEAM RULES"),
"append slot must carry the hook's text",
);
assert_eq!(mgr.take_system_prompt_append(), None);
}
#[test]
fn message_end_rewrites_response() {
let Some(pm) = try_pm() else {
eprintln!("[skipped] PluginManager::try_new failed");
return;
};
let mut mgr = pm.lock().unwrap();
mgr.eval(r#"(defn redact [_ctx] (harness/rewrite-message "[redacted]"))"#)
.expect("install redact");
mgr.register("message-end", "redact");
mgr.dispatch("message-end", "@{:message \"secret output\"}")
.expect("dispatch");
assert_eq!(mgr.take_message_rewrite().as_deref(), Some("[redacted]"));
assert_eq!(mgr.take_message_rewrite(), None);
}
#[tokio::test]
async fn transform_context_replaces_messages() {
let Some(pm) = try_pm() else {
eprintln!("[skipped] PluginManager::try_new failed");
return;
};
{
let mut mgr = pm.lock().unwrap();
mgr.eval(
r#"(defn xform [_ctx] (harness/replace-context "[{\"role\":\"user\",\"content\":\"pruned\"}]"))"#,
)
.expect("install xform");
mgr.register("transform-context", "xform");
}
let f = transform_context_from_plugin_manager(pm.clone());
let original = vec![
serde_json::json!({"role": "user", "content": "a"}),
serde_json::json!({"role": "assistant", "content": "b"}),
];
let out = f(original).await;
assert_eq!(out.len(), 1, "context replaced by the hook's array");
assert_eq!(out[0]["content"], "pruned");
}
#[tokio::test]
async fn transform_context_passthrough_without_hook() {
let Some(pm) = try_pm() else {
eprintln!("[skipped] PluginManager::try_new failed");
return;
};
let f = transform_context_from_plugin_manager(pm.clone());
let original = vec![serde_json::json!({"role": "user", "content": "keep me"})];
let out = f(original.clone()).await;
assert_eq!(out, original, "no transform-context hook → unchanged");
}
#[tokio::test]
async fn transform_context_passthrough_on_malformed_json() {
let Some(pm) = try_pm() else {
eprintln!("[skipped] PluginManager::try_new failed");
return;
};
{
let mut mgr = pm.lock().unwrap();
mgr.eval(r#"(defn bad [_ctx] (harness/replace-context "not json {{{"))"#)
.expect("install bad");
mgr.register("transform-context", "bad");
}
let f = transform_context_from_plugin_manager(pm.clone());
let original = vec![serde_json::json!({"role": "user", "content": "keep me"})];
let out = f(original.clone()).await;
assert_eq!(out, original, "malformed JSON → original context preserved");
}
#[tokio::test]
async fn on_compact_hook_supplies_custom_summary() {
let Some(pm) = try_pm() else {
eprintln!("[skipped] PluginManager::try_new failed");
return;
};
{
let mut mgr = pm.lock().unwrap();
mgr.eval(
r#"(defn summ [_ctx] (harness/set-compact-summary "Active Task: plugin summary"))"#,
)
.expect("install summ");
mgr.register("on-compact", "summ");
}
let hooks = compaction_hooks_from_plugin_manager(pm.clone());
let middle = vec![serde_json::json!({"role": "user", "content": "old turn"})];
let summary = (hooks.on_compact)(middle).await;
assert_eq!(
summary.as_deref(),
Some("Active Task: plugin summary"),
"on-compact hook's summary must flow through the factory",
);
}
#[tokio::test]
async fn compaction_hooks_passthrough_without_hooks() {
let Some(pm) = try_pm() else {
eprintln!("[skipped] PluginManager::try_new failed");
return;
};
let hooks = compaction_hooks_from_plugin_manager(pm.clone());
(hooks.on_before)(5, 1234).await;
let summary =
(hooks.on_compact)(vec![serde_json::json!({"role": "user", "content": "x"})]).await;
assert_eq!(summary, None, "no on-compact hook → fall through to LLM");
}