use std::collections::HashMap;
use std::sync::Arc;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use thiserror::Error;
use super::event::AgentEvent;
use super::event_sink::EventSink;
use super::identity::RunIdentity;
use super::progress::{ProgressStatus, TOOL_CALL_PROGRESS_ACTIVITY_TYPE, ToolCallProgressState};
use super::suspension::ToolCallResume;
use crate::cancellation::CancellationToken;
use crate::registry_spec::AgentSpec;
use crate::state::{Snapshot, StateCommand, StateKey};
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ToolStatus {
Success,
Pending,
Error,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolResult {
pub tool_name: String,
pub status: ToolStatus,
pub data: Value,
pub message: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub suspension: Option<Box<crate::contract::suspension::SuspendTicket>>,
#[serde(default, skip_serializing_if = "HashMap::is_empty")]
pub metadata: HashMap<String, Value>,
}
impl ToolResult {
pub fn success(tool_name: impl Into<String>, data: impl Into<Value>) -> Self {
Self {
tool_name: tool_name.into(),
status: ToolStatus::Success,
data: data.into(),
message: None,
suspension: None,
metadata: HashMap::new(),
}
}
pub fn success_with_message(
tool_name: impl Into<String>,
data: impl Into<Value>,
message: impl Into<String>,
) -> Self {
Self {
tool_name: tool_name.into(),
status: ToolStatus::Success,
data: data.into(),
message: Some(message.into()),
suspension: None,
metadata: HashMap::new(),
}
}
pub fn error(tool_name: impl Into<String>, message: impl Into<String>) -> Self {
Self {
tool_name: tool_name.into(),
status: ToolStatus::Error,
data: Value::Null,
message: Some(message.into()),
suspension: None,
metadata: HashMap::new(),
}
}
pub fn error_with_code(
tool_name: impl Into<String>,
code: impl Into<String>,
message: impl Into<String>,
) -> Self {
let code = code.into();
let message = message.into();
Self {
tool_name: tool_name.into(),
status: ToolStatus::Error,
data: serde_json::json!({
"error": {
"code": code,
"message": message,
}
}),
message: Some(format!("[{code}] {message}")),
suspension: None,
metadata: HashMap::new(),
}
}
pub fn suspended(tool_name: impl Into<String>, message: impl Into<String>) -> Self {
Self {
tool_name: tool_name.into(),
status: ToolStatus::Pending,
data: Value::Null,
message: Some(message.into()),
suspension: None,
metadata: HashMap::new(),
}
}
pub fn suspended_with(
tool_name: impl Into<String>,
message: impl Into<String>,
ticket: crate::contract::suspension::SuspendTicket,
) -> Self {
Self {
tool_name: tool_name.into(),
status: ToolStatus::Pending,
data: Value::Null,
message: Some(message.into()),
suspension: Some(Box::new(ticket)),
metadata: HashMap::new(),
}
}
pub fn is_success(&self) -> bool {
matches!(self.status, ToolStatus::Success)
}
pub fn is_pending(&self) -> bool {
matches!(self.status, ToolStatus::Pending)
}
pub fn is_error(&self) -> bool {
matches!(self.status, ToolStatus::Error)
}
pub fn to_json(&self) -> Value {
serde_json::to_value(self).unwrap_or(Value::Null)
}
#[must_use]
pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<Value>) -> Self {
self.metadata.insert(key.into(), value.into());
self
}
}
pub struct ToolOutput {
pub result: ToolResult,
pub command: StateCommand,
}
impl std::fmt::Debug for ToolOutput {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ToolOutput")
.field("result", &self.result)
.finish_non_exhaustive()
}
}
impl ToolOutput {
pub fn new(result: ToolResult) -> Self {
Self {
result,
command: StateCommand::new(),
}
}
pub fn with_command(result: ToolResult, command: StateCommand) -> Self {
Self { result, command }
}
}
impl From<ToolResult> for ToolOutput {
fn from(result: ToolResult) -> Self {
Self::new(result)
}
}
#[derive(Debug, Error)]
pub enum ToolError {
#[error("Invalid arguments: {0}")]
InvalidArguments(String),
#[error("Execution failed: {0}")]
ExecutionFailed(String),
#[error("Timeout: {0}")]
Timeout(String),
#[error("Cancelled: {0}")]
Cancelled(String),
#[error("Denied: {0}")]
Denied(String),
#[error("Not found: {0}")]
NotFound(String),
#[error("Internal error: {0}")]
Internal(String),
}
#[derive(Debug, Error)]
pub enum ToolValidationError {
#[error("{message}")]
InvalidArgument {
message: String,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolDescriptor {
pub id: String,
pub name: String,
pub description: String,
pub parameters: Value,
pub category: Option<String>,
#[serde(default, skip_serializing_if = "HashMap::is_empty")]
pub metadata: HashMap<String, Value>,
}
impl ToolDescriptor {
pub fn new(
id: impl Into<String>,
name: impl Into<String>,
description: impl Into<String>,
) -> Self {
Self {
id: id.into(),
name: name.into(),
description: description.into(),
parameters: serde_json::json!({"type": "object", "properties": {}}),
category: None,
metadata: HashMap::new(),
}
}
#[must_use]
pub fn with_parameters(mut self, schema: Value) -> Self {
self.parameters = schema;
self
}
#[must_use]
pub fn with_category(mut self, category: impl Into<String>) -> Self {
self.category = Some(category.into());
self
}
#[must_use]
pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<Value>) -> Self {
self.metadata.insert(key.into(), value.into());
self
}
}
#[derive(Clone)]
pub struct ToolCallContext {
pub call_id: String,
pub tool_name: String,
pub run_identity: RunIdentity,
pub agent_spec: Arc<AgentSpec>,
pub snapshot: Snapshot,
pub activity_sink: Option<Arc<dyn EventSink>>,
pub cancellation_token: Option<CancellationToken>,
pub resume_input: Option<ToolCallResume>,
pub suspension_id: Option<String>,
pub suspension_reason: Option<String>,
}
impl ToolCallContext {
pub fn state<K: StateKey>(&self) -> Option<&K::Value> {
self.snapshot.get::<K>()
}
pub async fn report_activity(&self, activity_type: &str, content: &str) {
if let Some(sink) = &self.activity_sink {
sink.emit(AgentEvent::ActivitySnapshot {
message_id: self.call_id.clone(),
activity_type: activity_type.to_string(),
content: serde_json::Value::String(content.to_string()),
replace: Some(true),
})
.await;
}
}
pub async fn report_activity_delta(&self, activity_type: &str, patch: serde_json::Value) {
if let Some(sink) = &self.activity_sink {
let patches = if let serde_json::Value::Array(arr) = patch {
arr
} else {
vec![patch]
};
sink.emit(AgentEvent::ActivityDelta {
message_id: self.call_id.clone(),
activity_type: activity_type.to_string(),
patch: patches,
})
.await;
}
}
pub async fn report_progress(
&self,
status: ProgressStatus,
message: Option<&str>,
progress: Option<f64>,
) {
if let Some(sink) = &self.activity_sink {
let parent_call_id = self.run_identity.parent_tool_call_id.clone();
let parent_node_id = parent_call_id
.as_ref()
.map(|id| format!("tool_call:{id}"))
.or_else(|| Some(format!("run:{}", self.run_identity.run_id)));
let state = ToolCallProgressState {
schema: "tool-call-progress.v1".into(),
node_id: format!("tool_call:{}", self.call_id),
call_id: self.call_id.clone(),
tool_name: self.tool_name.clone(),
status,
progress,
loaded: None,
total: None,
message: message.map(ToOwned::to_owned),
parent_node_id,
parent_call_id,
run_id: Some(self.run_identity.run_id.clone()),
parent_run_id: self.run_identity.parent_run_id.clone(),
thread_id: Some(self.run_identity.thread_id.clone()),
};
let content = serde_json::to_value(&state).unwrap_or_default();
sink.emit(AgentEvent::ActivitySnapshot {
message_id: self.call_id.clone(),
activity_type: TOOL_CALL_PROGRESS_ACTIVITY_TYPE.into(),
content,
replace: Some(true),
})
.await;
}
}
pub async fn stream_output(&self, delta: &str) {
if let Some(sink) = &self.activity_sink {
sink.emit(AgentEvent::ToolCallStreamDelta {
id: self.call_id.clone(),
name: self.tool_name.clone(),
delta: delta.to_string(),
})
.await;
}
}
pub fn test_default() -> Self {
Self {
call_id: String::new(),
tool_name: String::new(),
run_identity: RunIdentity::default(),
agent_spec: Arc::new(AgentSpec::default()),
snapshot: Snapshot::new(0, Arc::new(crate::state::StateMap::default())),
activity_sink: None,
cancellation_token: None,
resume_input: None,
suspension_id: None,
suspension_reason: None,
}
}
}
pub struct FrontEndTool {
descriptor: ToolDescriptor,
}
impl FrontEndTool {
pub fn new(descriptor: ToolDescriptor) -> Self {
Self { descriptor }
}
}
#[async_trait]
impl Tool for FrontEndTool {
fn descriptor(&self) -> ToolDescriptor {
self.descriptor.clone()
}
async fn execute(&self, args: Value, ctx: &ToolCallContext) -> Result<ToolOutput, ToolError> {
let tool_name = &self.descriptor.id;
if let Some(resume) = &ctx.resume_input {
return Ok(ToolResult::success(tool_name, resume.result.clone()).into());
}
let pending_id = if ctx.call_id.trim().is_empty() {
tool_name.clone()
} else {
ctx.call_id.clone()
};
let ticket = crate::contract::suspension::SuspendTicket::use_decision_as_tool_result(
crate::contract::suspension::Suspension {
id: format!("suspend_{pending_id}"),
action: format!("tool:{tool_name}"),
message: format!("Frontend tool '{tool_name}' requires client execution"),
parameters: args.clone(),
response_schema: None,
},
crate::contract::suspension::PendingToolCall::new(pending_id, tool_name, args),
);
Ok(ToolResult::suspended_with(
tool_name,
format!("Tool '{tool_name}' suspended: awaiting decision"),
ticket,
)
.into())
}
}
#[async_trait]
pub trait Tool: Send + Sync {
fn descriptor(&self) -> ToolDescriptor;
fn validate_args(&self, _args: &Value) -> Result<(), ToolError> {
Ok(())
}
async fn execute(&self, args: Value, ctx: &ToolCallContext) -> Result<ToolOutput, ToolError>;
}
#[async_trait]
pub trait TypedTool: Send + Sync {
type Args: for<'de> Deserialize<'de> + schemars::JsonSchema + Send;
fn tool_id(&self) -> &str;
fn name(&self) -> &str;
fn description(&self) -> &str;
fn category(&self) -> Option<&str> {
None
}
fn validate(&self, _args: &Self::Args) -> Result<(), ToolValidationError> {
Ok(())
}
async fn execute(
&self,
args: Self::Args,
ctx: &ToolCallContext,
) -> Result<ToolOutput, ToolError>;
}
#[async_trait]
impl<T: TypedTool> Tool for T {
fn descriptor(&self) -> ToolDescriptor {
let schema = super::tool_schema::generate_tool_schema::<T::Args>();
let mut desc = ToolDescriptor::new(self.tool_id(), self.name(), self.description())
.with_parameters(schema);
if let Some(cat) = self.category() {
desc = desc.with_category(cat);
}
desc
}
fn validate_args(&self, _args: &Value) -> Result<(), ToolError> {
Ok(())
}
async fn execute(&self, args: Value, ctx: &ToolCallContext) -> Result<ToolOutput, ToolError> {
let args = if args.is_null() {
Value::Object(Default::default())
} else {
args
};
let typed: T::Args =
serde_json::from_value(args).map_err(|e| ToolError::InvalidArguments(e.to_string()))?;
self.validate(&typed)
.map_err(|e| ToolError::InvalidArguments(e.to_string()))?;
TypedTool::execute(self, typed, ctx).await
}
}
#[cfg(test)]
mod tests;