1use std::collections::HashMap;
3
4use serde::{Deserialize, Serialize};
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct Message {
8 pub role: MessageRole,
9 pub content: String,
10 #[serde(skip_serializing_if = "Option::is_none", default)]
11 pub attachments: Option<Vec<Attachment>>,
12 #[serde(skip_serializing_if = "Option::is_none", default)]
14 pub tool_call_id: Option<String>,
15 #[serde(skip_serializing_if = "Option::is_none", default)]
17 pub tool_calls: Option<Vec<ToolCall>>,
18 #[serde(skip_serializing_if = "Option::is_none", default)]
20 pub is_error: Option<bool>,
21}
22
23impl Default for Message {
24 fn default() -> Self {
25 Self {
26 role: MessageRole::User,
27 content: String::new(),
28 attachments: None,
29 tool_call_id: None,
30 tool_calls: None,
31 is_error: None,
32 }
33 }
34}
35
36#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct ToolCall {
39 pub id: String,
40 pub name: String,
41 pub arguments: serde_json::Value,
42}
43
44#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
45#[serde(rename_all = "lowercase")]
46pub enum MessageRole {
47 #[default]
48 User,
49 Assistant,
50 #[serde(rename = "tool")]
51 Tool,
52 System,
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize)]
57#[serde(tag = "type")]
58pub enum Attachment {
59 File { path: String },
61 AlreadyReadFile { path: String, content: String },
63 PdfReference { path: String },
65 EditedTextFile { filename: String, snippet: String },
67 EditedImageFile { filename: String },
69 Directory {
71 path: String,
72 content: String,
73 display_path: String,
74 },
75 SelectedLinesInIde {
77 ide_name: String,
78 filename: String,
79 start_line: u32,
80 end_line: u32,
81 },
82 MemoryFile { path: String },
84 SkillListing { skills: Vec<SkillInfo> },
86 InvokedSkills { skills: Vec<InvokedSkill> },
88 TaskStatus {
90 task_id: String,
91 description: String,
92 status: String,
93 },
94 PlanFileReference { path: String },
96 McpResources { tools: Vec<String> },
98 DeferredTools { tools: Vec<String> },
100 AgentListing { agents: Vec<String> },
102 Custom {
104 name: String,
105 content: serde_json::Value,
106 },
107}
108
109#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct SkillInfo {
111 pub name: String,
112 pub description: String,
113}
114
115#[derive(Debug, Clone, Serialize, Deserialize)]
116pub struct InvokedSkill {
117 pub name: String,
118 pub path: String,
119 pub content: String,
120}
121
122#[derive(Debug, Clone, Serialize, Deserialize, Default)]
123pub struct TokenUsage {
124 pub input_tokens: u64,
125 pub output_tokens: u64,
126 #[serde(skip_serializing_if = "Option::is_none")]
127 pub cache_creation_input_tokens: Option<u64>,
128 #[serde(skip_serializing_if = "Option::is_none")]
129 pub cache_read_input_tokens: Option<u64>,
130}
131
132#[derive(Debug, Clone, Serialize, Deserialize)]
133pub struct ToolDefinition {
134 pub name: String,
135 pub description: String,
136 pub input_schema: ToolInputSchema,
137 #[serde(default, skip_serializing_if = "Option::is_none")]
139 pub annotations: Option<ToolAnnotations>,
140}
141
142impl Default for ToolDefinition {
143 fn default() -> Self {
144 Self {
145 name: String::new(),
146 description: String::new(),
147 input_schema: ToolInputSchema::default(),
148 annotations: None,
149 }
150 }
151}
152
153impl ToolDefinition {
154 pub fn new(name: &str, description: &str, input_schema: ToolInputSchema) -> Self {
156 Self {
157 name: name.to_string(),
158 description: description.to_string(),
159 input_schema,
160 annotations: None,
161 }
162 }
163
164 pub fn is_concurrency_safe(&self, _input: &serde_json::Value) -> bool {
166 self.annotations
167 .as_ref()
168 .and_then(|a| a.concurrency_safe)
169 .unwrap_or(false)
170 }
171
172 pub fn is_read_only(&self, _input: &serde_json::Value) -> bool {
174 if let Some(ref a) = self.annotations {
175 if let Some(ro) = a.read_only {
176 return ro;
177 }
178 }
179 matches!(
181 self.name.as_str(),
182 "Read" | "Glob" | "Grep" | "Search" | "WebFetch" | "WebSearch"
183 )
184 }
185
186 pub fn is_destructive(&self, input: &serde_json::Value) -> bool {
188 if let Some(ref a) = self.annotations {
189 if let Some(d) = a.destructive {
190 return d;
191 }
192 }
193 let input_str = input.to_string();
195 matches!(self.name.as_str(), "Bash" | "Write" | "Edit")
196 && (input_str.contains("rm -rf")
197 || input_str.contains("rm /")
198 || input_str.contains("dd if=")
199 || input_str.contains("format"))
200 }
201
202 pub fn is_idempotent(&self) -> bool {
204 self.annotations
205 .as_ref()
206 .and_then(|a| a.idempotent)
207 .unwrap_or(false)
208 }
209
210 pub fn get_use_summary(&self, input: &serde_json::Value) -> String {
212 match self.name.as_str() {
213 "Bash" => {
214 if let Some(cmd) = input.get("command").and_then(|v| v.as_str()) {
215 let truncated = if cmd.len() > 50 {
216 format!("{}...", &cmd[..50])
217 } else {
218 cmd.to_string()
219 };
220 format!("Bash: {}", truncated)
221 } else {
222 "Bash".to_string()
223 }
224 }
225 "Read" => {
226 if let Some(path) = input.get("path").and_then(|v| v.as_str()) {
227 format!("Read: {}", path)
228 } else {
229 "Read".to_string()
230 }
231 }
232 "Write" => {
233 if let Some(path) = input.get("path").and_then(|v| v.as_str()) {
234 format!("Write: {}", path)
235 } else {
236 "Write".to_string()
237 }
238 }
239 "Edit" => {
240 if let Some(path) = input.get("file_path").and_then(|v| v.as_str()) {
241 format!("Edit: {}", path)
242 } else {
243 "Edit".to_string()
244 }
245 }
246 "Glob" => {
247 if let Some(pattern) = input.get("pattern").and_then(|v| v.as_str()) {
248 format!("Glob: {}", pattern)
249 } else {
250 "Glob".to_string()
251 }
252 }
253 "Grep" => {
254 if let Some(pattern) = input.get("pattern").and_then(|v| v.as_str()) {
255 format!("Grep: {}", pattern)
256 } else {
257 "Grep".to_string()
258 }
259 }
260 _ => self.name.clone(),
261 }
262 }
263}
264
265#[derive(Debug, Clone, Serialize, Deserialize, Default)]
267pub struct ToolAnnotations {
268 #[serde(rename = "concurrencySafe", skip_serializing_if = "Option::is_none")]
270 pub concurrency_safe: Option<bool>,
271 #[serde(rename = "readOnly", skip_serializing_if = "Option::is_none")]
273 pub read_only: Option<bool>,
274 #[serde(rename = "destructive", skip_serializing_if = "Option::is_none")]
276 pub destructive: Option<bool>,
277 #[serde(skip_serializing_if = "Option::is_none")]
279 pub idempotent: Option<bool>,
280 #[serde(rename = "openWorld", skip_serializing_if = "Option::is_none")]
282 pub open_world: Option<bool>,
283}
284
285impl ToolAnnotations {
286 pub fn read_only() -> Self {
288 Self {
289 read_only: Some(true),
290 ..Default::default()
291 }
292 }
293
294 pub fn destructive() -> Self {
296 Self {
297 destructive: Some(true),
298 ..Default::default()
299 }
300 }
301
302 pub fn concurrency_safe() -> Self {
304 Self {
305 concurrency_safe: Some(true),
306 ..Default::default()
307 }
308 }
309}
310
311#[derive(Debug, Clone, Serialize, Deserialize, Default)]
312pub struct ToolInputSchema {
313 #[serde(rename = "type")]
314 pub schema_type: String,
315 pub properties: serde_json::Value,
316 pub required: Option<Vec<String>>,
317}
318
319#[derive(Debug, Clone, Serialize, Deserialize, Default)]
320pub struct ToolContext {
321 pub cwd: String,
322 #[serde(skip_serializing_if = "Option::is_none")]
323 pub abort_signal: Option<()>,
324}
325
326#[derive(Debug, Clone, Serialize, Deserialize)]
327pub struct ToolResult {
328 #[serde(rename = "type")]
329 pub result_type: String,
330 pub tool_use_id: String,
331 pub content: String,
332 #[serde(skip_serializing_if = "Option::is_none")]
333 pub is_error: Option<bool>,
334}
335
336#[derive(Clone, Serialize, Deserialize)]
337pub struct AgentOptions {
338 pub model: Option<String>,
339 pub api_key: Option<String>,
340 pub base_url: Option<String>,
341 pub cwd: Option<String>,
342 pub system_prompt: Option<String>,
343 pub max_turns: Option<u32>,
344 pub max_budget_usd: Option<f64>,
345 pub max_tokens: Option<u32>,
346 #[serde(default)]
347 pub tools: Vec<ToolDefinition>,
348 #[serde(default)]
349 pub allowed_tools: Vec<String>,
350 #[serde(default)]
351 pub disallowed_tools: Vec<String>,
352 #[serde(default)]
354 pub mcp_servers: Option<std::collections::HashMap<String, McpServerConfig>>,
355 #[serde(skip)]
358 pub on_event: Option<std::sync::Arc<dyn Fn(AgentEvent) + Send + Sync>>,
359}
360
361impl Default for AgentOptions {
362 fn default() -> Self {
363 Self {
364 model: None,
365 api_key: None,
366 base_url: None,
367 cwd: None,
368 system_prompt: None,
369 max_turns: None,
370 max_budget_usd: None,
371 max_tokens: None,
372 tools: Vec::new(),
373 allowed_tools: Vec::new(),
374 disallowed_tools: Vec::new(),
375 mcp_servers: None,
376 on_event: None,
377 }
378 }
379}
380
381#[derive(Debug, Clone, Serialize, Deserialize)]
383pub enum ExitReason {
384 Completed,
386 MaxTurns { max_turns: u32, turn_count: u32 },
388 AbortedStreaming { reason: String },
390 AbortedTools { reason: String },
392 HookStopped,
394 StopHookPrevented,
396 PromptTooLong { error: Option<String> },
398 ImageError { error: String },
400 ModelError { error: String },
402 BlockingLimit,
404 TokenBudgetExhausted { reason: String },
406}
407
408impl Default for ExitReason {
409 fn default() -> Self {
410 ExitReason::Completed
411 }
412}
413
414#[derive(Debug, Clone, Serialize, Deserialize)]
415pub struct QueryResult {
416 pub text: String,
417 pub usage: TokenUsage,
418 pub num_turns: u32,
419 pub duration_ms: u64,
420 pub exit_reason: ExitReason,
422}
423
424#[derive(Debug, Clone)]
426pub enum AgentEvent {
427 ToolStart {
429 tool_name: String,
430 tool_call_id: String,
431 input: serde_json::Value,
432 },
433 ToolComplete {
435 tool_name: String,
436 tool_call_id: String,
437 result: ToolResult,
438 },
439 ToolError {
441 tool_name: String,
442 tool_call_id: String,
443 error: String,
444 },
445 Thinking { turn: u32 },
447 Done { result: QueryResult },
449 MessageStart { message_id: String },
451 ContentBlockStart { index: u32, block_type: String },
453 ContentBlockDelta { index: u32, delta: ContentDelta },
455 ContentBlockStop { index: u32 },
457 MessageStop,
459 RequestStart,
461 MaxTurnsReached { max_turns: u32, turn_count: u32 },
463 Tombstone { message: String },
466}
467
468#[derive(Debug, Clone)]
470pub enum ContentDelta {
471 Text { text: String },
473 ToolUse {
475 id: String,
476 name: String,
477 input: serde_json::Value,
478 is_complete: bool,
479 },
480}
481
482#[derive(Debug, Clone, Serialize, Deserialize)]
488#[serde(untagged)]
489pub enum McpServerConfig {
490 Stdio(McpStdioConfig),
491 Sse(McpSseConfig),
492 Http(McpHttpConfig),
493}
494
495#[derive(Debug, Clone, Serialize, Deserialize)]
496#[serde(rename_all = "camelCase")]
497pub struct McpStdioConfig {
498 #[serde(default = "default_stdio_type")]
499 pub transport_type: Option<String>,
500 pub command: String,
501 pub args: Option<Vec<String>>,
502 pub env: Option<std::collections::HashMap<String, String>>,
503}
504
505fn default_stdio_type() -> Option<String> {
506 Some("stdio".to_string())
507}
508
509#[derive(Debug, Clone, Serialize, Deserialize)]
510#[serde(rename_all = "camelCase")]
511pub struct McpSseConfig {
512 pub transport_type: String,
513 pub url: String,
514 pub headers: Option<std::collections::HashMap<String, String>>,
515}
516
517#[derive(Debug, Clone, Serialize, Deserialize)]
518#[serde(rename_all = "camelCase")]
519pub struct McpHttpConfig {
520 pub transport_type: String,
521 pub url: String,
522 pub headers: Option<std::collections::HashMap<String, String>>,
523}
524
525#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
527#[serde(rename_all = "lowercase")]
528pub enum McpConnectionStatus {
529 Connected,
530 Disconnected,
531 Error,
532}
533
534#[derive(Debug, Clone, Serialize, Deserialize)]
536pub struct McpTool {
537 pub name: String,
538 pub description: Option<String>,
539 #[serde(rename = "inputSchema")]
540 pub input_schema: Option<serde_json::Value>,
541}
542
543#[derive(Debug, Clone, Serialize, Deserialize)]
549pub struct QueryChainTracking {
550 pub chain_id: String,
551 pub depth: u32,
552}
553
554#[derive(Debug, Clone, Serialize, Deserialize)]
557#[serde(tag = "result")]
558pub enum ValidationResult {
559 #[serde(rename = "true")]
561 Valid,
562 Invalid {
564 message: String,
566 #[serde(rename = "errorCode")]
568 error_code: i32,
569 },
570}
571
572impl ValidationResult {
573 pub fn valid() -> Self {
575 ValidationResult::Valid
576 }
577
578 pub fn invalid(message: String, error_code: i32) -> Self {
580 ValidationResult::Invalid { message, error_code }
581 }
582
583 pub fn is_valid(&self) -> bool {
585 matches!(self, ValidationResult::Valid)
586 }
587}
588
589#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
591#[serde(rename_all = "lowercase")]
592pub enum PermissionMode {
593 Default,
594 Auto,
595 #[serde(rename = "auto-accept")]
596 AutoAccept,
597 #[serde(rename = "auto-deny")]
598 AutoDeny,
599 Bypass,
600}
601
602#[derive(Debug, Clone, Serialize, Deserialize)]
604pub struct AdditionalWorkingDirectory {
605 pub path: String,
606 #[serde(rename = "permissionMode")]
607 pub permission_mode: Option<PermissionMode>,
608}
609
610#[derive(Debug, Clone, Serialize, Deserialize)]
612pub struct PermissionResult {
613 pub behavior: PermissionBehavior,
614 #[serde(rename = "updatedInput")]
615 pub updated_input: Option<serde_json::Value>,
616 #[serde(skip_serializing_if = "Option::is_none")]
617 pub message: Option<String>,
618}
619
620#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
622#[serde(rename_all = "kebab-case")]
623pub enum PermissionBehavior {
624 Allow,
625 Deny,
626 Ask,
627}
628
629pub type ToolPermissionRulesBySource = HashMap<String, Vec<String>>;
631
632#[derive(Debug, Clone, Serialize, Deserialize)]
634pub struct ToolPermissionContext {
635 pub mode: PermissionMode,
636 #[serde(rename = "additionalWorkingDirectories")]
637 pub additional_working_directories: HashMap<String, AdditionalWorkingDirectory>,
638 #[serde(rename = "alwaysAllowRules")]
639 pub always_allow_rules: ToolPermissionRulesBySource,
640 #[serde(rename = "alwaysDenyRules")]
641 pub always_deny_rules: ToolPermissionRulesBySource,
642 #[serde(rename = "alwaysAskRules")]
643 pub always_ask_rules: ToolPermissionRulesBySource,
644 #[serde(rename = "isBypassPermissionsModeAvailable")]
645 pub is_bypass_permissions_mode_available: bool,
646 #[serde(
647 rename = "isAutoModeAvailable",
648 skip_serializing_if = "Option::is_none"
649 )]
650 pub is_auto_mode_available: Option<bool>,
651 #[serde(
652 rename = "strippedDangerousRules",
653 skip_serializing_if = "Option::is_none"
654 )]
655 pub stripped_dangerous_rules: Option<ToolPermissionRulesBySource>,
656 #[serde(
657 rename = "shouldAvoidPermissionPrompts",
658 skip_serializing_if = "Option::is_none"
659 )]
660 pub should_avoid_permission_prompts: Option<bool>,
661 #[serde(
662 rename = "awaitAutomatedChecksBeforeDialog",
663 skip_serializing_if = "Option::is_none"
664 )]
665 pub await_automated_checks_before_dialog: Option<bool>,
666 #[serde(rename = "prePlanMode", skip_serializing_if = "Option::is_none")]
667 pub pre_plan_mode: Option<PermissionMode>,
668}
669
670impl Default for ToolPermissionContext {
671 fn default() -> Self {
672 Self {
673 mode: PermissionMode::Default,
674 additional_working_directories: HashMap::new(),
675 always_allow_rules: HashMap::new(),
676 always_deny_rules: HashMap::new(),
677 always_ask_rules: HashMap::new(),
678 is_bypass_permissions_mode_available: false,
679 is_auto_mode_available: None,
680 stripped_dangerous_rules: None,
681 should_avoid_permission_prompts: None,
682 await_automated_checks_before_dialog: None,
683 pre_plan_mode: None,
684 }
685 }
686}
687
688pub fn get_empty_tool_permission_context() -> ToolPermissionContext {
690 ToolPermissionContext::default()
691}
692
693#[derive(Debug, Clone, Serialize, Deserialize)]
695#[serde(tag = "type")]
696pub enum CompactProgressEvent {
697 #[serde(rename = "hooks_start")]
698 HooksStart {
699 #[serde(rename = "hookType")]
700 hook_type: CompactHookType,
701 },
702 #[serde(rename = "compact_start")]
703 CompactStart,
704 #[serde(rename = "compact_end")]
705 CompactEnd,
706}
707
708#[derive(Debug, Clone, Serialize, Deserialize)]
710#[serde(rename_all = "snake_case")]
711pub enum CompactHookType {
712 PreCompact,
713 PostCompact,
714 SessionStart,
715}
716
717#[derive(Debug, Clone, Serialize, Deserialize)]
719pub struct ToolInputJSONSchema {
720 #[serde(flatten)]
721 pub properties: serde_json::Value,
722 #[serde(rename = "type")]
723 pub schema_type: String,
724}
725
726#[derive(Debug, Clone, Serialize, Deserialize)]
732pub struct BashProgress {
733 #[serde(rename = "shell")]
734 pub shell: Option<String>,
735 #[serde(rename = "command")]
736 pub command: Option<String>,
737}
738
739#[derive(Debug, Clone, Serialize, Deserialize)]
741pub struct ReplProgress {
742 #[serde(rename = "input")]
743 pub input: Option<String>,
744 #[serde(rename = "toolName")]
745 pub tool_name: Option<String>,
746 #[serde(rename = "toolCallId")]
747 pub tool_call_id: Option<String>,
748}
749
750#[derive(Debug, Clone, Serialize, Deserialize)]
752pub struct McpProgress {
753 #[serde(rename = "serverName")]
754 pub server_name: String,
755 #[serde(rename = "toolName")]
756 pub tool_name: String,
757 #[serde(rename = "progress")]
758 pub progress: Option<serde_json::Value>,
759}
760
761#[derive(Debug, Clone, Serialize, Deserialize)]
763pub struct WebSearchProgress {
764 #[serde(rename = "query")]
765 pub query: String,
766 #[serde(rename = "currentStep")]
767 pub current_step: Option<String>,
768}
769
770#[derive(Debug, Clone, Serialize, Deserialize)]
772pub struct TaskOutputProgress {
773 #[serde(rename = "taskId")]
774 pub task_id: String,
775 #[serde(rename = "output")]
776 pub output: Option<String>,
777}
778
779#[derive(Debug, Clone, Serialize, Deserialize)]
781pub struct SkillToolProgress {
782 #[serde(rename = "skill")]
783 pub skill: String,
784 #[serde(rename = "step")]
785 pub step: Option<String>,
786}
787
788#[derive(Debug, Clone, Serialize, Deserialize)]
790pub struct AgentToolProgress {
791 #[serde(rename = "description")]
792 pub description: String,
793 #[serde(rename = "subagentType")]
794 pub subagent_type: Option<String>,
795}
796
797#[derive(Debug, Clone, Serialize, Deserialize)]
799#[serde(tag = "type")]
800pub enum ToolProgressData {
801 #[serde(rename = "bash_progress")]
802 BashProgress(BashProgress),
803 #[serde(rename = "repl_progress")]
804 ReplProgress(ReplProgress),
805 #[serde(rename = "mcp_progress")]
806 McpProgress(McpProgress),
807 #[serde(rename = "web_search_progress")]
808 WebSearchProgress(WebSearchProgress),
809 #[serde(rename = "task_output_progress")]
810 TaskOutputProgress(TaskOutputProgress),
811 #[serde(rename = "skill_progress")]
812 SkillProgress(SkillToolProgress),
813 #[serde(rename = "agent_progress")]
814 AgentProgress(AgentToolProgress),
815}
816
817#[derive(Debug, Clone, Serialize, Deserialize)]
819pub struct ToolProgress<P: Clone + serde::Serialize> {
820 #[serde(rename = "toolUseID")]
821 pub tool_use_id: String,
822 pub data: P,
823}
824
825pub fn filter_tool_progress_messages(
827 progress_messages: &[serde_json::Value],
828) -> Vec<serde_json::Value> {
829 progress_messages
830 .iter()
831 .filter(|msg| {
832 let data_type = msg.get("data").and_then(|d| d.get("type"));
833 data_type.map(|t| t != "hook_progress").unwrap_or(true)
834 })
835 .cloned()
836 .collect()
837}