agent_client_protocol/
tool_call.rs

1use std::{path::PathBuf, sync::Arc};
2
3use schemars::JsonSchema;
4use serde::{Deserialize, Serialize};
5
6use crate::{ContentBlock, Error};
7
8#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
9#[serde(rename_all = "camelCase")]
10pub struct ToolCall {
11    #[serde(rename = "toolCallId")]
12    pub id: ToolCallId,
13    pub title: String,
14    #[serde(default, skip_serializing_if = "ToolKind::is_default")]
15    pub kind: ToolKind,
16    #[serde(default, skip_serializing_if = "ToolCallStatus::is_default")]
17    pub status: ToolCallStatus,
18    #[serde(default, skip_serializing_if = "Vec::is_empty")]
19    pub content: Vec<ToolCallContent>,
20    #[serde(default, skip_serializing_if = "Vec::is_empty")]
21    pub locations: Vec<ToolCallLocation>,
22    #[serde(default, skip_serializing_if = "Option::is_none")]
23    pub raw_input: Option<serde_json::Value>,
24    #[serde(default, skip_serializing_if = "Option::is_none")]
25    pub raw_output: Option<serde_json::Value>,
26}
27
28impl ToolCall {
29    /// Update an existing tool call with the values in the provided update
30    /// fields. Fields with collections of values are overwritten, not extended.
31    pub fn update(&mut self, fields: ToolCallUpdateFields) {
32        if let Some(title) = fields.title {
33            self.title = title;
34        }
35        if let Some(kind) = fields.kind {
36            self.kind = kind;
37        }
38        if let Some(status) = fields.status {
39            self.status = status;
40        }
41        if let Some(content) = fields.content {
42            self.content = content;
43        }
44        if let Some(locations) = fields.locations {
45            self.locations = locations;
46        }
47        if let Some(raw_input) = fields.raw_input {
48            self.raw_input = Some(raw_input);
49        }
50        if let Some(raw_output) = fields.raw_output {
51            self.raw_output = Some(raw_output);
52        }
53    }
54}
55
56#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
57#[serde(rename_all = "camelCase")]
58pub struct ToolCallUpdate {
59    #[serde(rename = "toolCallId")]
60    pub id: ToolCallId,
61    #[serde(flatten)]
62    pub fields: ToolCallUpdateFields,
63}
64
65#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
66#[serde(rename_all = "camelCase")]
67pub struct ToolCallUpdateFields {
68    #[serde(default, skip_serializing_if = "Option::is_none")]
69    pub kind: Option<ToolKind>,
70    #[serde(default, skip_serializing_if = "Option::is_none")]
71    pub status: Option<ToolCallStatus>,
72    #[serde(default, skip_serializing_if = "Option::is_none")]
73    pub title: Option<String>,
74    #[serde(default, skip_serializing_if = "Option::is_none")]
75    pub content: Option<Vec<ToolCallContent>>,
76    #[serde(default, skip_serializing_if = "Option::is_none")]
77    pub locations: Option<Vec<ToolCallLocation>>,
78    #[serde(default, skip_serializing_if = "Option::is_none")]
79    pub raw_input: Option<serde_json::Value>,
80    #[serde(default, skip_serializing_if = "Option::is_none")]
81    pub raw_output: Option<serde_json::Value>,
82}
83
84/// If a given tool call doesn't exist yet, allows for attempting to construct
85/// one from a tool call update if possible.
86impl TryFrom<ToolCallUpdate> for ToolCall {
87    type Error = Error;
88
89    fn try_from(update: ToolCallUpdate) -> Result<Self, Self::Error> {
90        let ToolCallUpdate {
91            id,
92            fields:
93                ToolCallUpdateFields {
94                    kind,
95                    status,
96                    title,
97                    content,
98                    locations,
99                    raw_input,
100                    raw_output,
101                },
102        } = update;
103
104        Ok(Self {
105            id,
106            title: title.ok_or_else(|| {
107                Error::invalid_params()
108                    .with_data(serde_json::json!("title is required for a tool call"))
109            })?,
110            kind: kind.unwrap_or_default(),
111            status: status.unwrap_or_default(),
112            content: content.unwrap_or_default(),
113            locations: locations.unwrap_or_default(),
114            raw_input,
115            raw_output,
116        })
117    }
118}
119
120impl From<ToolCall> for ToolCallUpdate {
121    fn from(value: ToolCall) -> Self {
122        let ToolCall {
123            id,
124            title,
125            kind,
126            status,
127            content,
128            locations,
129            raw_input,
130            raw_output,
131        } = value;
132        Self {
133            id,
134            fields: ToolCallUpdateFields {
135                kind: Some(kind),
136                status: Some(status),
137                title: Some(title),
138                content: Some(content),
139                locations: Some(locations),
140                raw_input,
141                raw_output,
142            },
143        }
144    }
145}
146
147#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq, Hash)]
148#[serde(transparent)]
149pub struct ToolCallId(pub Arc<str>);
150
151#[derive(Debug, Clone, Default, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
152#[serde(rename_all = "snake_case")]
153pub enum ToolKind {
154    Read,
155    Edit,
156    Delete,
157    Move,
158    Search,
159    Execute,
160    Think,
161    Fetch,
162    #[default]
163    Other,
164}
165
166impl ToolKind {
167    fn is_default(&self) -> bool {
168        matches!(self, ToolKind::Other)
169    }
170}
171
172#[derive(Clone, Copy, Debug, Default, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
173#[serde(rename_all = "snake_case")]
174pub enum ToolCallStatus {
175    /// The tool call hasn't started running yet because the input is either
176    /// streaming or we're awaiting approval.
177    #[default]
178    Pending,
179    /// The tool call is currently running.
180    InProgress,
181    /// The tool call completed successfully.
182    Completed,
183    /// The tool call failed.
184    Failed,
185}
186
187impl ToolCallStatus {
188    fn is_default(&self) -> bool {
189        matches!(self, ToolCallStatus::Pending)
190    }
191}
192
193#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
194#[serde(tag = "type", rename_all = "snake_case")]
195pub enum ToolCallContent {
196    Content {
197        content: ContentBlock,
198    },
199    Diff {
200        #[serde(flatten)]
201        diff: Diff,
202    },
203}
204
205impl<T: Into<ContentBlock>> From<T> for ToolCallContent {
206    fn from(content: T) -> Self {
207        ToolCallContent::Content {
208            content: content.into(),
209        }
210    }
211}
212
213impl From<Diff> for ToolCallContent {
214    fn from(diff: Diff) -> Self {
215        ToolCallContent::Diff { diff }
216    }
217}
218
219#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
220#[serde(rename_all = "camelCase")]
221pub struct Diff {
222    pub path: PathBuf,
223    pub old_text: Option<String>,
224    pub new_text: String,
225}
226
227#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
228#[serde(tag = "type", rename_all = "camelCase")]
229pub struct ToolCallLocation {
230    pub path: PathBuf,
231    #[serde(default, skip_serializing_if = "Option::is_none")]
232    pub line: Option<u32>,
233}