1use std::{path::PathBuf, sync::Arc};
2
3use schemars::JsonSchema;
4use serde::{Deserialize, Serialize};
5
6use crate::{ContentBlock, Error};
7
8#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
9#[serde(rename_all = "camelCase")]
10pub struct ToolCall {
11 #[serde(rename = "toolCallId")]
12 pub id: ToolCallId,
13 pub title: String,
14 #[serde(default, skip_serializing_if = "ToolKind::is_default")]
15 pub kind: ToolKind,
16 #[serde(default, skip_serializing_if = "ToolCallStatus::is_default")]
17 pub status: ToolCallStatus,
18 #[serde(default, skip_serializing_if = "Vec::is_empty")]
19 pub content: Vec<ToolCallContent>,
20 #[serde(default, skip_serializing_if = "Vec::is_empty")]
21 pub locations: Vec<ToolCallLocation>,
22 #[serde(default, skip_serializing_if = "Option::is_none")]
23 pub raw_input: Option<serde_json::Value>,
24 #[serde(default, skip_serializing_if = "Option::is_none")]
25 pub raw_output: Option<serde_json::Value>,
26}
27
28impl ToolCall {
29 pub fn update(&mut self, fields: ToolCallUpdateFields) {
32 if let Some(title) = fields.title {
33 self.title = title;
34 }
35 if let Some(kind) = fields.kind {
36 self.kind = kind;
37 }
38 if let Some(status) = fields.status {
39 self.status = status;
40 }
41 if let Some(content) = fields.content {
42 self.content = content;
43 }
44 if let Some(locations) = fields.locations {
45 self.locations = locations;
46 }
47 if let Some(raw_input) = fields.raw_input {
48 self.raw_input = Some(raw_input);
49 }
50 if let Some(raw_output) = fields.raw_output {
51 self.raw_output = Some(raw_output);
52 }
53 }
54}
55
56#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
57#[serde(rename_all = "camelCase")]
58pub struct ToolCallUpdate {
59 #[serde(rename = "toolCallId")]
60 pub id: ToolCallId,
61 #[serde(flatten)]
62 pub fields: ToolCallUpdateFields,
63}
64
65#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
66#[serde(rename_all = "camelCase")]
67pub struct ToolCallUpdateFields {
68 #[serde(default, skip_serializing_if = "Option::is_none")]
69 pub kind: Option<ToolKind>,
70 #[serde(default, skip_serializing_if = "Option::is_none")]
71 pub status: Option<ToolCallStatus>,
72 #[serde(default, skip_serializing_if = "Option::is_none")]
73 pub title: Option<String>,
74 #[serde(default, skip_serializing_if = "Option::is_none")]
75 pub content: Option<Vec<ToolCallContent>>,
76 #[serde(default, skip_serializing_if = "Option::is_none")]
77 pub locations: Option<Vec<ToolCallLocation>>,
78 #[serde(default, skip_serializing_if = "Option::is_none")]
79 pub raw_input: Option<serde_json::Value>,
80 #[serde(default, skip_serializing_if = "Option::is_none")]
81 pub raw_output: Option<serde_json::Value>,
82}
83
84impl TryFrom<ToolCallUpdate> for ToolCall {
87 type Error = Error;
88
89 fn try_from(update: ToolCallUpdate) -> Result<Self, Self::Error> {
90 let ToolCallUpdate {
91 id,
92 fields:
93 ToolCallUpdateFields {
94 kind,
95 status,
96 title,
97 content,
98 locations,
99 raw_input,
100 raw_output,
101 },
102 } = update;
103
104 Ok(Self {
105 id,
106 title: title.ok_or_else(|| {
107 Error::invalid_params()
108 .with_data(serde_json::json!("title is required for a tool call"))
109 })?,
110 kind: kind.unwrap_or_default(),
111 status: status.unwrap_or_default(),
112 content: content.unwrap_or_default(),
113 locations: locations.unwrap_or_default(),
114 raw_input,
115 raw_output,
116 })
117 }
118}
119
120impl From<ToolCall> for ToolCallUpdate {
121 fn from(value: ToolCall) -> Self {
122 let ToolCall {
123 id,
124 title,
125 kind,
126 status,
127 content,
128 locations,
129 raw_input,
130 raw_output,
131 } = value;
132 Self {
133 id,
134 fields: ToolCallUpdateFields {
135 kind: Some(kind),
136 status: Some(status),
137 title: Some(title),
138 content: Some(content),
139 locations: Some(locations),
140 raw_input,
141 raw_output,
142 },
143 }
144 }
145}
146
147#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq, Hash)]
148#[serde(transparent)]
149pub struct ToolCallId(pub Arc<str>);
150
151#[derive(Debug, Clone, Default, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
152#[serde(rename_all = "snake_case")]
153pub enum ToolKind {
154 Read,
155 Edit,
156 Delete,
157 Move,
158 Search,
159 Execute,
160 Think,
161 Fetch,
162 #[default]
163 Other,
164}
165
166impl ToolKind {
167 fn is_default(&self) -> bool {
168 matches!(self, ToolKind::Other)
169 }
170}
171
172#[derive(Clone, Copy, Debug, Default, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
173#[serde(rename_all = "snake_case")]
174pub enum ToolCallStatus {
175 #[default]
178 Pending,
179 InProgress,
181 Completed,
183 Failed,
185}
186
187impl ToolCallStatus {
188 fn is_default(&self) -> bool {
189 matches!(self, ToolCallStatus::Pending)
190 }
191}
192
193#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
194#[serde(tag = "type", rename_all = "snake_case")]
195pub enum ToolCallContent {
196 Content {
197 content: ContentBlock,
198 },
199 Diff {
200 #[serde(flatten)]
201 diff: Diff,
202 },
203}
204
205impl<T: Into<ContentBlock>> From<T> for ToolCallContent {
206 fn from(content: T) -> Self {
207 ToolCallContent::Content {
208 content: content.into(),
209 }
210 }
211}
212
213impl From<Diff> for ToolCallContent {
214 fn from(diff: Diff) -> Self {
215 ToolCallContent::Diff { diff }
216 }
217}
218
219#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
220#[serde(rename_all = "camelCase")]
221pub struct Diff {
222 pub path: PathBuf,
223 pub old_text: Option<String>,
224 pub new_text: String,
225}
226
227#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
228#[serde(tag = "type", rename_all = "camelCase")]
229pub struct ToolCallLocation {
230 pub path: PathBuf,
231 #[serde(default, skip_serializing_if = "Option::is_none")]
232 pub line: Option<u32>,
233}