use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
#[non_exhaustive]
pub enum Chunk {
TextDelta(String),
ToolCallStart {
id: String,
name: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
signature: Option<String>,
},
ToolCallArgsDelta {
id: String,
args_json_delta: String,
},
ToolCallEnd {
id: String,
},
Usage(Usage),
Stop(StopReason),
}
impl Chunk {
#[must_use]
pub fn text_delta(s: impl Into<String>) -> Self {
Self::TextDelta(s.into())
}
#[must_use]
pub fn tool_call_start(id: impl Into<String>, name: impl Into<String>) -> Self {
Self::ToolCallStart {
id: id.into(),
name: name.into(),
signature: None,
}
}
#[must_use]
pub fn tool_call_start_signed(
id: impl Into<String>,
name: impl Into<String>,
signature: Option<String>,
) -> Self {
Self::ToolCallStart {
id: id.into(),
name: name.into(),
signature,
}
}
#[must_use]
pub fn tool_call_args_delta(id: impl Into<String>, args_json_delta: impl Into<String>) -> Self {
Self::ToolCallArgsDelta {
id: id.into(),
args_json_delta: args_json_delta.into(),
}
}
#[must_use]
pub fn tool_call_end(id: impl Into<String>) -> Self {
Self::ToolCallEnd { id: id.into() }
}
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
pub struct Usage {
pub input_tokens: u64,
pub output_tokens: u64,
}
impl Usage {
#[must_use]
pub const fn total_tokens(self) -> u64 {
self.input_tokens.saturating_add(self.output_tokens)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
#[non_exhaustive]
pub enum StopReason {
EndTurn,
MaxTokens,
StopSequence,
ToolUse,
Refusal,
}
#[cfg(test)]
mod tests {
#![allow(clippy::pedantic, clippy::nursery, missing_docs)]
use serde_json::{Value, json};
use super::*;
#[test]
fn text_delta_constructor() {
assert_eq!(Chunk::text_delta("hi"), Chunk::TextDelta("hi".to_owned()));
}
#[test]
fn tool_call_start_constructor() {
match Chunk::tool_call_start("call-1", "search") {
Chunk::ToolCallStart { id, name, .. } => {
assert_eq!(id, "call-1");
assert_eq!(name, "search");
}
_ => panic!("wrong variant"),
}
}
#[test]
fn tool_call_args_delta_constructor() {
match Chunk::tool_call_args_delta("call-1", r#"{"q":"#) {
Chunk::ToolCallArgsDelta {
id,
args_json_delta,
} => {
assert_eq!(id, "call-1");
assert_eq!(args_json_delta, r#"{"q":"#);
}
_ => panic!("wrong variant"),
}
}
#[test]
fn tool_call_end_constructor() {
assert_eq!(
Chunk::tool_call_end("call-1"),
Chunk::ToolCallEnd {
id: "call-1".to_owned()
},
);
}
#[test]
fn text_delta_serializes_as_tagged_object() {
let v: Value = serde_json::to_value(Chunk::text_delta("hello")).unwrap();
assert_eq!(v, json!({"text_delta": "hello"}));
}
#[test]
fn tool_call_start_serializes_with_named_fields() {
let v: Value = serde_json::to_value(Chunk::tool_call_start("id-1", "calc")).unwrap();
assert_eq!(
v,
json!({"tool_call_start": {"id": "id-1", "name": "calc"}})
);
}
#[test]
fn tool_call_args_delta_serializes_with_named_fields() {
let v: Value =
serde_json::to_value(Chunk::tool_call_args_delta("id-1", r#"{"x":1}"#)).unwrap();
assert_eq!(
v,
json!({"tool_call_args_delta": {"id": "id-1", "args_json_delta": r#"{"x":1}"#}}),
);
}
#[test]
fn chunk_round_trips_all_variants() {
for chunk in [
Chunk::text_delta("partial"),
Chunk::tool_call_start("c1", "weather"),
Chunk::tool_call_args_delta("c1", r#"{"city":"NYC"}"#),
Chunk::tool_call_end("c1"),
Chunk::Usage(Usage {
input_tokens: 10,
output_tokens: 20,
}),
Chunk::Stop(StopReason::EndTurn),
] {
let json = serde_json::to_string(&chunk).unwrap();
let back: Chunk = serde_json::from_str(&json).unwrap();
assert_eq!(back, chunk);
}
}
#[test]
fn reassemble_tool_call_args_from_deltas_by_id() {
let stream = [
Chunk::tool_call_start("a", "weather"),
Chunk::tool_call_args_delta("a", r#"{"city":"#),
Chunk::tool_call_args_delta("b", "IGNORED"),
Chunk::tool_call_args_delta("a", r#""NYC"}"#),
Chunk::tool_call_end("a"),
];
let mut assembled = String::new();
for c in &stream {
if let Chunk::ToolCallArgsDelta {
id,
args_json_delta,
} = c
&& id == "a"
{
assembled.push_str(args_json_delta);
}
}
let parsed: Value = serde_json::from_str(&assembled).unwrap();
assert_eq!(parsed, json!({"city": "NYC"}));
}
#[test]
fn usage_total_sums_input_and_output() {
let u = Usage {
input_tokens: 100,
output_tokens: 250,
};
assert_eq!(u.total_tokens(), 350);
}
#[test]
fn usage_total_saturates_on_overflow() {
let u = Usage {
input_tokens: u64::MAX,
output_tokens: 1,
};
assert_eq!(u.total_tokens(), u64::MAX);
}
#[test]
fn usage_default_is_all_zero() {
let u = Usage::default();
assert_eq!(u.input_tokens, 0);
assert_eq!(u.output_tokens, 0);
assert_eq!(u.total_tokens(), 0);
}
#[test]
fn stop_reason_serializes_to_snake_case() {
assert_eq!(
serde_json::to_string(&StopReason::EndTurn).unwrap(),
r#""end_turn""#
);
assert_eq!(
serde_json::to_string(&StopReason::MaxTokens).unwrap(),
r#""max_tokens""#
);
assert_eq!(
serde_json::to_string(&StopReason::StopSequence).unwrap(),
r#""stop_sequence""#,
);
assert_eq!(
serde_json::to_string(&StopReason::ToolUse).unwrap(),
r#""tool_use""#
);
assert_eq!(
serde_json::to_string(&StopReason::Refusal).unwrap(),
r#""refusal""#,
);
}
#[test]
fn stop_reason_round_trips() {
for reason in [
StopReason::EndTurn,
StopReason::MaxTokens,
StopReason::StopSequence,
StopReason::ToolUse,
StopReason::Refusal,
] {
let json = serde_json::to_string(&reason).unwrap();
let back: StopReason = serde_json::from_str(&json).unwrap();
assert_eq!(back, reason);
}
}
}