use crate::events::{EventContext, EventRequest, ToolCallRequestedData};
use crate::tool_types::{ToolCall, ToolDefinition, ToolResult};
use crate::traits::{EventEmitter, ToolContext};
use async_trait::async_trait;
use serde_json::json;
use std::sync::Arc;
use uuid::Uuid;
use super::AtomContext;
use super::act::ActResult;
#[derive(Debug, Clone)]
pub enum PreToolUseDecision {
Continue(ToolCall),
Block {
tool_call: ToolCall,
reason: String,
user_message: Option<String>,
},
}
#[async_trait]
pub trait PreToolUseHook: Send + Sync {
async fn before_exec(
&self,
tool_call: ToolCall,
tool_def: &ToolDefinition,
context: &ToolContext,
) -> PreToolUseDecision;
}
pub(super) async fn run_pre_tool_use_hooks(
hooks: &[Arc<dyn PreToolUseHook>],
mut tool_call: ToolCall,
tool_def: &ToolDefinition,
context: &ToolContext,
) -> PreToolUseDecision {
for hook in hooks {
match hook.before_exec(tool_call.clone(), tool_def, context).await {
PreToolUseDecision::Continue(updated) => {
tool_call = updated;
}
block @ PreToolUseDecision::Block { .. } => return block,
}
}
PreToolUseDecision::Continue(tool_call)
}
#[async_trait]
pub trait PostToolExecHook: Send + Sync {
async fn after_exec(
&self,
tool_call: &ToolCall,
tool_def: &ToolDefinition,
result: &mut ToolResult,
context: &ToolContext,
);
}
pub(super) async fn run_post_tool_exec_hooks(
hooks: &[Arc<dyn PostToolExecHook>],
final_hooks: &[Arc<dyn PostToolExecHook>],
tool_call: &ToolCall,
tool_def: &ToolDefinition,
result: &mut ToolResult,
context: &ToolContext,
) {
for hook in hooks {
hook.after_exec(tool_call, tool_def, result, context).await;
}
for hook in final_hooks {
hook.after_exec(tool_call, tool_def, result, context).await;
}
}
const MAX_TOOL_RESULT_BYTES: usize = 64 * 1024;
const TRUNCATION_SUFFIX: &str =
"\n\n[Output truncated — exceeded 64 KiB limit. Try quiet flags, pipes, or redirect to file.]";
pub struct OutputHardLimitHook;
impl OutputHardLimitHook {
fn truncate(text: String) -> String {
if text.len() <= MAX_TOOL_RESULT_BYTES {
return text;
}
let content_budget = MAX_TOOL_RESULT_BYTES.saturating_sub(TRUNCATION_SUFFIX.len());
let mut end = content_budget;
while end > 0 && !text.is_char_boundary(end) {
end -= 1;
}
let mut truncated = text[..end].to_string();
truncated.push_str(TRUNCATION_SUFFIX);
truncated
}
}
#[async_trait]
impl PostToolExecHook for OutputHardLimitHook {
async fn after_exec(
&self,
tool_call: &ToolCall,
_tool_def: &ToolDefinition,
result: &mut ToolResult,
_context: &ToolContext,
) {
if let Some(val) = result.result.take() {
match val {
serde_json::Value::String(s) => {
let original_len = s.len();
let truncated = Self::truncate(s);
if truncated.len() < original_len {
tracing::warn!(
tool_name = %tool_call.name,
tool_call_id = %tool_call.id,
result_bytes = original_len,
limit = MAX_TOOL_RESULT_BYTES,
"Tool result exceeded hard limit, truncated"
);
}
result.result = Some(serde_json::Value::String(truncated));
}
other => {
let serialized = serde_json::to_string(&other).unwrap_or_default();
if serialized.len() > MAX_TOOL_RESULT_BYTES {
tracing::warn!(
tool_name = %tool_call.name,
tool_call_id = %tool_call.id,
result_bytes = serialized.len(),
limit = MAX_TOOL_RESULT_BYTES,
"Tool result exceeded hard limit, truncated"
);
let truncated = Self::truncate(serialized);
result.result = Some(serde_json::Value::String(truncated));
} else {
result.result = Some(other);
}
}
}
}
if let Some(err) = result.error.take() {
if err.len() > MAX_TOOL_RESULT_BYTES {
tracing::warn!(
tool_name = %tool_call.name,
tool_call_id = %tool_call.id,
result_bytes = err.len(),
limit = MAX_TOOL_RESULT_BYTES,
"Tool error exceeded hard limit, truncated"
);
}
result.error = Some(Self::truncate(err));
}
if let Some(images) = result.images.as_mut() {
let original_count = images.len();
let mut cumulative = 0usize;
images.retain(|img| {
let len = img.base64.len();
if len > MAX_TOOL_RESULT_BYTES {
return false;
}
match cumulative.checked_add(len) {
Some(total) if total <= MAX_TOOL_RESULT_BYTES => {
cumulative = total;
true
}
_ => false,
}
});
let dropped = original_count.saturating_sub(images.len());
if dropped > 0 {
tracing::warn!(
tool_name = %tool_call.name,
tool_call_id = %tool_call.id,
dropped_images = dropped,
kept_images = images.len(),
kept_bytes = cumulative,
limit = MAX_TOOL_RESULT_BYTES,
"Tool images exceeded hard limit and were dropped"
);
}
if images.is_empty() {
result.images = None;
}
}
}
}
#[derive(Debug, Clone)]
pub enum PostActAction {
EmitToolCallRequested {
tool_calls: Vec<ToolCall>,
tool_definitions: Vec<ToolDefinition>,
},
}
pub trait PostActHook: Send + Sync {
fn on_completed(
&self,
result: &mut ActResult,
tool_definitions: &[ToolDefinition],
) -> Vec<PostActAction>;
}
pub struct ConnectionSetupHook;
impl PostActHook for ConnectionSetupHook {
fn on_completed(
&self,
result: &mut ActResult,
_tool_definitions: &[ToolDefinition],
) -> Vec<PostActAction> {
let providers: Vec<String> = result
.results
.iter()
.filter_map(|r| r.connection_required.clone())
.collect();
if providers.is_empty() {
return vec![];
}
result.waiting_for_tool_results = true;
let tool_calls: Vec<ToolCall> = providers
.iter()
.map(|provider| ToolCall {
id: format!("setup_conn_{}", Uuid::now_v7()),
name: "setup_connection".to_string(),
arguments: json!({ "provider": provider }),
})
.collect();
vec![PostActAction::EmitToolCallRequested {
tool_calls,
tool_definitions: vec![],
}]
}
}
pub struct ClientSideToolHook;
impl PostActHook for ClientSideToolHook {
fn on_completed(
&self,
result: &mut ActResult,
_tool_definitions: &[ToolDefinition],
) -> Vec<PostActAction> {
if result.client_tool_calls.is_empty() {
return vec![];
}
result.waiting_for_tool_results = true;
vec![PostActAction::EmitToolCallRequested {
tool_calls: result.client_tool_calls.clone(),
tool_definitions: result.client_tool_definitions.clone(),
}]
}
}
pub(super) async fn run_post_act_hooks<E: EventEmitter>(
hooks: &[Box<dyn PostActHook>],
context: &AtomContext,
result: &mut ActResult,
tool_definitions: &[ToolDefinition],
event_emitter: &E,
locale: Option<&str>,
) {
for hook in hooks {
let actions = hook.on_completed(result, tool_definitions);
for action in actions {
match action {
PostActAction::EmitToolCallRequested {
tool_calls,
tool_definitions: action_defs,
} => {
let event = EventRequest::new(
context.session_id,
EventContext::turn(context.turn_id, context.input_message_id),
ToolCallRequestedData::with_definitions_and_locale(
&tool_calls,
&action_defs,
locale,
),
);
if let Err(e) = event_emitter.emit(event).await {
tracing::warn!(
error = %e,
"PostActHook: failed to emit tool.call_requested event"
);
}
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::atoms::act::ToolCallResult;
use crate::tool_types::ToolResult;
fn make_tool_call_result(connection_required: Option<&str>) -> ToolCallResult {
ToolCallResult {
tool_call: ToolCall {
id: "call_1".to_string(),
name: "some_tool".to_string(),
arguments: json!({}),
},
result: ToolResult {
tool_call_id: "call_1".to_string(),
result: Some(json!({})),
images: None,
error: None,
connection_required: connection_required.map(|s| s.to_string()),
raw_output: None,
},
success: true,
status: "success".to_string(),
connection_required: connection_required.map(|s| s.to_string()),
}
}
#[test]
fn test_connection_setup_hook_no_connections() {
let hook = ConnectionSetupHook;
let mut result = ActResult {
results: vec![make_tool_call_result(None)],
completed: true,
success_count: 1,
error_count: 0,
waiting_for_tool_results: false,
blocked: false,
client_tool_calls: vec![],
client_tool_definitions: vec![],
};
let actions = hook.on_completed(&mut result, &[]);
assert!(actions.is_empty());
assert!(!result.waiting_for_tool_results);
}
#[test]
fn test_connection_setup_hook_with_connection() {
let hook = ConnectionSetupHook;
let mut result = ActResult {
results: vec![make_tool_call_result(Some("github"))],
completed: true,
success_count: 0,
error_count: 0,
waiting_for_tool_results: false,
blocked: false,
client_tool_calls: vec![],
client_tool_definitions: vec![],
};
let actions = hook.on_completed(&mut result, &[]);
assert_eq!(actions.len(), 1);
assert!(result.waiting_for_tool_results);
match &actions[0] {
PostActAction::EmitToolCallRequested { tool_calls, .. } => {
assert_eq!(tool_calls.len(), 1);
assert_eq!(tool_calls[0].name, "setup_connection");
assert_eq!(tool_calls[0].arguments["provider"], "github");
}
}
}
#[test]
fn test_client_side_tool_hook_no_client_tools() {
let hook = ClientSideToolHook;
let mut result = ActResult {
results: vec![],
completed: true,
success_count: 0,
error_count: 0,
waiting_for_tool_results: false,
blocked: false,
client_tool_calls: vec![],
client_tool_definitions: vec![],
};
let actions = hook.on_completed(&mut result, &[]);
assert!(actions.is_empty());
assert!(!result.waiting_for_tool_results);
}
#[test]
fn test_client_side_tool_hook_with_client_tools() {
let hook = ClientSideToolHook;
let client_call = ToolCall {
id: "call_client".to_string(),
name: "browser_click".to_string(),
arguments: json!({"selector": "#btn"}),
};
let mut result = ActResult {
results: vec![],
completed: true,
success_count: 0,
error_count: 0,
waiting_for_tool_results: false,
blocked: false,
client_tool_calls: vec![client_call.clone()],
client_tool_definitions: vec![],
};
let actions = hook.on_completed(&mut result, &[]);
assert_eq!(actions.len(), 1);
assert!(result.waiting_for_tool_results);
match &actions[0] {
PostActAction::EmitToolCallRequested { tool_calls, .. } => {
assert_eq!(tool_calls.len(), 1);
assert_eq!(tool_calls[0].name, "browser_click");
}
}
}
use crate::traits::ToolContext;
use crate::typed_id::SessionId;
fn make_tool_call() -> ToolCall {
ToolCall {
id: "call_test".to_string(),
name: "test_tool".to_string(),
arguments: json!({}),
}
}
fn make_tool_def() -> ToolDefinition {
ToolDefinition::Builtin(crate::tool_types::BuiltinTool {
name: "test_tool".to_string(),
display_name: None,
description: "test".to_string(),
parameters: json!({}),
policy: crate::tool_types::ToolPolicy::Auto,
category: None,
deferrable: crate::tool_types::DeferrablePolicy::Never,
hints: Default::default(),
})
}
#[tokio::test]
async fn test_output_hard_limit_passthrough_small() {
let hook = OutputHardLimitHook;
let tc = make_tool_call();
let td = make_tool_def();
let ctx = ToolContext::new(SessionId::new());
let mut result = ToolResult {
tool_call_id: "call_test".into(),
result: Some(json!("hello")),
images: None,
error: None,
connection_required: None,
raw_output: None,
};
hook.after_exec(&tc, &td, &mut result, &ctx).await;
assert_eq!(result.result, Some(json!("hello")));
}
#[tokio::test]
async fn test_output_hard_limit_truncates_large_string() {
let hook = OutputHardLimitHook;
let tc = make_tool_call();
let td = make_tool_def();
let ctx = ToolContext::new(SessionId::new());
let big = "x".repeat(MAX_TOOL_RESULT_BYTES + 1000);
let mut result = ToolResult {
tool_call_id: "call_test".into(),
result: Some(json!(big)),
images: None,
error: None,
connection_required: None,
raw_output: None,
};
hook.after_exec(&tc, &td, &mut result, &ctx).await;
let text = result.result.unwrap();
let s = text.as_str().unwrap();
assert!(s.len() <= MAX_TOOL_RESULT_BYTES);
assert!(s.ends_with(TRUNCATION_SUFFIX));
}
#[tokio::test]
async fn test_output_hard_limit_at_exact_limit() {
let hook = OutputHardLimitHook;
let tc = make_tool_call();
let td = make_tool_def();
let ctx = ToolContext::new(SessionId::new());
let exact = "a".repeat(MAX_TOOL_RESULT_BYTES);
let mut result = ToolResult {
tool_call_id: "call_test".into(),
result: Some(json!(exact)),
images: None,
error: None,
connection_required: None,
raw_output: None,
};
hook.after_exec(&tc, &td, &mut result, &ctx).await;
let text = result.result.unwrap();
let s = text.as_str().unwrap();
assert_eq!(s.len(), MAX_TOOL_RESULT_BYTES);
assert!(!s.contains("[Output truncated"));
}
#[tokio::test]
async fn test_output_hard_limit_multibyte_boundary() {
let hook = OutputHardLimitHook;
let tc = make_tool_call();
let td = make_tool_def();
let ctx = ToolContext::new(SessionId::new());
let ch = "€"; let count = MAX_TOOL_RESULT_BYTES / ch.len() + 1;
let big = ch.repeat(count);
let mut result = ToolResult {
tool_call_id: "call_test".into(),
result: Some(json!(big)),
images: None,
error: None,
connection_required: None,
raw_output: None,
};
hook.after_exec(&tc, &td, &mut result, &ctx).await;
let text = result.result.unwrap();
let s = text.as_str().unwrap();
assert!(s.len() <= MAX_TOOL_RESULT_BYTES);
assert!(s.contains("[Output truncated"));
}
#[tokio::test]
async fn test_output_hard_limit_truncates_error() {
let hook = OutputHardLimitHook;
let tc = make_tool_call();
let td = make_tool_def();
let ctx = ToolContext::new(SessionId::new());
let big_err = "e".repeat(MAX_TOOL_RESULT_BYTES + 500);
let mut result = ToolResult {
tool_call_id: "call_test".into(),
result: None,
images: None,
error: Some(big_err),
connection_required: None,
raw_output: None,
};
hook.after_exec(&tc, &td, &mut result, &ctx).await;
let err = result.error.unwrap();
assert!(err.len() <= MAX_TOOL_RESULT_BYTES);
assert!(err.ends_with(TRUNCATION_SUFFIX));
}
#[tokio::test]
async fn test_output_hard_limit_non_string_json() {
let hook = OutputHardLimitHook;
let tc = make_tool_call();
let td = make_tool_def();
let ctx = ToolContext::new(SessionId::new());
let mut result = ToolResult {
tool_call_id: "call_test".into(),
result: Some(json!({"key": "value", "num": 42})),
images: None,
error: None,
connection_required: None,
raw_output: None,
};
hook.after_exec(&tc, &td, &mut result, &ctx).await;
assert_eq!(result.result, Some(json!({"key": "value", "num": 42})));
}
#[tokio::test]
async fn test_output_hard_limit_drops_oversized_images() {
let hook = OutputHardLimitHook;
let tc = make_tool_call();
let td = make_tool_def();
let ctx = ToolContext::new(SessionId::new());
let mut result = ToolResult {
tool_call_id: "call_test".into(),
result: Some(json!({"ok": true})),
images: Some(vec![
crate::tools::ToolResultImage {
base64: "a".repeat(32),
media_type: "image/png".to_string(),
},
crate::tools::ToolResultImage {
base64: "b".repeat(MAX_TOOL_RESULT_BYTES + 1),
media_type: "image/png".to_string(),
},
]),
error: None,
connection_required: None,
raw_output: None,
};
hook.after_exec(&tc, &td, &mut result, &ctx).await;
let images = result.images.unwrap();
assert_eq!(images.len(), 1);
assert_eq!(images[0].base64.len(), 32);
}
#[tokio::test]
async fn test_output_hard_limit_enforces_cumulative_image_budget() {
let hook = OutputHardLimitHook;
let tc = make_tool_call();
let td = make_tool_def();
let ctx = ToolContext::new(SessionId::new());
let half = MAX_TOOL_RESULT_BYTES / 2;
let mut result = ToolResult {
tool_call_id: "call_test".into(),
result: Some(json!({"ok": true})),
images: Some(vec![
crate::tools::ToolResultImage {
base64: "a".repeat(half),
media_type: "image/png".to_string(),
},
crate::tools::ToolResultImage {
base64: "b".repeat(half),
media_type: "image/png".to_string(),
},
crate::tools::ToolResultImage {
base64: "c".repeat(half),
media_type: "image/png".to_string(),
},
]),
error: None,
connection_required: None,
raw_output: None,
};
hook.after_exec(&tc, &td, &mut result, &ctx).await;
let images = result.images.unwrap();
assert_eq!(
images.len(),
2,
"third image should be dropped by cumulative budget"
);
assert!(images.iter().all(|i| i.base64.len() == half));
}
#[tokio::test]
async fn test_output_hard_limit_normalizes_empty_images_to_none() {
let hook = OutputHardLimitHook;
let tc = make_tool_call();
let td = make_tool_def();
let ctx = ToolContext::new(SessionId::new());
let mut result = ToolResult {
tool_call_id: "call_test".into(),
result: Some(json!({"ok": true})),
images: Some(vec![crate::tools::ToolResultImage {
base64: "a".repeat(MAX_TOOL_RESULT_BYTES + 1),
media_type: "image/png".to_string(),
}]),
error: None,
connection_required: None,
raw_output: None,
};
hook.after_exec(&tc, &td, &mut result, &ctx).await;
assert!(
result.images.is_none(),
"images vec emptied by retain should normalize to None"
);
}
#[test]
fn test_truncate_helper_short() {
let s = "hello".to_string();
assert_eq!(OutputHardLimitHook::truncate(s.clone()), s);
}
#[test]
fn test_truncate_helper_over() {
let s = "a".repeat(MAX_TOOL_RESULT_BYTES + 100);
let t = OutputHardLimitHook::truncate(s);
assert!(t.len() <= MAX_TOOL_RESULT_BYTES);
assert!(t.ends_with(TRUNCATION_SUFFIX));
}
}