use juncture_core::state::messages::ToolCall;
use crate::tools::error::ToolError;
pub trait ToolCallTransformer: Send + Sync + 'static {
fn transform(&self, tool_call: &mut ToolCall) -> Result<(), ToolError>;
}
#[derive(Debug)]
pub struct NopToolTransformer;
impl ToolCallTransformer for NopToolTransformer {
fn transform(&self, _tool_call: &mut ToolCall) -> Result<(), ToolError> {
Ok(())
}
}
pub struct CompositeTransformer {
transformers: Vec<Box<dyn ToolCallTransformer>>,
}
impl std::fmt::Debug for CompositeTransformer {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CompositeTransformer")
.field("transformers", &self.transformers.len())
.finish()
}
}
impl CompositeTransformer {
#[must_use]
pub fn new(transformers: Vec<Box<dyn ToolCallTransformer>>) -> Self {
Self { transformers }
}
pub fn add(&mut self, transformer: Box<dyn ToolCallTransformer>) {
self.transformers.push(transformer);
}
}
impl ToolCallTransformer for CompositeTransformer {
fn transform(&self, tool_call: &mut ToolCall) -> Result<(), ToolError> {
for transformer in &self.transformers {
transformer.transform(tool_call)?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
struct LimitInjector;
impl ToolCallTransformer for LimitInjector {
fn transform(&self, tool_call: &mut ToolCall) -> Result<(), ToolError> {
if tool_call.name == "search"
&& let Some(obj) = tool_call.arguments.as_object_mut()
&& !obj.contains_key("limit")
{
obj.insert("limit".to_string(), json!(10));
}
Ok(())
}
}
struct BlockingTransformer {
blocked_tools: Vec<String>,
}
impl ToolCallTransformer for BlockingTransformer {
fn transform(&self, tool_call: &mut ToolCall) -> Result<(), ToolError> {
if self.blocked_tools.contains(&tool_call.name) {
return Err(ToolError::Intercepted(format!(
"Tool '{}' is blocked",
tool_call.name
)));
}
Ok(())
}
}
#[test]
fn test_nop_transformer() {
let transformer = NopToolTransformer;
let mut tool_call = ToolCall {
id: "call_1".to_string(),
name: "test".to_string(),
arguments: json!({}),
};
transformer.transform(&mut tool_call).unwrap();
}
#[test]
fn test_limit_injector() {
let transformer = LimitInjector;
let mut tool_call = ToolCall {
id: "call_1".to_string(),
name: "search".to_string(),
arguments: json!({"query": "test"}),
};
transformer.transform(&mut tool_call).unwrap();
assert_eq!(tool_call.arguments["limit"], 10);
}
#[test]
fn test_limit_injector_non_search() {
let transformer = LimitInjector;
let mut tool_call = ToolCall {
id: "call_1".to_string(),
name: "other".to_string(),
arguments: json!({"query": "test"}),
};
transformer.transform(&mut tool_call).unwrap();
assert!(
!tool_call
.arguments
.as_object()
.unwrap()
.contains_key("limit")
);
}
#[test]
fn test_blocking_transformer() {
let transformer = BlockingTransformer {
blocked_tools: vec!["dangerous".to_string()],
};
let mut tool_call = ToolCall {
id: "call_1".to_string(),
name: "dangerous".to_string(),
arguments: json!({}),
};
let result = transformer.transform(&mut tool_call);
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), ToolError::Intercepted(_)));
}
#[test]
fn test_composite_transformer() {
let transformer1 = Box::new(NopToolTransformer) as Box<dyn ToolCallTransformer>;
let transformer2 = Box::new(LimitInjector) as Box<dyn ToolCallTransformer>;
let composite = CompositeTransformer::new(vec![transformer1, transformer2]);
let mut tool_call = ToolCall {
id: "call_1".to_string(),
name: "search".to_string(),
arguments: json!({"query": "test"}),
};
composite.transform(&mut tool_call).unwrap();
assert_eq!(tool_call.arguments["limit"], 10);
}
#[test]
fn test_composite_transformer_add() {
let mut composite = CompositeTransformer::new(vec![]);
composite.add(Box::new(NopToolTransformer));
composite.add(Box::new(LimitInjector));
assert_eq!(composite.transformers.len(), 2);
}
#[test]
fn test_composite_transformer_blocking() {
let transformer1 = Box::new(NopToolTransformer) as Box<dyn ToolCallTransformer>;
let transformer2 = Box::new(BlockingTransformer {
blocked_tools: vec!["blocked".to_string()],
}) as Box<dyn ToolCallTransformer>;
let composite = CompositeTransformer::new(vec![transformer1, transformer2]);
let mut tool_call = ToolCall {
id: "call_1".to_string(),
name: "blocked".to_string(),
arguments: json!({}),
};
let result = composite.transform(&mut tool_call);
assert!(result.is_err());
}
}