1use std::{collections::BTreeMap, path::PathBuf, sync::Arc};
8
9use derive_more::{Display, From};
10use schemars::{JsonSchema, Schema};
11use serde::{Deserialize, Serialize};
12use serde_with::{DefaultOnError, VecSkipError, serde_as, skip_serializing_none};
13
14use super::{ContentBlock, Error, Meta};
15use crate::{IntoOption, SkipListener};
16
17#[serde_as]
24#[skip_serializing_none]
25#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
26#[serde(rename_all = "camelCase")]
27#[non_exhaustive]
28pub struct ToolCall {
29 pub tool_call_id: ToolCallId,
31 pub title: String,
33 #[serde(default, skip_serializing_if = "ToolKind::is_default")]
36 pub kind: ToolKind,
37 #[serde(default, skip_serializing_if = "ToolCallStatus::is_default")]
39 pub status: ToolCallStatus,
40 #[serde_as(deserialize_as = "DefaultOnError<VecSkipError<_, SkipListener>>")]
42 #[schemars(extend("x-deserialize-default-on-error" = true, "x-deserialize-skip-invalid-items" = true))]
43 #[serde(default, skip_serializing_if = "Vec::is_empty")]
44 pub content: Vec<ToolCallContent>,
45 #[serde_as(deserialize_as = "DefaultOnError<VecSkipError<_, SkipListener>>")]
48 #[schemars(extend("x-deserialize-default-on-error" = true, "x-deserialize-skip-invalid-items" = true))]
49 #[serde(default, skip_serializing_if = "Vec::is_empty")]
50 pub locations: Vec<ToolCallLocation>,
51 pub raw_input: Option<serde_json::Value>,
53 pub raw_output: Option<serde_json::Value>,
55 #[serde(rename = "_meta")]
61 pub meta: Option<Meta>,
62}
63
64impl ToolCall {
65 #[must_use]
66 pub fn new(tool_call_id: impl Into<ToolCallId>, title: impl Into<String>) -> Self {
67 Self {
68 tool_call_id: tool_call_id.into(),
69 title: title.into(),
70 kind: ToolKind::default(),
71 status: ToolCallStatus::default(),
72 content: Vec::default(),
73 locations: Vec::default(),
74 raw_input: None,
75 raw_output: None,
76 meta: None,
77 }
78 }
79
80 #[must_use]
83 pub fn kind(mut self, kind: ToolKind) -> Self {
84 self.kind = kind;
85 self
86 }
87
88 #[must_use]
90 pub fn status(mut self, status: ToolCallStatus) -> Self {
91 self.status = status;
92 self
93 }
94
95 #[must_use]
97 pub fn content(mut self, content: Vec<ToolCallContent>) -> Self {
98 self.content = content;
99 self
100 }
101
102 #[must_use]
105 pub fn locations(mut self, locations: Vec<ToolCallLocation>) -> Self {
106 self.locations = locations;
107 self
108 }
109
110 #[must_use]
112 pub fn raw_input(mut self, raw_input: impl IntoOption<serde_json::Value>) -> Self {
113 self.raw_input = raw_input.into_option();
114 self
115 }
116
117 #[must_use]
119 pub fn raw_output(mut self, raw_output: impl IntoOption<serde_json::Value>) -> Self {
120 self.raw_output = raw_output.into_option();
121 self
122 }
123
124 #[must_use]
130 pub fn meta(mut self, meta: impl IntoOption<Meta>) -> Self {
131 self.meta = meta.into_option();
132 self
133 }
134
135 pub fn update(&mut self, fields: ToolCallUpdateFields) {
138 if let Some(title) = fields.title {
139 self.title = title;
140 }
141 if let Some(kind) = fields.kind {
142 self.kind = kind;
143 }
144 if let Some(status) = fields.status {
145 self.status = status;
146 }
147 if let Some(content) = fields.content {
148 self.content = content;
149 }
150 if let Some(locations) = fields.locations {
151 self.locations = locations;
152 }
153 if let Some(raw_input) = fields.raw_input {
154 self.raw_input = Some(raw_input);
155 }
156 if let Some(raw_output) = fields.raw_output {
157 self.raw_output = Some(raw_output);
158 }
159 }
160}
161
162#[skip_serializing_none]
169#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
170#[serde(rename_all = "camelCase")]
171#[non_exhaustive]
172pub struct ToolCallUpdate {
173 pub tool_call_id: ToolCallId,
175 #[serde(flatten)]
177 pub fields: ToolCallUpdateFields,
178 #[serde(rename = "_meta")]
184 pub meta: Option<Meta>,
185}
186
187impl ToolCallUpdate {
188 #[must_use]
189 pub fn new(tool_call_id: impl Into<ToolCallId>, fields: ToolCallUpdateFields) -> Self {
190 Self {
191 tool_call_id: tool_call_id.into(),
192 fields,
193 meta: None,
194 }
195 }
196
197 #[must_use]
203 pub fn meta(mut self, meta: impl IntoOption<Meta>) -> Self {
204 self.meta = meta.into_option();
205 self
206 }
207}
208
209#[serde_as]
216#[skip_serializing_none]
217#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
218#[serde(rename_all = "camelCase")]
219#[non_exhaustive]
220pub struct ToolCallUpdateFields {
221 #[serde_as(deserialize_as = "DefaultOnError")]
223 #[schemars(extend("x-deserialize-default-on-error" = true))]
224 #[serde(default)]
225 pub kind: Option<ToolKind>,
226 #[serde_as(deserialize_as = "DefaultOnError")]
228 #[schemars(extend("x-deserialize-default-on-error" = true))]
229 #[serde(default)]
230 pub status: Option<ToolCallStatus>,
231 pub title: Option<String>,
233 #[serde_as(deserialize_as = "DefaultOnError<Option<VecSkipError<_, SkipListener>>>")]
235 #[schemars(extend("x-deserialize-default-on-error" = true, "x-deserialize-skip-invalid-items" = true))]
236 #[serde(default)]
237 pub content: Option<Vec<ToolCallContent>>,
238 #[serde_as(deserialize_as = "DefaultOnError<Option<VecSkipError<_, SkipListener>>>")]
240 #[schemars(extend("x-deserialize-default-on-error" = true, "x-deserialize-skip-invalid-items" = true))]
241 #[serde(default)]
242 pub locations: Option<Vec<ToolCallLocation>>,
243 pub raw_input: Option<serde_json::Value>,
245 pub raw_output: Option<serde_json::Value>,
247}
248
249impl ToolCallUpdateFields {
250 #[must_use]
251 pub fn new() -> Self {
252 Self::default()
253 }
254
255 #[must_use]
257 pub fn kind(mut self, kind: impl IntoOption<ToolKind>) -> Self {
258 self.kind = kind.into_option();
259 self
260 }
261
262 #[must_use]
264 pub fn status(mut self, status: impl IntoOption<ToolCallStatus>) -> Self {
265 self.status = status.into_option();
266 self
267 }
268
269 #[must_use]
271 pub fn title(mut self, title: impl IntoOption<String>) -> Self {
272 self.title = title.into_option();
273 self
274 }
275
276 #[must_use]
278 pub fn content(mut self, content: impl IntoOption<Vec<ToolCallContent>>) -> Self {
279 self.content = content.into_option();
280 self
281 }
282
283 #[must_use]
285 pub fn locations(mut self, locations: impl IntoOption<Vec<ToolCallLocation>>) -> Self {
286 self.locations = locations.into_option();
287 self
288 }
289
290 #[must_use]
292 pub fn raw_input(mut self, raw_input: impl IntoOption<serde_json::Value>) -> Self {
293 self.raw_input = raw_input.into_option();
294 self
295 }
296
297 #[must_use]
299 pub fn raw_output(mut self, raw_output: impl IntoOption<serde_json::Value>) -> Self {
300 self.raw_output = raw_output.into_option();
301 self
302 }
303}
304
305impl TryFrom<ToolCallUpdate> for ToolCall {
308 type Error = Error;
309
310 fn try_from(update: ToolCallUpdate) -> Result<Self, Self::Error> {
311 let ToolCallUpdate {
312 tool_call_id,
313 fields:
314 ToolCallUpdateFields {
315 kind,
316 status,
317 title,
318 content,
319 locations,
320 raw_input,
321 raw_output,
322 },
323 meta,
324 } = update;
325
326 Ok(Self {
327 tool_call_id,
328 title: title.ok_or_else(|| {
329 Error::invalid_params().data(serde_json::json!("title is required for a tool call"))
330 })?,
331 kind: kind.unwrap_or_default(),
332 status: status.unwrap_or_default(),
333 content: content.unwrap_or_default(),
334 locations: locations.unwrap_or_default(),
335 raw_input,
336 raw_output,
337 meta,
338 })
339 }
340}
341
342impl From<ToolCall> for ToolCallUpdate {
343 fn from(value: ToolCall) -> Self {
344 let ToolCall {
345 tool_call_id,
346 title,
347 kind,
348 status,
349 content,
350 locations,
351 raw_input,
352 raw_output,
353 meta,
354 } = value;
355 Self {
356 tool_call_id,
357 fields: ToolCallUpdateFields {
358 kind: Some(kind),
359 status: Some(status),
360 title: Some(title),
361 content: Some(content),
362 locations: Some(locations),
363 raw_input,
364 raw_output,
365 },
366 meta,
367 }
368 }
369}
370
371#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq, Hash, Display, From)]
373#[serde(transparent)]
374#[from(Arc<str>, String, &'static str)]
375#[non_exhaustive]
376pub struct ToolCallId(pub Arc<str>);
377
378impl ToolCallId {
379 #[must_use]
380 pub fn new(id: impl Into<Arc<str>>) -> Self {
381 Self(id.into())
382 }
383}
384
385impl IntoOption<ToolCallId> for &str {
386 fn into_option(self) -> Option<ToolCallId> {
387 Some(ToolCallId::new(self))
388 }
389}
390
391#[derive(Debug, Clone, Default, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
398#[serde(rename_all = "snake_case")]
399#[non_exhaustive]
400pub enum ToolKind {
401 Read,
403 Edit,
405 Delete,
407 Move,
409 Search,
411 Execute,
413 Think,
415 Fetch,
417 SwitchMode,
419 #[default]
421 Other,
422 #[serde(untagged)]
428 Unknown(String),
429}
430
431impl ToolKind {
432 fn is_default(&self) -> bool {
433 matches!(self, ToolKind::Other)
434 }
435}
436
437#[derive(Clone, Debug, Default, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
443#[serde(rename_all = "snake_case")]
444#[non_exhaustive]
445pub enum ToolCallStatus {
446 #[default]
449 Pending,
450 InProgress,
452 Completed,
454 Failed,
456 #[serde(untagged)]
462 Other(String),
463}
464
465impl ToolCallStatus {
466 fn is_default(&self) -> bool {
467 matches!(self, ToolCallStatus::Pending)
468 }
469}
470
471#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
478#[serde(tag = "type", rename_all = "snake_case")]
479#[schemars(extend("discriminator" = {"propertyName": "type"}))]
480#[non_exhaustive]
481pub enum ToolCallContent {
482 Content(Content),
484 Diff(Diff),
486 #[serde(untagged)]
496 Other(OtherToolCallContent),
497}
498
499#[derive(Debug, Clone, PartialEq, Serialize, JsonSchema)]
501#[schemars(inline)]
502#[schemars(transform = other_tool_call_content_schema)]
503#[serde(rename_all = "camelCase")]
504#[non_exhaustive]
505pub struct OtherToolCallContent {
506 #[serde(rename = "type")]
512 pub type_: String,
513 #[serde(flatten)]
515 pub fields: BTreeMap<String, serde_json::Value>,
516}
517
518impl OtherToolCallContent {
519 #[must_use]
520 pub fn new(type_: impl Into<String>, mut fields: BTreeMap<String, serde_json::Value>) -> Self {
521 fields.remove("type");
522 Self {
523 type_: type_.into(),
524 fields,
525 }
526 }
527}
528
529impl<'de> Deserialize<'de> for OtherToolCallContent {
530 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
531 where
532 D: serde::Deserializer<'de>,
533 {
534 let mut fields = BTreeMap::<String, serde_json::Value>::deserialize(deserializer)?;
535 let type_ = fields
536 .remove("type")
537 .ok_or_else(|| serde::de::Error::missing_field("type"))?;
538 let serde_json::Value::String(type_) = type_ else {
539 return Err(serde::de::Error::custom("`type` must be a string"));
540 };
541
542 if is_known_tool_call_content_type(&type_) {
543 return Err(serde::de::Error::custom(format!(
544 "known tool call content `{type_}` did not match its schema"
545 )));
546 }
547
548 Ok(Self { type_, fields })
549 }
550}
551
552fn is_known_tool_call_content_type(type_: &str) -> bool {
553 matches!(type_, "content" | "diff")
554}
555
556fn other_tool_call_content_schema(schema: &mut Schema) {
557 super::schema_util::reject_known_string_discriminators(schema, "type", &["content", "diff"]);
558}
559
560impl<T: Into<ContentBlock>> From<T> for ToolCallContent {
561 fn from(content: T) -> Self {
562 ToolCallContent::Content(Content::new(content))
563 }
564}
565
566impl From<Diff> for ToolCallContent {
567 fn from(diff: Diff) -> Self {
568 ToolCallContent::Diff(diff)
569 }
570}
571
572#[skip_serializing_none]
574#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
575#[serde(rename_all = "camelCase")]
576#[non_exhaustive]
577pub struct Content {
578 pub content: ContentBlock,
580 #[serde(rename = "_meta")]
586 pub meta: Option<Meta>,
587}
588
589impl Content {
590 #[must_use]
591 pub fn new(content: impl Into<ContentBlock>) -> Self {
592 Self {
593 content: content.into(),
594 meta: None,
595 }
596 }
597
598 #[must_use]
604 pub fn meta(mut self, meta: impl IntoOption<Meta>) -> Self {
605 self.meta = meta.into_option();
606 self
607 }
608}
609
610#[skip_serializing_none]
616#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
617#[serde(rename_all = "camelCase")]
618#[non_exhaustive]
619pub struct Diff {
620 pub path: PathBuf,
622 pub old_text: Option<String>,
624 pub new_text: String,
626 #[serde(rename = "_meta")]
632 pub meta: Option<Meta>,
633}
634
635impl Diff {
636 #[must_use]
637 pub fn new(path: impl Into<PathBuf>, new_text: impl Into<String>) -> Self {
638 Self {
639 path: path.into(),
640 old_text: None,
641 new_text: new_text.into(),
642 meta: None,
643 }
644 }
645
646 #[must_use]
648 pub fn old_text(mut self, old_text: impl IntoOption<String>) -> Self {
649 self.old_text = old_text.into_option();
650 self
651 }
652
653 #[must_use]
659 pub fn meta(mut self, meta: impl IntoOption<Meta>) -> Self {
660 self.meta = meta.into_option();
661 self
662 }
663}
664
665#[skip_serializing_none]
672#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
673#[serde(rename_all = "camelCase")]
674#[non_exhaustive]
675pub struct ToolCallLocation {
676 pub path: PathBuf,
678 #[serde(default)]
680 pub line: Option<u32>,
681 #[serde(rename = "_meta")]
687 pub meta: Option<Meta>,
688}
689
690impl ToolCallLocation {
691 #[must_use]
692 pub fn new(path: impl Into<PathBuf>) -> Self {
693 Self {
694 path: path.into(),
695 line: None,
696 meta: None,
697 }
698 }
699
700 #[must_use]
702 pub fn line(mut self, line: impl IntoOption<u32>) -> Self {
703 self.line = line.into_option();
704 self
705 }
706
707 #[must_use]
713 pub fn meta(mut self, meta: impl IntoOption<Meta>) -> Self {
714 self.meta = meta.into_option();
715 self
716 }
717}
718
719#[cfg(test)]
720mod tests {
721 use super::*;
722
723 #[test]
724 fn tool_kind_preserves_unknown_variant() {
725 let kind: ToolKind = serde_json::from_str("\"review\"").unwrap();
726 assert_eq!(kind, ToolKind::Unknown("review".to_string()));
727 assert_eq!(serde_json::to_value(&kind).unwrap(), "review");
728 }
729
730 #[test]
731 fn tool_call_status_preserves_unknown_variant() {
732 let status: ToolCallStatus = serde_json::from_str("\"deferred\"").unwrap();
733 assert_eq!(status, ToolCallStatus::Other("deferred".to_string()));
734 assert_eq!(serde_json::to_value(&status).unwrap(), "deferred");
735 }
736
737 #[test]
738 fn tool_call_content_preserves_unknown_variant() {
739 let content: ToolCallContent = serde_json::from_value(serde_json::json!({
740 "type": "_chart",
741 "title": "Tests",
742 "data": [1, 2, 3]
743 }))
744 .unwrap();
745
746 let ToolCallContent::Other(unknown) = content else {
747 panic!("expected unknown tool call content");
748 };
749
750 assert_eq!(unknown.type_, "_chart");
751 assert_eq!(
752 unknown.fields.get("title"),
753 Some(&serde_json::json!("Tests"))
754 );
755 assert_eq!(
756 serde_json::to_value(ToolCallContent::Other(unknown)).unwrap(),
757 serde_json::json!({
758 "type": "_chart",
759 "title": "Tests",
760 "data": [1, 2, 3]
761 })
762 );
763 }
764
765 #[test]
766 fn tool_call_content_does_not_hide_malformed_known_variant() {
767 assert!(
768 serde_json::from_value::<ToolCallContent>(serde_json::json!({
769 "type": "diff"
770 }))
771 .is_err()
772 );
773 }
774}