use serde::{Deserialize, Serialize};
use crate::ToolReturn;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DeferredToolCall {
pub tool_name: String,
pub args: serde_json::Value,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
}
impl DeferredToolCall {
#[must_use]
pub fn new(tool_name: impl Into<String>, args: serde_json::Value) -> Self {
Self {
tool_name: tool_name.into(),
args,
tool_call_id: None,
}
}
#[must_use]
pub fn with_tool_call_id(mut self, id: impl Into<String>) -> Self {
self.tool_call_id = Some(id.into());
self
}
#[must_use]
pub fn approve(&self) -> DeferredToolDecision {
DeferredToolDecision::Approved
}
#[must_use]
pub fn deny(&self, message: impl Into<String>) -> DeferredToolDecision {
DeferredToolDecision::Denied(message.into())
}
#[must_use]
pub fn with_result(&self, result: ToolReturn) -> DeferredToolDecision {
DeferredToolDecision::CustomResult(result)
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct DeferredToolRequests {
pub calls: Vec<DeferredToolCall>,
}
impl DeferredToolRequests {
#[must_use]
pub fn new() -> Self {
Self { calls: Vec::new() }
}
pub fn add(&mut self, call: DeferredToolCall) {
self.calls.push(call);
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.calls.is_empty()
}
#[must_use]
pub fn len(&self) -> usize {
self.calls.len()
}
#[must_use]
pub fn get(&self, index: usize) -> Option<&DeferredToolCall> {
self.calls.get(index)
}
pub fn iter(&self) -> impl Iterator<Item = &DeferredToolCall> {
self.calls.iter()
}
#[must_use]
pub fn by_tool(&self, name: &str) -> Vec<&DeferredToolCall> {
self.calls.iter().filter(|c| c.tool_name == name).collect()
}
pub fn clear(&mut self) {
self.calls.clear();
}
#[must_use]
pub fn approve_all(&self) -> DeferredToolDecisions {
DeferredToolDecisions {
decisions: self
.calls
.iter()
.map(|_| DeferredToolDecision::Approved)
.collect(),
}
}
#[must_use]
pub fn deny_all(&self, message: impl Into<String>) -> DeferredToolDecisions {
let msg = message.into();
DeferredToolDecisions {
decisions: self
.calls
.iter()
.map(|_| DeferredToolDecision::Denied(msg.clone()))
.collect(),
}
}
}
impl FromIterator<DeferredToolCall> for DeferredToolRequests {
fn from_iter<T: IntoIterator<Item = DeferredToolCall>>(iter: T) -> Self {
Self {
calls: iter.into_iter().collect(),
}
}
}
#[derive(Debug, Clone)]
pub enum DeferredToolDecision {
Approved,
Denied(String),
CustomResult(ToolReturn),
}
impl DeferredToolDecision {
#[must_use]
pub fn is_approved(&self) -> bool {
matches!(self, Self::Approved)
}
#[must_use]
pub fn is_denied(&self) -> bool {
matches!(self, Self::Denied(_))
}
#[must_use]
pub fn is_custom(&self) -> bool {
matches!(self, Self::CustomResult(_))
}
#[must_use]
pub fn denial_message(&self) -> Option<&str> {
match self {
Self::Denied(msg) => Some(msg),
_ => None,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct DeferredToolDecisions {
pub decisions: Vec<DeferredToolDecision>,
}
impl DeferredToolDecisions {
#[must_use]
pub fn new() -> Self {
Self {
decisions: Vec::new(),
}
}
pub fn add(&mut self, decision: DeferredToolDecision) {
self.decisions.push(decision);
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.decisions.is_empty()
}
#[must_use]
pub fn len(&self) -> usize {
self.decisions.len()
}
#[must_use]
pub fn all_approved(&self) -> bool {
self.decisions.iter().all(|d| d.is_approved())
}
#[must_use]
pub fn any_denied(&self) -> bool {
self.decisions.iter().any(|d| d.is_denied())
}
}
impl FromIterator<DeferredToolDecision> for DeferredToolDecisions {
fn from_iter<T: IntoIterator<Item = DeferredToolDecision>>(iter: T) -> Self {
Self {
decisions: iter.into_iter().collect(),
}
}
}
#[derive(Debug, Clone)]
pub struct DeferredToolResult {
pub tool_call_id: Option<String>,
pub result: ToolReturn,
}
impl DeferredToolResult {
#[must_use]
pub fn new(result: ToolReturn) -> Self {
Self {
tool_call_id: None,
result,
}
}
#[must_use]
pub fn with_tool_call_id(mut self, id: impl Into<String>) -> Self {
self.tool_call_id = Some(id.into());
self
}
#[must_use]
pub fn approved() -> Self {
Self::new(ToolReturn::text("Tool execution approved"))
}
#[must_use]
pub fn denied(message: impl Into<String>) -> Self {
Self::new(ToolReturn::error(message))
}
}
#[derive(Debug, Clone, Default)]
pub struct DeferredToolResults {
pub results: Vec<DeferredToolResult>,
}
impl DeferredToolResults {
#[must_use]
pub fn new() -> Self {
Self {
results: Vec::new(),
}
}
pub fn add(&mut self, result: DeferredToolResult) {
self.results.push(result);
}
#[must_use]
pub fn approved(id: Option<String>) -> Self {
let mut result = DeferredToolResult::approved();
if let Some(id) = id {
result = result.with_tool_call_id(id);
}
Self {
results: vec![result],
}
}
#[must_use]
pub fn denied(id: Option<String>, message: impl Into<String>) -> Self {
let mut result = DeferredToolResult::denied(message);
if let Some(id) = id {
result = result.with_tool_call_id(id);
}
Self {
results: vec![result],
}
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.results.is_empty()
}
#[must_use]
pub fn len(&self) -> usize {
self.results.len()
}
}
impl FromIterator<DeferredToolResult> for DeferredToolResults {
fn from_iter<T: IntoIterator<Item = DeferredToolResult>>(iter: T) -> Self {
Self {
results: iter.into_iter().collect(),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ToolApproved;
#[derive(Debug, Clone)]
pub struct ToolDenied {
pub message: String,
}
impl ToolDenied {
#[must_use]
pub fn new(message: impl Into<String>) -> Self {
Self {
message: message.into(),
}
}
}
#[allow(async_fn_in_trait)]
pub trait ToolApprover {
async fn approve(&self, call: &DeferredToolCall) -> DeferredToolDecision;
}
#[derive(Debug, Clone, Copy, Default)]
pub struct AutoApprover;
impl ToolApprover for AutoApprover {
async fn approve(&self, _call: &DeferredToolCall) -> DeferredToolDecision {
DeferredToolDecision::Approved
}
}
#[derive(Debug, Clone)]
pub struct AutoDenier {
message: String,
}
impl AutoDenier {
#[must_use]
pub fn new(message: impl Into<String>) -> Self {
Self {
message: message.into(),
}
}
}
impl ToolApprover for AutoDenier {
async fn approve(&self, _call: &DeferredToolCall) -> DeferredToolDecision {
DeferredToolDecision::Denied(self.message.clone())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_deferred_tool_call() {
let call = DeferredToolCall::new("my_tool", serde_json::json!({"x": 1}))
.with_tool_call_id("call_123");
assert_eq!(call.tool_name, "my_tool");
assert_eq!(call.tool_call_id, Some("call_123".to_string()));
}
#[test]
fn test_deferred_tool_requests() {
let mut requests = DeferredToolRequests::new();
assert!(requests.is_empty());
requests.add(DeferredToolCall::new("tool1", serde_json::json!({})));
requests.add(DeferredToolCall::new("tool2", serde_json::json!({})));
assert_eq!(requests.len(), 2);
assert!(!requests.is_empty());
}
#[test]
fn test_by_tool() {
let mut requests = DeferredToolRequests::new();
requests.add(DeferredToolCall::new("tool1", serde_json::json!({})));
requests.add(DeferredToolCall::new("tool2", serde_json::json!({})));
requests.add(DeferredToolCall::new("tool1", serde_json::json!({})));
let tool1_calls = requests.by_tool("tool1");
assert_eq!(tool1_calls.len(), 2);
}
#[test]
fn test_approve_all() {
let mut requests = DeferredToolRequests::new();
requests.add(DeferredToolCall::new("tool1", serde_json::json!({})));
requests.add(DeferredToolCall::new("tool2", serde_json::json!({})));
let decisions = requests.approve_all();
assert_eq!(decisions.len(), 2);
assert!(decisions.all_approved());
}
#[test]
fn test_deny_all() {
let mut requests = DeferredToolRequests::new();
requests.add(DeferredToolCall::new("tool1", serde_json::json!({})));
let decisions = requests.deny_all("Not allowed");
assert!(decisions.any_denied());
}
#[test]
fn test_deferred_tool_decision() {
let approved = DeferredToolDecision::Approved;
assert!(approved.is_approved());
assert!(!approved.is_denied());
let denied = DeferredToolDecision::Denied("No".into());
assert!(denied.is_denied());
assert_eq!(denied.denial_message(), Some("No"));
let custom = DeferredToolDecision::CustomResult(ToolReturn::text("custom"));
assert!(custom.is_custom());
}
#[test]
fn test_deferred_tool_result() {
let result = DeferredToolResult::approved().with_tool_call_id("id1");
assert_eq!(result.tool_call_id, Some("id1".to_string()));
let denied = DeferredToolResult::denied("Not allowed");
assert!(denied.result.is_error());
}
#[test]
fn test_deferred_tool_results() {
let results = DeferredToolResults::approved(Some("id1".to_string()));
assert_eq!(results.len(), 1);
let denied = DeferredToolResults::denied(None, "Nope");
assert_eq!(denied.len(), 1);
}
#[test]
fn test_tool_denied() {
let denied = ToolDenied::new("Not allowed");
assert_eq!(denied.message, "Not allowed");
}
#[tokio::test]
async fn test_auto_approver() {
let approver = AutoApprover;
let call = DeferredToolCall::new("test", serde_json::json!({}));
let decision = approver.approve(&call).await;
assert!(decision.is_approved());
}
#[tokio::test]
async fn test_auto_denier() {
let denier = AutoDenier::new("Denied");
let call = DeferredToolCall::new("test", serde_json::json!({}));
let decision = denier.approve(&call).await;
assert!(decision.is_denied());
}
#[test]
fn test_serde_roundtrip() {
let call =
DeferredToolCall::new("test", serde_json::json!({"x": 1})).with_tool_call_id("id");
let json = serde_json::to_string(&call).unwrap();
let parsed: DeferredToolCall = serde_json::from_str(&json).unwrap();
assert_eq!(call.tool_name, parsed.tool_name);
assert_eq!(call.tool_call_id, parsed.tool_call_id);
}
}