use std::{collections::BTreeMap, path::PathBuf, sync::Arc};
use derive_more::{Display, From};
use schemars::{JsonSchema, Schema};
use serde::{Deserialize, Serialize};
use serde_with::{DefaultOnError, VecSkipError, serde_as, skip_serializing_none};
use super::{ContentBlock, Meta};
use crate::{IntoMaybeUndefined, IntoOption, MaybeUndefined, SkipListener};
#[serde_as]
#[skip_serializing_none]
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "camelCase")]
#[non_exhaustive]
pub struct ToolCallUpdate {
pub tool_call_id: ToolCallId,
#[serde(default, skip_serializing_if = "MaybeUndefined::is_undefined")]
pub title: MaybeUndefined<String>,
#[serde(default, skip_serializing_if = "MaybeUndefined::is_undefined")]
pub kind: MaybeUndefined<ToolKind>,
#[serde(default, skip_serializing_if = "MaybeUndefined::is_undefined")]
pub status: MaybeUndefined<ToolCallStatus>,
#[serde_as(deserialize_as = "DefaultOnError<MaybeUndefined<VecSkipError<_, SkipListener>>>")]
#[schemars(extend("x-deserialize-default-on-error" = true, "x-deserialize-skip-invalid-items" = true))]
#[serde(default, skip_serializing_if = "MaybeUndefined::is_undefined")]
pub content: MaybeUndefined<Vec<ToolCallContent>>,
#[serde_as(deserialize_as = "DefaultOnError<MaybeUndefined<VecSkipError<_, SkipListener>>>")]
#[schemars(extend("x-deserialize-default-on-error" = true, "x-deserialize-skip-invalid-items" = true))]
#[serde(default, skip_serializing_if = "MaybeUndefined::is_undefined")]
pub locations: MaybeUndefined<Vec<ToolCallLocation>>,
#[serde(default, skip_serializing_if = "MaybeUndefined::is_undefined")]
pub raw_input: MaybeUndefined<serde_json::Value>,
#[serde(default, skip_serializing_if = "MaybeUndefined::is_undefined")]
pub raw_output: MaybeUndefined<serde_json::Value>,
#[serde(rename = "_meta")]
pub meta: Option<Meta>,
}
impl ToolCallUpdate {
#[must_use]
pub fn new(tool_call_id: impl Into<ToolCallId>) -> Self {
Self {
tool_call_id: tool_call_id.into(),
title: MaybeUndefined::Undefined,
kind: MaybeUndefined::Undefined,
status: MaybeUndefined::Undefined,
content: MaybeUndefined::Undefined,
locations: MaybeUndefined::Undefined,
raw_input: MaybeUndefined::Undefined,
raw_output: MaybeUndefined::Undefined,
meta: None,
}
}
#[must_use]
pub fn title(mut self, title: impl IntoMaybeUndefined<String>) -> Self {
self.title = title.into_maybe_undefined();
self
}
#[must_use]
pub fn kind(mut self, kind: impl IntoMaybeUndefined<ToolKind>) -> Self {
self.kind = kind.into_maybe_undefined();
self
}
#[must_use]
pub fn status(mut self, status: impl IntoMaybeUndefined<ToolCallStatus>) -> Self {
self.status = status.into_maybe_undefined();
self
}
#[must_use]
pub fn content(mut self, content: impl IntoMaybeUndefined<Vec<ToolCallContent>>) -> Self {
self.content = content.into_maybe_undefined();
self
}
#[must_use]
pub fn locations(mut self, locations: impl IntoMaybeUndefined<Vec<ToolCallLocation>>) -> Self {
self.locations = locations.into_maybe_undefined();
self
}
#[must_use]
pub fn raw_input(mut self, raw_input: impl IntoMaybeUndefined<serde_json::Value>) -> Self {
self.raw_input = raw_input.into_maybe_undefined();
self
}
#[must_use]
pub fn raw_output(mut self, raw_output: impl IntoMaybeUndefined<serde_json::Value>) -> Self {
self.raw_output = raw_output.into_maybe_undefined();
self
}
#[must_use]
pub fn meta(mut self, meta: impl IntoOption<Meta>) -> Self {
self.meta = meta.into_option();
self
}
pub fn apply_update(&mut self, update: ToolCallUpdate) {
debug_assert_eq!(self.tool_call_id, update.tool_call_id);
if !update.title.is_undefined() {
self.title = update.title;
}
if !update.kind.is_undefined() {
self.kind = update.kind;
}
if !update.status.is_undefined() {
self.status = update.status;
}
if !update.content.is_undefined() {
self.content = update.content;
}
if !update.locations.is_undefined() {
self.locations = update.locations;
}
if !update.raw_input.is_undefined() {
self.raw_input = update.raw_input;
}
if !update.raw_output.is_undefined() {
self.raw_output = update.raw_output;
}
}
}
#[skip_serializing_none]
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "camelCase")]
#[non_exhaustive]
pub struct ToolCallContentChunk {
pub tool_call_id: ToolCallId,
pub content: ToolCallContent,
#[serde(rename = "_meta")]
pub meta: Option<Meta>,
}
impl ToolCallContentChunk {
#[must_use]
pub fn new(tool_call_id: impl Into<ToolCallId>, content: impl Into<ToolCallContent>) -> Self {
Self {
tool_call_id: tool_call_id.into(),
content: content.into(),
meta: None,
}
}
#[must_use]
pub fn meta(mut self, meta: impl IntoOption<Meta>) -> Self {
self.meta = meta.into_option();
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq, Hash, Display, From)]
#[serde(transparent)]
#[from(Arc<str>, String, &'static str)]
#[non_exhaustive]
pub struct ToolCallId(pub Arc<str>);
impl ToolCallId {
#[must_use]
pub fn new(id: impl Into<Arc<str>>) -> Self {
Self(id.into())
}
}
impl IntoOption<ToolCallId> for &str {
fn into_option(self) -> Option<ToolCallId> {
Some(ToolCallId::new(self))
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
#[non_exhaustive]
pub enum ToolKind {
Read,
Edit,
Delete,
Move,
Search,
Execute,
Think,
Fetch,
SwitchMode,
#[default]
Other,
#[serde(untagged)]
Unknown(String),
}
#[derive(Clone, Debug, Default, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
#[non_exhaustive]
pub enum ToolCallStatus {
#[default]
Pending,
InProgress,
Completed,
Failed,
#[serde(untagged)]
Other(String),
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
#[serde(tag = "type", rename_all = "snake_case")]
#[schemars(extend("discriminator" = {"propertyName": "type"}))]
#[non_exhaustive]
pub enum ToolCallContent {
Content(Box<Content>),
Diff(Diff),
#[serde(untagged)]
Other(OtherToolCallContent),
}
#[derive(Debug, Clone, PartialEq, Serialize, JsonSchema)]
#[schemars(inline)]
#[schemars(transform = other_tool_call_content_schema)]
#[serde(rename_all = "camelCase")]
#[non_exhaustive]
pub struct OtherToolCallContent {
#[serde(rename = "type")]
pub type_: String,
#[serde(flatten)]
pub fields: BTreeMap<String, serde_json::Value>,
}
impl OtherToolCallContent {
#[must_use]
pub fn new(type_: impl Into<String>, mut fields: BTreeMap<String, serde_json::Value>) -> Self {
fields.remove("type");
Self {
type_: type_.into(),
fields,
}
}
}
impl<'de> Deserialize<'de> for OtherToolCallContent {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let mut fields = BTreeMap::<String, serde_json::Value>::deserialize(deserializer)?;
let type_ = fields
.remove("type")
.ok_or_else(|| serde::de::Error::missing_field("type"))?;
let serde_json::Value::String(type_) = type_ else {
return Err(serde::de::Error::custom("`type` must be a string"));
};
if is_known_tool_call_content_type(&type_) {
return Err(serde::de::Error::custom(format!(
"known tool call content `{type_}` did not match its schema"
)));
}
Ok(Self { type_, fields })
}
}
fn is_known_tool_call_content_type(type_: &str) -> bool {
matches!(type_, "content" | "diff")
}
fn other_tool_call_content_schema(schema: &mut Schema) {
super::schema_util::reject_known_string_discriminators(schema, "type", &["content", "diff"]);
}
impl<T: Into<ContentBlock>> From<T> for ToolCallContent {
fn from(content: T) -> Self {
ToolCallContent::Content(Box::new(Content::new(content)))
}
}
impl From<Diff> for ToolCallContent {
fn from(diff: Diff) -> Self {
ToolCallContent::Diff(diff)
}
}
#[skip_serializing_none]
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "camelCase")]
#[non_exhaustive]
pub struct Content {
pub content: ContentBlock,
#[serde(rename = "_meta")]
pub meta: Option<Meta>,
}
impl Content {
#[must_use]
pub fn new(content: impl Into<ContentBlock>) -> Self {
Self {
content: content.into(),
meta: None,
}
}
#[must_use]
pub fn meta(mut self, meta: impl IntoOption<Meta>) -> Self {
self.meta = meta.into_option();
self
}
}
#[skip_serializing_none]
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "camelCase")]
#[non_exhaustive]
pub struct Diff {
pub path: PathBuf,
pub old_text: Option<String>,
pub new_text: String,
#[serde(rename = "_meta")]
pub meta: Option<Meta>,
}
impl Diff {
#[must_use]
pub fn new(path: impl Into<PathBuf>, new_text: impl Into<String>) -> Self {
Self {
path: path.into(),
old_text: None,
new_text: new_text.into(),
meta: None,
}
}
#[must_use]
pub fn old_text(mut self, old_text: impl IntoOption<String>) -> Self {
self.old_text = old_text.into_option();
self
}
#[must_use]
pub fn meta(mut self, meta: impl IntoOption<Meta>) -> Self {
self.meta = meta.into_option();
self
}
}
#[skip_serializing_none]
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "camelCase")]
#[non_exhaustive]
pub struct ToolCallLocation {
pub path: PathBuf,
#[serde(default)]
pub line: Option<u32>,
#[serde(rename = "_meta")]
pub meta: Option<Meta>,
}
impl ToolCallLocation {
#[must_use]
pub fn new(path: impl Into<PathBuf>) -> Self {
Self {
path: path.into(),
line: None,
meta: None,
}
}
#[must_use]
pub fn line(mut self, line: impl IntoOption<u32>) -> Self {
self.line = line.into_option();
self
}
#[must_use]
pub fn meta(mut self, meta: impl IntoOption<Meta>) -> Self {
self.meta = meta.into_option();
self
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::MaybeUndefined;
#[test]
fn tool_call_serializes_as_upsert() {
let tool_call = ToolCallUpdate::new("tc_1")
.title("Reading configuration")
.status(ToolCallStatus::InProgress)
.raw_input(serde_json::json!({"path": "settings.json"}));
assert_eq!(
serde_json::to_value(tool_call).unwrap(),
serde_json::json!({
"toolCallId": "tc_1",
"title": "Reading configuration",
"status": "in_progress",
"rawInput": {
"path": "settings.json"
}
})
);
}
#[test]
fn tool_call_update_distinguishes_omitted_null_and_value() {
let tool_call = ToolCallUpdate::new("tc_1")
.status(ToolCallStatus::Completed)
.content(None::<Vec<ToolCallContent>>);
assert_eq!(
serde_json::to_value(tool_call).unwrap(),
serde_json::json!({
"toolCallId": "tc_1",
"status": "completed",
"content": null
})
);
let deserialized: ToolCallUpdate = serde_json::from_value(serde_json::json!({
"toolCallId": "tc_1",
"status": null,
"locations": []
}))
.unwrap();
assert_eq!(deserialized.title, MaybeUndefined::Undefined);
assert_eq!(deserialized.status, MaybeUndefined::Null);
assert_eq!(deserialized.locations, MaybeUndefined::Value(Vec::new()));
}
#[test]
fn tool_call_update_skips_malformed_list_items() {
let deserialized: ToolCallUpdate = serde_json::from_value(serde_json::json!({
"toolCallId": "tc_1",
"content": [
{
"type": "content",
"content": {
"type": "text",
"text": "ok"
}
},
{
"type": "diff",
"path": "/bad"
}
],
"locations": [
{
"path": "/ok",
"line": 3
},
{
"line": 4
}
]
}))
.unwrap();
let MaybeUndefined::Value(content) = deserialized.content else {
panic!("content should deserialize to a value");
};
assert_eq!(content.len(), 1);
let MaybeUndefined::Value(locations) = deserialized.locations else {
panic!("locations should deserialize to a value");
};
assert_eq!(locations.len(), 1);
}
#[test]
fn tool_call_content_chunk_serializes_single_content_item() {
let chunk = ToolCallContentChunk::new(
"tc_1",
ContentBlock::Text(crate::v2::TextContent::new("partial output")),
);
assert_eq!(
serde_json::to_value(chunk).unwrap(),
serde_json::json!({
"toolCallId": "tc_1",
"content": {
"type": "content",
"content": {
"type": "text",
"text": "partial output"
}
}
})
);
}
#[test]
fn tool_kind_preserves_unknown_variant() {
let kind: ToolKind = serde_json::from_str("\"review\"").unwrap();
assert_eq!(kind, ToolKind::Unknown("review".to_string()));
assert_eq!(serde_json::to_value(&kind).unwrap(), "review");
}
#[test]
fn tool_call_status_preserves_unknown_variant() {
let status: ToolCallStatus = serde_json::from_str("\"deferred\"").unwrap();
assert_eq!(status, ToolCallStatus::Other("deferred".to_string()));
assert_eq!(serde_json::to_value(&status).unwrap(), "deferred");
}
#[test]
fn tool_call_content_preserves_unknown_variant() {
let content: ToolCallContent = serde_json::from_value(serde_json::json!({
"type": "_chart",
"title": "Tests",
"data": [1, 2, 3]
}))
.unwrap();
let ToolCallContent::Other(unknown) = content else {
panic!("expected unknown tool call content");
};
assert_eq!(unknown.type_, "_chart");
assert_eq!(
unknown.fields.get("title"),
Some(&serde_json::json!("Tests"))
);
assert_eq!(
serde_json::to_value(ToolCallContent::Other(unknown)).unwrap(),
serde_json::json!({
"type": "_chart",
"title": "Tests",
"data": [1, 2, 3]
})
);
}
#[test]
fn tool_call_content_does_not_hide_malformed_known_variant() {
assert!(
serde_json::from_value::<ToolCallContent>(serde_json::json!({
"type": "diff"
}))
.is_err()
);
}
}