use serde::{Deserialize, Serialize};
use std::fmt;
use std::sync::Arc;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum Content {
#[serde(rename = "text")]
Text { text: String },
#[serde(rename = "image")]
Image {
data: String,
#[serde(rename = "mimeType")]
mime_type: String,
},
#[serde(rename = "thinking")]
Thinking {
thinking: String,
#[serde(skip_serializing_if = "Option::is_none")]
signature: Option<String>,
},
#[serde(rename = "toolCall")]
ToolCall {
id: String,
name: String,
arguments: serde_json::Value,
#[serde(default, skip_serializing_if = "Option::is_none")]
provider_metadata: Option<serde_json::Value>,
},
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(tag = "role")]
pub enum Message {
#[serde(rename = "user")]
User {
content: Vec<Content>,
timestamp: u64,
},
#[serde(rename = "assistant")]
Assistant {
content: Vec<Content>,
#[serde(rename = "stopReason")]
stop_reason: StopReason,
model: String,
provider: String,
usage: Usage,
timestamp: u64,
#[serde(skip_serializing_if = "Option::is_none")]
error_message: Option<String>,
},
#[serde(rename = "toolResult")]
ToolResult {
#[serde(rename = "toolCallId")]
tool_call_id: String,
#[serde(rename = "toolName")]
tool_name: String,
content: Vec<Content>,
#[serde(rename = "isError")]
is_error: bool,
timestamp: u64,
},
}
impl Message {
pub fn user(text: impl Into<String>) -> Self {
Self::User {
content: vec![Content::Text { text: text.into() }],
timestamp: now_ms(),
}
}
pub fn role(&self) -> &str {
match self {
Self::User { .. } => "user",
Self::Assistant { .. } => "assistant",
Self::ToolResult { .. } => "toolResult",
}
}
pub fn is_context_overflow(&self) -> bool {
match self {
Self::Assistant {
stop_reason: StopReason::Error,
error_message: Some(msg),
..
} => crate::provider::is_context_overflow_message(msg),
_ => false,
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ExtensionMessage {
pub role: String,
pub kind: String,
pub data: serde_json::Value,
}
impl ExtensionMessage {
pub fn new(kind: impl Into<String>, data: impl Serialize) -> Self {
Self {
role: "extension".into(),
kind: kind.into(),
data: serde_json::to_value(data).unwrap_or(serde_json::Value::Null),
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(untagged)]
pub enum AgentMessage {
Llm(Message),
Extension(ExtensionMessage),
}
impl AgentMessage {
pub fn role(&self) -> &str {
match self {
Self::Llm(m) => m.role(),
Self::Extension(ext) => &ext.role,
}
}
pub fn as_llm(&self) -> Option<&Message> {
match self {
Self::Llm(m) => Some(m),
Self::Extension(_) => None,
}
}
}
impl From<Message> for AgentMessage {
fn from(m: Message) -> Self {
Self::Llm(m)
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "camelCase")]
pub enum StopReason {
Stop,
Length,
ToolUse,
Error,
Aborted,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
pub struct Usage {
pub input: u64,
pub output: u64,
#[serde(default)]
pub cache_read: u64,
#[serde(default)]
pub cache_write: u64,
#[serde(default)]
pub total_tokens: u64,
}
impl Usage {
pub fn cache_hit_rate(&self) -> f64 {
let total_input = self.input + self.cache_read + self.cache_write;
if total_input == 0 {
return 0.0;
}
self.cache_read as f64 / total_input as f64
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct CacheConfig {
pub enabled: bool,
pub strategy: CacheStrategy,
}
impl Default for CacheConfig {
fn default() -> Self {
Self {
enabled: true,
strategy: CacheStrategy::Auto,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq, Eq)]
#[serde(tag = "type", rename_all = "camelCase")]
pub enum ToolExecutionStrategy {
Sequential,
#[default]
Parallel,
Batched { size: usize },
}
#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq, Eq)]
#[serde(rename_all = "camelCase")]
pub enum CacheStrategy {
#[default]
Auto,
Disabled,
Manual {
cache_system: bool,
cache_tools: bool,
cache_messages: bool,
},
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
#[serde(rename_all = "lowercase")]
pub enum ThinkingLevel {
#[default]
Off,
Minimal,
Low,
Medium,
High,
}
pub type ToolUpdateFn = Arc<dyn Fn(ToolResult) + Send + Sync>;
pub type ProgressFn = Arc<dyn Fn(String) + Send + Sync>;
pub struct ToolContext {
pub tool_call_id: String,
pub tool_name: String,
pub cancel: tokio_util::sync::CancellationToken,
pub on_update: Option<ToolUpdateFn>,
pub on_progress: Option<ProgressFn>,
}
impl Clone for ToolContext {
fn clone(&self) -> Self {
Self {
tool_call_id: self.tool_call_id.clone(),
tool_name: self.tool_name.clone(),
cancel: self.cancel.clone(),
on_update: self.on_update.clone(),
on_progress: self.on_progress.clone(),
}
}
}
impl std::fmt::Debug for ToolContext {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ToolContext")
.field("tool_call_id", &self.tool_call_id)
.field("tool_name", &self.tool_name)
.field("cancel", &self.cancel)
.field("on_update", &self.on_update.as_ref().map(|_| "<callback>"))
.field(
"on_progress",
&self.on_progress.as_ref().map(|_| "<callback>"),
)
.finish()
}
}
#[async_trait::async_trait]
pub trait AgentTool: Send + Sync {
fn name(&self) -> &str;
fn label(&self) -> &str;
fn description(&self) -> &str;
fn parameters_schema(&self) -> serde_json::Value;
async fn execute(
&self,
params: serde_json::Value,
ctx: ToolContext,
) -> Result<ToolResult, ToolError>;
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ToolResult {
pub content: Vec<Content>,
#[serde(default)]
pub details: serde_json::Value,
}
#[derive(Debug, thiserror::Error)]
pub enum ToolError {
#[error("{0}")]
Failed(String),
#[error("Tool not found: {0}")]
NotFound(String),
#[error("Invalid arguments: {0}")]
InvalidArgs(String),
#[error("Cancelled")]
Cancelled,
}
#[derive(Debug, Clone)]
pub enum AgentEvent {
AgentStart,
AgentEnd {
messages: Vec<AgentMessage>,
},
TurnStart,
TurnEnd {
message: AgentMessage,
tool_results: Vec<Message>,
},
MessageStart {
message: AgentMessage,
},
MessageUpdate {
message: AgentMessage,
delta: StreamDelta,
},
MessageEnd {
message: AgentMessage,
},
ToolExecutionStart {
tool_call_id: String,
tool_name: String,
args: serde_json::Value,
},
ToolExecutionUpdate {
tool_call_id: String,
tool_name: String,
partial_result: ToolResult,
},
ToolExecutionEnd {
tool_call_id: String,
tool_name: String,
result: ToolResult,
is_error: bool,
},
ProgressMessage {
tool_call_id: String,
tool_name: String,
text: String,
},
InputRejected {
reason: String,
},
}
#[derive(Debug, Clone)]
pub enum StreamDelta {
Text { delta: String },
Thinking { delta: String },
ToolCallDelta { delta: String },
}
pub struct AgentContext {
pub system_prompt: String,
pub messages: Vec<AgentMessage>,
pub tools: Vec<Box<dyn AgentTool>>,
}
#[derive(Debug, Clone)]
pub enum FilterResult {
Pass,
Warn(String),
Reject(String),
}
pub trait InputFilter: Send + Sync {
fn filter(&self, text: &str) -> FilterResult;
}
pub fn now_ms() -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64
}
impl fmt::Display for StopReason {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Stop => write!(f, "stop"),
Self::Length => write!(f, "length"),
Self::ToolUse => write!(f, "toolUse"),
Self::Error => write!(f, "error"),
Self::Aborted => write!(f, "aborted"),
}
}
}