use std::collections::HashMap;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub role: MessageRole,
pub content: String,
#[serde(skip_serializing_if = "Option::is_none", default)]
pub attachments: Option<Vec<Attachment>>,
#[serde(skip_serializing_if = "Option::is_none", default)]
pub tool_call_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none", default)]
pub tool_calls: Option<Vec<ToolCall>>,
#[serde(skip_serializing_if = "Option::is_none", default)]
pub is_error: Option<bool>,
}
impl Default for Message {
fn default() -> Self {
Self {
role: MessageRole::User,
content: String::new(),
attachments: None,
tool_call_id: None,
tool_calls: None,
is_error: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCall {
pub id: String,
pub name: String,
pub arguments: serde_json::Value,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum MessageRole {
#[default]
User,
Assistant,
#[serde(rename = "tool")]
Tool,
System,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum Attachment {
File { path: String },
AlreadyReadFile { path: String, content: String },
PdfReference { path: String },
EditedTextFile { filename: String, snippet: String },
EditedImageFile { filename: String },
Directory {
path: String,
content: String,
display_path: String,
},
SelectedLinesInIde {
ide_name: String,
filename: String,
start_line: u32,
end_line: u32,
},
MemoryFile { path: String },
SkillListing { skills: Vec<SkillInfo> },
InvokedSkills { skills: Vec<InvokedSkill> },
TaskStatus {
task_id: String,
description: String,
status: String,
},
PlanFileReference { path: String },
McpResources { tools: Vec<String> },
DeferredTools { tools: Vec<String> },
AgentListing { agents: Vec<String> },
Custom {
name: String,
content: serde_json::Value,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SkillInfo {
pub name: String,
pub description: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InvokedSkill {
pub name: String,
pub path: String,
pub content: String,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct TokenUsage {
pub input_tokens: u64,
pub output_tokens: u64,
#[serde(skip_serializing_if = "Option::is_none")]
pub cache_creation_input_tokens: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub cache_read_input_tokens: Option<u64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolDefinition {
pub name: String,
pub description: String,
pub input_schema: ToolInputSchema,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub annotations: Option<ToolAnnotations>,
}
impl Default for ToolDefinition {
fn default() -> Self {
Self {
name: String::new(),
description: String::new(),
input_schema: ToolInputSchema::default(),
annotations: None,
}
}
}
impl ToolDefinition {
pub fn new(name: &str, description: &str, input_schema: ToolInputSchema) -> Self {
Self {
name: name.to_string(),
description: description.to_string(),
input_schema,
annotations: None,
}
}
pub fn is_concurrency_safe(&self, _input: &serde_json::Value) -> bool {
self.annotations
.as_ref()
.and_then(|a| a.concurrency_safe)
.unwrap_or(false)
}
pub fn is_read_only(&self, _input: &serde_json::Value) -> bool {
if let Some(ref a) = self.annotations {
if let Some(ro) = a.read_only {
return ro;
}
}
matches!(
self.name.as_str(),
"Read" | "Glob" | "Grep" | "Search" | "WebFetch" | "WebSearch"
)
}
pub fn is_destructive(&self, input: &serde_json::Value) -> bool {
if let Some(ref a) = self.annotations {
if let Some(d) = a.destructive {
return d;
}
}
let input_str = input.to_string();
matches!(self.name.as_str(), "Bash" | "Write" | "Edit")
&& (input_str.contains("rm -rf")
|| input_str.contains("rm /")
|| input_str.contains("dd if=")
|| input_str.contains("format"))
}
pub fn is_idempotent(&self) -> bool {
self.annotations
.as_ref()
.and_then(|a| a.idempotent)
.unwrap_or(false)
}
pub fn get_use_summary(&self, input: &serde_json::Value) -> String {
match self.name.as_str() {
"Bash" => {
if let Some(cmd) = input.get("command").and_then(|v| v.as_str()) {
let truncated = if cmd.len() > 50 {
format!("{}...", &cmd[..50])
} else {
cmd.to_string()
};
format!("Bash: {}", truncated)
} else {
"Bash".to_string()
}
}
"Read" => {
if let Some(path) = input.get("path").and_then(|v| v.as_str()) {
format!("Read: {}", path)
} else {
"Read".to_string()
}
}
"Write" => {
if let Some(path) = input.get("path").and_then(|v| v.as_str()) {
format!("Write: {}", path)
} else {
"Write".to_string()
}
}
"Edit" => {
if let Some(path) = input.get("file_path").and_then(|v| v.as_str()) {
format!("Edit: {}", path)
} else {
"Edit".to_string()
}
}
"Glob" => {
if let Some(pattern) = input.get("pattern").and_then(|v| v.as_str()) {
format!("Glob: {}", pattern)
} else {
"Glob".to_string()
}
}
"Grep" => {
if let Some(pattern) = input.get("pattern").and_then(|v| v.as_str()) {
format!("Grep: {}", pattern)
} else {
"Grep".to_string()
}
}
_ => self.name.clone(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct ToolAnnotations {
#[serde(rename = "concurrencySafe", skip_serializing_if = "Option::is_none")]
pub concurrency_safe: Option<bool>,
#[serde(rename = "readOnly", skip_serializing_if = "Option::is_none")]
pub read_only: Option<bool>,
#[serde(rename = "destructive", skip_serializing_if = "Option::is_none")]
pub destructive: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub idempotent: Option<bool>,
#[serde(rename = "openWorld", skip_serializing_if = "Option::is_none")]
pub open_world: Option<bool>,
}
impl ToolAnnotations {
pub fn read_only() -> Self {
Self {
read_only: Some(true),
..Default::default()
}
}
pub fn destructive() -> Self {
Self {
destructive: Some(true),
..Default::default()
}
}
pub fn concurrency_safe() -> Self {
Self {
concurrency_safe: Some(true),
..Default::default()
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct ToolInputSchema {
#[serde(rename = "type")]
pub schema_type: String,
pub properties: serde_json::Value,
pub required: Option<Vec<String>>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct ToolContext {
pub cwd: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub abort_signal: Option<()>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolResult {
#[serde(rename = "type")]
pub result_type: String,
pub tool_use_id: String,
pub content: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub is_error: Option<bool>,
}
#[derive(Clone, Serialize, Deserialize)]
pub struct AgentOptions {
pub model: Option<String>,
pub api_key: Option<String>,
pub base_url: Option<String>,
pub cwd: Option<String>,
pub system_prompt: Option<String>,
pub max_turns: Option<u32>,
pub max_budget_usd: Option<f64>,
pub max_tokens: Option<u32>,
#[serde(default)]
pub tools: Vec<ToolDefinition>,
#[serde(default)]
pub allowed_tools: Vec<String>,
#[serde(default)]
pub disallowed_tools: Vec<String>,
#[serde(default)]
pub mcp_servers: Option<std::collections::HashMap<String, McpServerConfig>>,
#[serde(skip)]
pub on_event: Option<std::sync::Arc<dyn Fn(AgentEvent) + Send + Sync>>,
}
impl Default for AgentOptions {
fn default() -> Self {
Self {
model: None,
api_key: None,
base_url: None,
cwd: None,
system_prompt: None,
max_turns: None,
max_budget_usd: None,
max_tokens: None,
tools: Vec::new(),
allowed_tools: Vec::new(),
disallowed_tools: Vec::new(),
mcp_servers: None,
on_event: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ExitReason {
Completed,
MaxTurns { max_turns: u32, turn_count: u32 },
AbortedStreaming { reason: String },
AbortedTools { reason: String },
HookStopped,
StopHookPrevented,
PromptTooLong { error: Option<String> },
ImageError { error: String },
ModelError { error: String },
BlockingLimit,
TokenBudgetExhausted { reason: String },
}
impl Default for ExitReason {
fn default() -> Self {
ExitReason::Completed
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QueryResult {
pub text: String,
pub usage: TokenUsage,
pub num_turns: u32,
pub duration_ms: u64,
pub exit_reason: ExitReason,
}
#[derive(Debug, Clone)]
pub enum AgentEvent {
ToolStart {
tool_name: String,
tool_call_id: String,
input: serde_json::Value,
},
ToolComplete {
tool_name: String,
tool_call_id: String,
result: ToolResult,
},
ToolError {
tool_name: String,
tool_call_id: String,
error: String,
},
Thinking { turn: u32 },
Done { result: QueryResult },
MessageStart { message_id: String },
ContentBlockStart { index: u32, block_type: String },
ContentBlockDelta { index: u32, delta: ContentDelta },
ContentBlockStop { index: u32 },
MessageStop,
RequestStart,
MaxTurnsReached { max_turns: u32, turn_count: u32 },
Tombstone { message: String },
}
#[derive(Debug, Clone)]
pub enum ContentDelta {
Text { text: String },
ToolUse {
id: String,
name: String,
input: serde_json::Value,
is_complete: bool,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum McpServerConfig {
Stdio(McpStdioConfig),
Sse(McpSseConfig),
Http(McpHttpConfig),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct McpStdioConfig {
#[serde(default = "default_stdio_type")]
pub transport_type: Option<String>,
pub command: String,
pub args: Option<Vec<String>>,
pub env: Option<std::collections::HashMap<String, String>>,
}
fn default_stdio_type() -> Option<String> {
Some("stdio".to_string())
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct McpSseConfig {
pub transport_type: String,
pub url: String,
pub headers: Option<std::collections::HashMap<String, String>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct McpHttpConfig {
pub transport_type: String,
pub url: String,
pub headers: Option<std::collections::HashMap<String, String>>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum McpConnectionStatus {
Connected,
Disconnected,
Error,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct McpTool {
pub name: String,
pub description: Option<String>,
#[serde(rename = "inputSchema")]
pub input_schema: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QueryChainTracking {
pub chain_id: String,
pub depth: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "result")]
pub enum ValidationResult {
#[serde(rename = "true")]
Valid,
Invalid {
message: String,
#[serde(rename = "errorCode")]
error_code: i32,
},
}
impl ValidationResult {
pub fn valid() -> Self {
ValidationResult::Valid
}
pub fn invalid(message: String, error_code: i32) -> Self {
ValidationResult::Invalid { message, error_code }
}
pub fn is_valid(&self) -> bool {
matches!(self, ValidationResult::Valid)
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum PermissionMode {
Default,
Auto,
#[serde(rename = "auto-accept")]
AutoAccept,
#[serde(rename = "auto-deny")]
AutoDeny,
Bypass,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AdditionalWorkingDirectory {
pub path: String,
#[serde(rename = "permissionMode")]
pub permission_mode: Option<PermissionMode>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PermissionResult {
pub behavior: PermissionBehavior,
#[serde(rename = "updatedInput")]
pub updated_input: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub message: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "kebab-case")]
pub enum PermissionBehavior {
Allow,
Deny,
Ask,
}
pub type ToolPermissionRulesBySource = HashMap<String, Vec<String>>;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolPermissionContext {
pub mode: PermissionMode,
#[serde(rename = "additionalWorkingDirectories")]
pub additional_working_directories: HashMap<String, AdditionalWorkingDirectory>,
#[serde(rename = "alwaysAllowRules")]
pub always_allow_rules: ToolPermissionRulesBySource,
#[serde(rename = "alwaysDenyRules")]
pub always_deny_rules: ToolPermissionRulesBySource,
#[serde(rename = "alwaysAskRules")]
pub always_ask_rules: ToolPermissionRulesBySource,
#[serde(rename = "isBypassPermissionsModeAvailable")]
pub is_bypass_permissions_mode_available: bool,
#[serde(
rename = "isAutoModeAvailable",
skip_serializing_if = "Option::is_none"
)]
pub is_auto_mode_available: Option<bool>,
#[serde(
rename = "strippedDangerousRules",
skip_serializing_if = "Option::is_none"
)]
pub stripped_dangerous_rules: Option<ToolPermissionRulesBySource>,
#[serde(
rename = "shouldAvoidPermissionPrompts",
skip_serializing_if = "Option::is_none"
)]
pub should_avoid_permission_prompts: Option<bool>,
#[serde(
rename = "awaitAutomatedChecksBeforeDialog",
skip_serializing_if = "Option::is_none"
)]
pub await_automated_checks_before_dialog: Option<bool>,
#[serde(rename = "prePlanMode", skip_serializing_if = "Option::is_none")]
pub pre_plan_mode: Option<PermissionMode>,
}
impl Default for ToolPermissionContext {
fn default() -> Self {
Self {
mode: PermissionMode::Default,
additional_working_directories: HashMap::new(),
always_allow_rules: HashMap::new(),
always_deny_rules: HashMap::new(),
always_ask_rules: HashMap::new(),
is_bypass_permissions_mode_available: false,
is_auto_mode_available: None,
stripped_dangerous_rules: None,
should_avoid_permission_prompts: None,
await_automated_checks_before_dialog: None,
pre_plan_mode: None,
}
}
}
pub fn get_empty_tool_permission_context() -> ToolPermissionContext {
ToolPermissionContext::default()
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum CompactProgressEvent {
#[serde(rename = "hooks_start")]
HooksStart {
#[serde(rename = "hookType")]
hook_type: CompactHookType,
},
#[serde(rename = "compact_start")]
CompactStart,
#[serde(rename = "compact_end")]
CompactEnd,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum CompactHookType {
PreCompact,
PostCompact,
SessionStart,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolInputJSONSchema {
#[serde(flatten)]
pub properties: serde_json::Value,
#[serde(rename = "type")]
pub schema_type: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BashProgress {
#[serde(rename = "shell")]
pub shell: Option<String>,
#[serde(rename = "command")]
pub command: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ReplProgress {
#[serde(rename = "input")]
pub input: Option<String>,
#[serde(rename = "toolName")]
pub tool_name: Option<String>,
#[serde(rename = "toolCallId")]
pub tool_call_id: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct McpProgress {
#[serde(rename = "serverName")]
pub server_name: String,
#[serde(rename = "toolName")]
pub tool_name: String,
#[serde(rename = "progress")]
pub progress: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WebSearchProgress {
#[serde(rename = "query")]
pub query: String,
#[serde(rename = "currentStep")]
pub current_step: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TaskOutputProgress {
#[serde(rename = "taskId")]
pub task_id: String,
#[serde(rename = "output")]
pub output: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SkillToolProgress {
#[serde(rename = "skill")]
pub skill: String,
#[serde(rename = "step")]
pub step: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AgentToolProgress {
#[serde(rename = "description")]
pub description: String,
#[serde(rename = "subagentType")]
pub subagent_type: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum ToolProgressData {
#[serde(rename = "bash_progress")]
BashProgress(BashProgress),
#[serde(rename = "repl_progress")]
ReplProgress(ReplProgress),
#[serde(rename = "mcp_progress")]
McpProgress(McpProgress),
#[serde(rename = "web_search_progress")]
WebSearchProgress(WebSearchProgress),
#[serde(rename = "task_output_progress")]
TaskOutputProgress(TaskOutputProgress),
#[serde(rename = "skill_progress")]
SkillProgress(SkillToolProgress),
#[serde(rename = "agent_progress")]
AgentProgress(AgentToolProgress),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolProgress<P: Clone + serde::Serialize> {
#[serde(rename = "toolUseID")]
pub tool_use_id: String,
pub data: P,
}
pub fn filter_tool_progress_messages(
progress_messages: &[serde_json::Value],
) -> Vec<serde_json::Value> {
progress_messages
.iter()
.filter(|msg| {
let data_type = msg.get("data").and_then(|d| d.get("type"));
data_type.map(|t| t != "hook_progress").unwrap_or(true)
})
.cloned()
.collect()
}