1use std::collections::HashMap;
2
3use base64::Engine;
4use mcp_types::CallToolResult;
5use serde::Deserialize;
6use serde::Deserializer;
7use serde::Serialize;
8use serde::ser::Serializer;
9use ts_rs::TS;
10
11use crate::protocol::InputItem;
12
13#[derive(Debug, Clone, Copy, Default, Eq, Hash, PartialEq, Serialize, Deserialize, TS)]
15#[serde(rename_all = "snake_case")]
16pub enum SandboxPermissions {
17 #[default]
19 UseDefault,
20 RequireEscalated,
22}
23
24impl SandboxPermissions {
25 pub fn requires_escalated_permissions(self) -> bool {
26 matches!(self, SandboxPermissions::RequireEscalated)
27 }
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, TS)]
31#[serde(tag = "type", rename_all = "snake_case")]
32pub enum ResponseInputItem {
33 Message {
34 role: String,
35 content: Vec<ContentItem>,
36 },
37 FunctionCallOutput {
38 call_id: String,
39 output: FunctionCallOutputPayload,
40 },
41 McpToolCallOutput {
42 call_id: String,
43 result: Result<CallToolResult, String>,
44 },
45 CustomToolCallOutput {
46 call_id: String,
47 output: String,
48 },
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, TS)]
52#[serde(tag = "type", rename_all = "snake_case")]
53pub enum ContentItem {
54 InputText { text: String },
55 InputImage { image_url: String },
56 OutputText { text: String },
57}
58
59#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, TS)]
60#[serde(tag = "type", rename_all = "snake_case")]
61pub enum ResponseItem {
62 Message {
63 #[serde(skip_serializing)]
64 id: Option<String>,
65 role: String,
66 content: Vec<ContentItem>,
67 },
68 Reasoning {
69 #[serde(default, skip_serializing)]
70 id: String,
71 summary: Vec<ReasoningItemReasoningSummary>,
72 #[serde(default, skip_serializing_if = "should_serialize_reasoning_content")]
73 content: Option<Vec<ReasoningItemContent>>,
74 encrypted_content: Option<String>,
75 },
76 CompactionSummary {
79 encrypted_content: String,
80 },
81 LocalShellCall {
82 #[serde(skip_serializing)]
84 id: Option<String>,
85 call_id: Option<String>,
87 status: LocalShellStatus,
88 action: LocalShellAction,
89 },
90 FunctionCall {
91 #[serde(skip_serializing)]
92 id: Option<String>,
93 name: String,
94 arguments: String,
99 call_id: String,
100 },
101 FunctionCallOutput {
108 call_id: String,
109 output: FunctionCallOutputPayload,
110 },
111 CustomToolCall {
112 #[serde(skip_serializing)]
113 id: Option<String>,
114 #[serde(default, skip_serializing_if = "Option::is_none")]
115 status: Option<String>,
116
117 call_id: String,
118 name: String,
119 input: String,
120 },
121 CustomToolCallOutput {
122 call_id: String,
123 output: String,
124 },
125 WebSearchCall {
134 #[serde(skip_serializing)]
135 id: Option<String>,
136 #[serde(default, skip_serializing_if = "Option::is_none")]
137 status: Option<String>,
138 action: WebSearchAction,
139 },
140
141 #[serde(other)]
142 Other,
143}
144
145fn should_serialize_reasoning_content(content: &Option<Vec<ReasoningItemContent>>) -> bool {
146 match content {
147 Some(content) => !content
148 .iter()
149 .any(|c| matches!(c, ReasoningItemContent::ReasoningText { .. })),
150 None => false,
151 }
152}
153
154impl From<ResponseInputItem> for ResponseItem {
155 fn from(item: ResponseInputItem) -> Self {
156 match item {
157 ResponseInputItem::Message { role, content } => Self::Message {
158 role,
159 content,
160 id: None,
161 },
162 ResponseInputItem::FunctionCallOutput { call_id, output } => {
163 Self::FunctionCallOutput { call_id, output }
164 }
165 ResponseInputItem::McpToolCallOutput { call_id, result } => Self::FunctionCallOutput {
166 call_id,
167 output: FunctionCallOutputPayload {
168 success: Some(result.is_ok()),
169 content: result.map_or_else(
170 |tool_call_err| format!("err: {tool_call_err:?}"),
171 |result| {
172 serde_json::to_string(&result)
173 .unwrap_or_else(|e| format!("JSON serialization error: {e}"))
174 },
175 ),
176 },
177 },
178 ResponseInputItem::CustomToolCallOutput { call_id, output } => {
179 Self::CustomToolCallOutput { call_id, output }
180 }
181 }
182 }
183}
184
185#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, TS)]
186#[serde(rename_all = "snake_case")]
187pub enum LocalShellStatus {
188 Completed,
189 InProgress,
190 Incomplete,
191}
192
193#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, TS)]
194#[serde(tag = "type", rename_all = "snake_case")]
195pub enum LocalShellAction {
196 Exec(LocalShellExecAction),
197}
198
199#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, TS)]
200pub struct LocalShellExecAction {
201 pub command: Vec<String>,
202 pub timeout_ms: Option<u64>,
203 pub working_directory: Option<String>,
204 pub env: Option<HashMap<String, String>>,
205 pub user: Option<String>,
206}
207
208#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, TS)]
209#[serde(tag = "type", rename_all = "snake_case")]
210pub enum WebSearchAction {
211 Search {
212 query: String,
213 },
214 #[serde(other)]
215 Other,
216}
217
218#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, TS)]
219#[serde(tag = "type", rename_all = "snake_case")]
220pub enum ReasoningItemReasoningSummary {
221 SummaryText { text: String },
222}
223
224#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, TS)]
225#[serde(tag = "type", rename_all = "snake_case")]
226pub enum ReasoningItemContent {
227 ReasoningText { text: String },
228 Text { text: String },
229}
230
231impl From<Vec<InputItem>> for ResponseInputItem {
232 fn from(items: Vec<InputItem>) -> Self {
233 Self::Message {
234 role: "user".to_string(),
235 content: items
236 .into_iter()
237 .filter_map(|c| match c {
238 InputItem::Text { text } => Some(ContentItem::InputText { text }),
239 InputItem::Image { image_url } => Some(ContentItem::InputImage { image_url }),
240 InputItem::LocalImage { path } => match std::fs::read(&path) {
241 Ok(bytes) => {
242 let mime = mime_guess::from_path(&path)
243 .first()
244 .map(|m| m.essence_str().to_owned())
245 .unwrap_or_else(|| "application/octet-stream".to_string());
246 let encoded = base64::engine::general_purpose::STANDARD.encode(bytes);
247 Some(ContentItem::InputImage {
248 image_url: format!("data:{mime};base64,{encoded}"),
249 })
250 }
251 Err(err) => {
252 tracing::warn!(
253 "Skipping image {} – could not read file: {}",
254 path.display(),
255 err
256 );
257 None
258 }
259 },
260 })
261 .collect::<Vec<ContentItem>>(),
262 }
263 }
264}
265
266#[derive(Deserialize, Debug, Clone, PartialEq, TS)]
269pub struct ShellToolCallParams {
270 pub command: Vec<String>,
271 pub workdir: Option<String>,
272
273 #[serde(alias = "timeout")]
275 pub timeout_ms: Option<u64>,
276 #[serde(default, skip_serializing_if = "Option::is_none")]
277 pub sandbox_permissions: Option<SandboxPermissions>,
278 #[serde(skip_serializing_if = "Option::is_none")]
279 pub justification: Option<String>,
280}
281
282#[derive(Debug, Clone, PartialEq, TS)]
283pub struct FunctionCallOutputPayload {
284 pub content: String,
285 pub success: Option<bool>,
286}
287
288impl Serialize for FunctionCallOutputPayload {
295 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
296 where
297 S: Serializer,
298 {
299 serializer.serialize_str(&self.content)
306 }
307}
308
309impl<'de> Deserialize<'de> for FunctionCallOutputPayload {
310 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
311 where
312 D: Deserializer<'de>,
313 {
314 let s = String::deserialize(deserializer)?;
315 Ok(FunctionCallOutputPayload {
316 content: s,
317 success: None,
318 })
319 }
320}
321
322impl std::fmt::Display for FunctionCallOutputPayload {
327 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
328 f.write_str(&self.content)
329 }
330}
331
332impl std::ops::Deref for FunctionCallOutputPayload {
333 type Target = str;
334 fn deref(&self) -> &Self::Target {
335 &self.content
336 }
337}
338
339#[cfg(test)]
342mod tests {
343 use super::*;
344 use anyhow::Result;
345
346 #[test]
347 fn serializes_success_as_plain_string() -> Result<()> {
348 let item = ResponseInputItem::FunctionCallOutput {
349 call_id: "call1".into(),
350 output: FunctionCallOutputPayload {
351 content: "ok".into(),
352 success: None,
353 },
354 };
355
356 let json = serde_json::to_string(&item)?;
357 let v: serde_json::Value = serde_json::from_str(&json)?;
358
359 assert_eq!(v.get("output").unwrap().as_str().unwrap(), "ok");
361 Ok(())
362 }
363
364 #[test]
365 fn serializes_failure_as_string() -> Result<()> {
366 let item = ResponseInputItem::FunctionCallOutput {
367 call_id: "call1".into(),
368 output: FunctionCallOutputPayload {
369 content: "bad".into(),
370 success: Some(false),
371 },
372 };
373
374 let json = serde_json::to_string(&item)?;
375 let v: serde_json::Value = serde_json::from_str(&json)?;
376
377 assert_eq!(v.get("output").unwrap().as_str().unwrap(), "bad");
378 Ok(())
379 }
380
381 #[test]
382 fn deserialize_shell_tool_call_params() -> Result<()> {
383 let json = r#"{
384 "command": ["ls", "-l"],
385 "workdir": "/tmp",
386 "timeout": 1000
387 }"#;
388
389 let params: ShellToolCallParams = serde_json::from_str(json)?;
390 assert_eq!(
391 ShellToolCallParams {
392 command: vec!["ls".to_string(), "-l".to_string()],
393 workdir: Some("/tmp".to_string()),
394 timeout_ms: Some(1000),
395 sandbox_permissions: None,
396 justification: None,
397 },
398 params
399 );
400 Ok(())
401 }
402}