use futures::Stream;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::borrow::Cow;
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
use tokio_util::sync::CancellationToken;
pub use crate::stream_error_kind::StreamErrorKind;
use crate::types::{
AgentContext, AssistantMessage, ContentBlock, Cost, ModelSpec, StopReason, Usage,
};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum StreamTransport {
#[default]
Sse,
}
#[derive(Debug, Clone, Default)]
pub enum CacheStrategy {
#[default]
None,
Auto,
Anthropic,
Google {
ttl: Duration,
},
}
pub type OnRawPayload = Arc<dyn Fn(&str) + Send + Sync>;
#[derive(Clone, Default)]
pub struct StreamOptions {
pub temperature: Option<f64>,
pub max_tokens: Option<u64>,
pub session_id: Option<String>,
pub api_key: Option<String>,
pub transport: StreamTransport,
pub cache_strategy: CacheStrategy,
pub on_raw_payload: Option<OnRawPayload>,
}
impl std::fmt::Debug for StreamOptions {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("StreamOptions")
.field("temperature", &self.temperature)
.field("max_tokens", &self.max_tokens)
.field("session_id", &self.session_id)
.field("api_key", &self.api_key.as_ref().map(|_| "[REDACTED]"))
.field("transport", &self.transport)
.field("cache_strategy", &self.cache_strategy)
.field(
"on_raw_payload",
&self.on_raw_payload.as_ref().map(|_| "<callback>"),
)
.finish()
}
}
#[non_exhaustive]
#[derive(Debug, Clone)]
pub enum AssistantMessageEvent {
Start,
TextStart { content_index: usize },
TextDelta { content_index: usize, delta: String },
TextEnd { content_index: usize },
ThinkingStart { content_index: usize },
ThinkingDelta { content_index: usize, delta: String },
ThinkingEnd {
content_index: usize,
signature: Option<String>,
},
ToolCallStart {
content_index: usize,
id: String,
name: String,
},
ToolCallDelta { content_index: usize, delta: String },
ToolCallEnd { content_index: usize },
Done {
stop_reason: StopReason,
usage: Usage,
cost: Cost,
},
Error {
stop_reason: StopReason,
error_message: String,
usage: Option<Usage>,
error_kind: Option<StreamErrorKind>,
},
}
impl AssistantMessageEvent {
pub fn error(message: impl Into<String>) -> Self {
Self::Error {
stop_reason: StopReason::Error,
error_message: message.into(),
usage: None,
error_kind: None,
}
}
pub fn error_throttled(message: impl Into<String>) -> Self {
Self::Error {
stop_reason: StopReason::Error,
error_message: message.into(),
usage: None,
error_kind: Some(StreamErrorKind::Throttled),
}
}
pub fn error_context_overflow(message: impl Into<String>) -> Self {
Self::Error {
stop_reason: StopReason::Error,
error_message: message.into(),
usage: None,
error_kind: Some(StreamErrorKind::ContextWindowExceeded),
}
}
pub fn error_auth(message: impl Into<String>) -> Self {
Self::Error {
stop_reason: StopReason::Error,
error_message: message.into(),
usage: None,
error_kind: Some(StreamErrorKind::Auth),
}
}
pub fn error_network(message: impl Into<String>) -> Self {
Self::Error {
stop_reason: StopReason::Error,
error_message: message.into(),
usage: None,
error_kind: Some(StreamErrorKind::Network),
}
}
pub fn error_content_filtered(message: impl Into<String>) -> Self {
Self::Error {
stop_reason: StopReason::Error,
error_message: message.into(),
usage: None,
error_kind: Some(StreamErrorKind::ContentFiltered),
}
}
pub fn text_response(text: &str) -> Vec<Self> {
vec![
Self::Start,
Self::TextStart { content_index: 0 },
Self::TextDelta {
content_index: 0,
delta: text.to_string(),
},
Self::TextEnd { content_index: 0 },
Self::Done {
stop_reason: StopReason::Stop,
usage: Usage::default(),
cost: Cost::default(),
},
]
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum AssistantMessageDelta {
Text {
content_index: usize,
delta: Cow<'static, str>,
},
Thinking {
content_index: usize,
delta: Cow<'static, str>,
},
ToolCall {
content_index: usize,
delta: Cow<'static, str>,
},
}
pub trait StreamFn: Send + Sync {
fn stream<'a>(
&'a self,
model: &'a ModelSpec,
context: &'a AgentContext,
options: &'a StreamOptions,
cancellation_token: CancellationToken,
) -> Pin<Box<dyn Stream<Item = AssistantMessageEvent> + Send + 'a>>;
}
pub fn sanitize_incomplete_tool_calls(message: &mut AssistantMessage) -> usize {
let mut fixed = 0;
for block in &mut message.content {
if let ContentBlock::ToolCall {
arguments,
partial_json,
..
} = block
{
let needs_fix = partial_json.is_some() || !arguments.is_object();
if needs_fix {
*arguments = Value::Object(serde_json::Map::new());
*partial_json = None;
fixed += 1;
}
}
}
fixed
}
#[allow(clippy::too_many_lines)]
pub fn accumulate_message(
events: Vec<AssistantMessageEvent>,
provider: &str,
model_id: &str,
) -> Result<AssistantMessage, String> {
fn ensure_block_open(
open_blocks: &[bool],
content_index: usize,
event_name: &str,
) -> Result<(), String> {
match open_blocks.get(content_index) {
Some(false) => Err(format!(
"{event_name}: block at index {content_index} is already closed"
)),
Some(true) | None => Ok(()),
}
}
fn all_open_blocks_are_tool_calls(content: &[ContentBlock], open_blocks: &[bool]) -> bool {
open_blocks
.iter()
.enumerate()
.filter(|(_, open)| **open)
.all(|(content_index, _)| {
matches!(
content.get(content_index),
Some(ContentBlock::ToolCall { .. })
)
})
}
let mut content: Option<Vec<ContentBlock>> = None;
let mut open_blocks: Vec<bool> = Vec::new();
let mut stop_reason: Option<StopReason> = None;
let mut usage: Option<Usage> = None;
let mut cost: Option<Cost> = None;
let mut error_message: Option<String> = None;
let mut error_kind: Option<StreamErrorKind> = None;
let mut saw_start = false;
let mut saw_terminal = false;
let tolerate_truncated_tool_args = events.iter().any(|e| {
matches!(
e,
AssistantMessageEvent::Done {
stop_reason: StopReason::Length,
..
} | AssistantMessageEvent::Error {
stop_reason: StopReason::Length,
..
}
)
});
for event in events {
match &event {
AssistantMessageEvent::TextStart { .. }
| AssistantMessageEvent::TextDelta { .. }
| AssistantMessageEvent::TextEnd { .. }
| AssistantMessageEvent::ThinkingStart { .. }
| AssistantMessageEvent::ThinkingDelta { .. }
| AssistantMessageEvent::ThinkingEnd { .. }
| AssistantMessageEvent::ToolCallStart { .. }
| AssistantMessageEvent::ToolCallDelta { .. }
| AssistantMessageEvent::ToolCallEnd { .. } => {
if saw_terminal {
return Err("content event after terminal event".into());
}
}
AssistantMessageEvent::Done { .. } | AssistantMessageEvent::Error { .. } => {
if saw_terminal {
return Err("duplicate terminal event".into());
}
}
AssistantMessageEvent::Start => {}
}
match event {
AssistantMessageEvent::Start => {
if saw_start {
return Err("duplicate Start event".into());
}
saw_start = true;
content = Some(Vec::new());
}
AssistantMessageEvent::TextStart { content_index } => {
let blocks = content.as_mut().ok_or("TextStart before Start")?;
if content_index != blocks.len() {
return Err(format!(
"TextStart content_index {content_index} != content length {}",
blocks.len()
));
}
blocks.push(ContentBlock::Text {
text: String::new(),
});
open_blocks.push(true);
}
AssistantMessageEvent::TextDelta {
content_index,
delta,
} => {
let blocks = content.as_mut().ok_or("TextDelta before Start")?;
ensure_block_open(&open_blocks, content_index, "TextDelta")?;
let block = blocks
.get_mut(content_index)
.ok_or_else(|| format!("TextDelta: invalid content_index {content_index}"))?;
match block {
ContentBlock::Text { text } => text.push_str(&delta),
_ => {
return Err(format!(
"TextDelta: block at index {content_index} is not Text"
));
}
}
}
AssistantMessageEvent::TextEnd { content_index } => {
let blocks = content.as_ref().ok_or("TextEnd before Start")?;
let block = blocks
.get(content_index)
.ok_or_else(|| format!("TextEnd: invalid content_index {content_index}"))?;
if !matches!(block, ContentBlock::Text { .. }) {
return Err(format!(
"TextEnd: block at index {content_index} is not Text"
));
}
ensure_block_open(&open_blocks, content_index, "TextEnd")?;
if let Some(open) = open_blocks.get_mut(content_index) {
*open = false;
}
}
AssistantMessageEvent::ThinkingStart { content_index } => {
let blocks = content.as_mut().ok_or("ThinkingStart before Start")?;
if content_index != blocks.len() {
return Err(format!(
"ThinkingStart content_index {content_index} != content length {}",
blocks.len()
));
}
blocks.push(ContentBlock::Thinking {
thinking: String::new(),
signature: None,
});
open_blocks.push(true);
}
AssistantMessageEvent::ThinkingDelta {
content_index,
delta,
} => {
let blocks = content.as_mut().ok_or("ThinkingDelta before Start")?;
ensure_block_open(&open_blocks, content_index, "ThinkingDelta")?;
let block = blocks.get_mut(content_index).ok_or_else(|| {
format!("ThinkingDelta: invalid content_index {content_index}")
})?;
match block {
ContentBlock::Thinking { thinking, .. } => thinking.push_str(&delta),
_ => {
return Err(format!(
"ThinkingDelta: block at index {content_index} is not Thinking"
));
}
}
}
AssistantMessageEvent::ThinkingEnd {
content_index,
signature,
} => {
let blocks = content.as_mut().ok_or("ThinkingEnd before Start")?;
ensure_block_open(&open_blocks, content_index, "ThinkingEnd")?;
let block = blocks
.get_mut(content_index)
.ok_or_else(|| format!("ThinkingEnd: invalid content_index {content_index}"))?;
match block {
ContentBlock::Thinking { signature: sig, .. } => *sig = signature,
_ => {
return Err(format!(
"ThinkingEnd: block at index {content_index} is not Thinking"
));
}
}
if let Some(open) = open_blocks.get_mut(content_index) {
*open = false;
}
}
AssistantMessageEvent::ToolCallStart {
content_index,
id,
name,
} => {
let blocks = content.as_mut().ok_or("ToolCallStart before Start")?;
if content_index != blocks.len() {
return Err(format!(
"ToolCallStart content_index {content_index} != content length {}",
blocks.len()
));
}
blocks.push(ContentBlock::ToolCall {
id,
name,
arguments: Value::Null,
partial_json: Some(String::new()),
});
open_blocks.push(true);
}
AssistantMessageEvent::ToolCallDelta {
content_index,
delta,
} => {
let blocks = content.as_mut().ok_or("ToolCallDelta before Start")?;
ensure_block_open(&open_blocks, content_index, "ToolCallDelta")?;
let block = blocks.get_mut(content_index).ok_or_else(|| {
format!("ToolCallDelta: invalid content_index {content_index}")
})?;
match block {
ContentBlock::ToolCall { partial_json, .. } => {
let pj = partial_json
.as_mut()
.ok_or("ToolCallDelta: partial_json already consumed")?;
pj.push_str(&delta);
}
_ => {
return Err(format!(
"ToolCallDelta: block at index {content_index} is not ToolCall"
));
}
}
}
AssistantMessageEvent::ToolCallEnd { content_index } => {
let blocks = content.as_mut().ok_or("ToolCallEnd before Start")?;
let block = blocks
.get_mut(content_index)
.ok_or_else(|| format!("ToolCallEnd: invalid content_index {content_index}"))?;
ensure_block_open(&open_blocks, content_index, "ToolCallEnd")?;
match block {
ContentBlock::ToolCall {
arguments,
partial_json,
..
} => {
let json_str = partial_json
.as_ref()
.ok_or("ToolCallEnd: partial_json already consumed")?
.clone();
if json_str.is_empty() {
*arguments = Value::Object(serde_json::Map::new());
*partial_json = None;
} else {
match serde_json::from_str::<Value>(&json_str) {
Ok(v) => {
*arguments = v;
*partial_json = None;
}
Err(e) => {
if tolerate_truncated_tool_args {
} else {
return Err(format!(
"ToolCallEnd: failed to parse arguments JSON: {e}"
));
}
}
}
}
}
_ => {
return Err(format!(
"ToolCallEnd: block at index {content_index} is not ToolCall"
));
}
}
if let Some(open) = open_blocks.get_mut(content_index) {
*open = false;
}
}
AssistantMessageEvent::Done {
stop_reason: sr,
usage: u,
cost: c,
} => {
if let Some(idx) = open_blocks.iter().position(|open| *open) {
let content = content.as_ref().ok_or("Done before Start")?;
if tolerate_truncated_tool_args
&& all_open_blocks_are_tool_calls(content, &open_blocks)
{
tracing::debug!(
"Done(Length) with unterminated content block at index {idx} — tolerating for max-tokens recovery"
);
} else {
return Err(format!(
"Done received with unterminated content block at index {idx}"
));
}
}
stop_reason = Some(sr);
usage = Some(u);
cost = Some(c);
saw_terminal = true;
}
AssistantMessageEvent::Error {
stop_reason: sr,
error_message: em,
usage: u,
error_kind: ek,
} => {
stop_reason = Some(sr);
error_message = Some(em);
error_kind = ek;
if let Some(u) = u {
usage = Some(u);
}
saw_terminal = true;
}
}
}
let content = content.ok_or("no Start event found")?;
let stop_reason = stop_reason.ok_or("no terminal event (Done or Error) found")?;
let timestamp = crate::util::now_timestamp();
Ok(AssistantMessage {
content,
provider: provider.to_owned(),
model_id: model_id.to_owned(),
usage: usage.unwrap_or_default(),
cost: cost.unwrap_or_default(),
stop_reason,
error_message,
error_kind,
timestamp,
cache_hint: None,
})
}
const _: () = {
const fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<StreamErrorKind>();
assert_send_sync::<StreamTransport>();
assert_send_sync::<StreamOptions>();
assert_send_sync::<AssistantMessageEvent>();
assert_send_sync::<AssistantMessageDelta>();
};
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn done_with_unterminated_text_block_is_rejected() {
let events = vec![
AssistantMessageEvent::Start,
AssistantMessageEvent::TextStart { content_index: 0 },
AssistantMessageEvent::TextDelta {
content_index: 0,
delta: "hi".into(),
},
AssistantMessageEvent::Done {
stop_reason: StopReason::Stop,
usage: Usage::default(),
cost: Cost::default(),
},
];
let err = accumulate_message(events, "test", "test").unwrap_err();
assert!(err.contains("unterminated content block"), "got: {err}");
}
#[test]
fn done_with_unterminated_tool_call_block_is_rejected() {
let events = vec![
AssistantMessageEvent::Start,
AssistantMessageEvent::ToolCallStart {
content_index: 0,
id: "t1".into(),
name: "foo".into(),
},
AssistantMessageEvent::ToolCallDelta {
content_index: 0,
delta: "{}".into(),
},
AssistantMessageEvent::Done {
stop_reason: StopReason::ToolUse,
usage: Usage::default(),
cost: Cost::default(),
},
];
let err = accumulate_message(events, "test", "test").unwrap_err();
assert!(err.contains("unterminated content block"), "got: {err}");
}
#[test]
fn done_with_all_blocks_terminated_succeeds() {
let events = vec![
AssistantMessageEvent::Start,
AssistantMessageEvent::TextStart { content_index: 0 },
AssistantMessageEvent::TextDelta {
content_index: 0,
delta: "ok".into(),
},
AssistantMessageEvent::TextEnd { content_index: 0 },
AssistantMessageEvent::Done {
stop_reason: StopReason::Stop,
usage: Usage::default(),
cost: Cost::default(),
},
];
let msg = accumulate_message(events, "test", "test").expect("should succeed");
assert_eq!(msg.content.len(), 1);
}
#[test]
fn error_with_unterminated_block_is_allowed() {
let events = vec![
AssistantMessageEvent::Start,
AssistantMessageEvent::TextStart { content_index: 0 },
AssistantMessageEvent::Error {
stop_reason: StopReason::Error,
error_message: "boom".into(),
usage: None,
error_kind: None,
},
];
let msg = accumulate_message(events, "test", "test").expect("error terminal ok");
assert_eq!(msg.error_message.as_deref(), Some("boom"));
}
#[test]
fn error_constructor_sets_kind_none() {
let event = AssistantMessageEvent::error("boom");
match event {
AssistantMessageEvent::Error { error_kind, .. } => {
assert_eq!(error_kind, None);
}
other => panic!("expected Error, got {other:?}"),
}
}
#[test]
fn error_throttled_constructor_sets_kind() {
let event = AssistantMessageEvent::error_throttled("rate limited");
match event {
AssistantMessageEvent::Error {
error_kind,
error_message,
..
} => {
assert_eq!(error_kind, Some(StreamErrorKind::Throttled));
assert_eq!(error_message, "rate limited");
}
other => panic!("expected Error, got {other:?}"),
}
}
#[test]
fn error_context_overflow_constructor_sets_kind() {
let event = AssistantMessageEvent::error_context_overflow("too long");
match event {
AssistantMessageEvent::Error { error_kind, .. } => {
assert_eq!(error_kind, Some(StreamErrorKind::ContextWindowExceeded));
}
other => panic!("expected Error, got {other:?}"),
}
}
#[test]
fn error_auth_constructor_sets_kind() {
let event = AssistantMessageEvent::error_auth("bad key");
match event {
AssistantMessageEvent::Error { error_kind, .. } => {
assert_eq!(error_kind, Some(StreamErrorKind::Auth));
}
other => panic!("expected Error, got {other:?}"),
}
}
#[test]
fn error_network_constructor_sets_kind() {
let event = AssistantMessageEvent::error_network("timeout");
match event {
AssistantMessageEvent::Error { error_kind, .. } => {
assert_eq!(error_kind, Some(StreamErrorKind::Network));
}
other => panic!("expected Error, got {other:?}"),
}
}
#[test]
fn error_content_filtered_constructor_sets_kind() {
let event = AssistantMessageEvent::error_content_filtered("blocked by safety filter");
match event {
AssistantMessageEvent::Error {
error_kind,
error_message,
..
} => {
assert_eq!(error_kind, Some(StreamErrorKind::ContentFiltered));
assert_eq!(error_message, "blocked by safety filter");
}
other => panic!("expected Error, got {other:?}"),
}
}
#[test]
fn text_response_produces_valid_event_sequence() {
let events = AssistantMessageEvent::text_response("hello world");
assert_eq!(events.len(), 5);
assert!(matches!(events[0], AssistantMessageEvent::Start));
assert!(matches!(
events[1],
AssistantMessageEvent::TextStart { content_index: 0 }
));
match &events[2] {
AssistantMessageEvent::TextDelta {
content_index,
delta,
} => {
assert_eq!(*content_index, 0);
assert_eq!(delta, "hello world");
}
other => panic!("expected TextDelta, got {other:?}"),
}
assert!(matches!(
events[3],
AssistantMessageEvent::TextEnd { content_index: 0 }
));
assert!(matches!(
events[4],
AssistantMessageEvent::Done {
stop_reason: StopReason::Stop,
..
}
));
}
#[test]
fn done_length_with_unterminated_tool_call_is_tolerated() {
let events = vec![
AssistantMessageEvent::Start,
AssistantMessageEvent::ToolCallStart {
content_index: 0,
id: "tc_1".into(),
name: "read_file".into(),
},
AssistantMessageEvent::ToolCallDelta {
content_index: 0,
delta: r#"{"path": "/tmp"#.into(),
},
AssistantMessageEvent::Done {
stop_reason: StopReason::Length,
usage: Usage::default(),
cost: Cost::default(),
},
];
let msg = accumulate_message(events, "test", "test")
.expect("Done(Length) with open tool-call block should succeed");
assert_eq!(msg.stop_reason, StopReason::Length);
match &msg.content[0] {
ContentBlock::ToolCall { partial_json, .. } => {
assert!(
partial_json.is_some(),
"partial_json should be Some for incomplete tool call"
);
}
other => panic!("expected ToolCall, got {other:?}"),
}
}
#[test]
fn done_length_with_unterminated_text_block_is_rejected() {
let events = vec![
AssistantMessageEvent::Start,
AssistantMessageEvent::TextStart { content_index: 0 },
AssistantMessageEvent::TextDelta {
content_index: 0,
delta: "partial".into(),
},
AssistantMessageEvent::Done {
stop_reason: StopReason::Length,
usage: Usage::default(),
cost: Cost::default(),
},
];
let err = accumulate_message(events, "test", "test").unwrap_err();
assert!(err.contains("unterminated content block"), "got: {err}");
}
#[test]
fn done_length_with_unterminated_thinking_block_is_rejected() {
let events = vec![
AssistantMessageEvent::Start,
AssistantMessageEvent::ThinkingStart { content_index: 0 },
AssistantMessageEvent::ThinkingDelta {
content_index: 0,
delta: "partial".into(),
},
AssistantMessageEvent::Done {
stop_reason: StopReason::Length,
usage: Usage::default(),
cost: Cost::default(),
},
];
let err = accumulate_message(events, "test", "test").unwrap_err();
assert!(err.contains("unterminated content block"), "got: {err}");
}
#[test]
fn text_response_accumulates_correctly() {
let events = AssistantMessageEvent::text_response("accumulated text");
let msg = accumulate_message(events, "test", "test-model").expect("accumulation failed");
assert_eq!(msg.content.len(), 1);
assert_eq!(ContentBlock::extract_text(&msg.content), "accumulated text");
assert_eq!(msg.stop_reason, StopReason::Stop);
}
#[test]
fn text_delta_after_text_end_is_rejected() {
let events = vec![
AssistantMessageEvent::Start,
AssistantMessageEvent::TextStart { content_index: 0 },
AssistantMessageEvent::TextDelta {
content_index: 0,
delta: "hello".into(),
},
AssistantMessageEvent::TextEnd { content_index: 0 },
AssistantMessageEvent::TextDelta {
content_index: 0,
delta: " again".into(),
},
AssistantMessageEvent::Done {
stop_reason: StopReason::Stop,
usage: Usage::default(),
cost: Cost::default(),
},
];
let err = accumulate_message(events, "test", "test").unwrap_err();
assert_eq!(err, "TextDelta: block at index 0 is already closed");
}
#[test]
fn duplicate_text_end_is_rejected() {
let events = vec![
AssistantMessageEvent::Start,
AssistantMessageEvent::TextStart { content_index: 0 },
AssistantMessageEvent::TextDelta {
content_index: 0,
delta: "hello".into(),
},
AssistantMessageEvent::TextEnd { content_index: 0 },
AssistantMessageEvent::TextEnd { content_index: 0 },
AssistantMessageEvent::Done {
stop_reason: StopReason::Stop,
usage: Usage::default(),
cost: Cost::default(),
},
];
let err = accumulate_message(events, "test", "test").unwrap_err();
assert_eq!(err, "TextEnd: block at index 0 is already closed");
}
#[test]
fn duplicate_thinking_end_is_rejected() {
let events = vec![
AssistantMessageEvent::Start,
AssistantMessageEvent::ThinkingStart { content_index: 0 },
AssistantMessageEvent::ThinkingDelta {
content_index: 0,
delta: "step 1".into(),
},
AssistantMessageEvent::ThinkingEnd {
content_index: 0,
signature: Some("sig-1".into()),
},
AssistantMessageEvent::ThinkingEnd {
content_index: 0,
signature: Some("sig-2".into()),
},
AssistantMessageEvent::Done {
stop_reason: StopReason::Stop,
usage: Usage::default(),
cost: Cost::default(),
},
];
let err = accumulate_message(events, "test", "test").unwrap_err();
assert_eq!(err, "ThinkingEnd: block at index 0 is already closed");
}
#[test]
fn tool_call_delta_after_end_is_rejected() {
let events = vec![
AssistantMessageEvent::Start,
AssistantMessageEvent::ToolCallStart {
content_index: 0,
id: "tool-1".into(),
name: "read_file".into(),
},
AssistantMessageEvent::ToolCallDelta {
content_index: 0,
delta: "{\"path\":\"/tmp/a\"}".into(),
},
AssistantMessageEvent::ToolCallEnd { content_index: 0 },
AssistantMessageEvent::ToolCallDelta {
content_index: 0,
delta: ",\"extra\":true}".into(),
},
AssistantMessageEvent::Done {
stop_reason: StopReason::ToolUse,
usage: Usage::default(),
cost: Cost::default(),
},
];
let err = accumulate_message(events, "test", "test").unwrap_err();
assert_eq!(err, "ToolCallDelta: block at index 0 is already closed");
}
#[test]
fn duplicate_tool_call_end_is_rejected() {
let events = vec![
AssistantMessageEvent::Start,
AssistantMessageEvent::ToolCallStart {
content_index: 0,
id: "tool-1".into(),
name: "read_file".into(),
},
AssistantMessageEvent::ToolCallDelta {
content_index: 0,
delta: "{\"path\":\"/tmp/a\"}".into(),
},
AssistantMessageEvent::ToolCallEnd { content_index: 0 },
AssistantMessageEvent::ToolCallEnd { content_index: 0 },
AssistantMessageEvent::Done {
stop_reason: StopReason::ToolUse,
usage: Usage::default(),
cost: Cost::default(),
},
];
let err = accumulate_message(events, "test", "test").unwrap_err();
assert_eq!(err, "ToolCallEnd: block at index 0 is already closed");
}
fn build_assistant_with_tool_call(
arguments: Value,
partial_json: Option<String>,
) -> AssistantMessage {
AssistantMessage {
content: vec![ContentBlock::ToolCall {
id: "tc_1".into(),
name: "read_file".into(),
arguments,
partial_json,
}],
provider: "test".into(),
model_id: "test".into(),
usage: Usage::default(),
cost: Cost::default(),
stop_reason: StopReason::Length,
error_message: None,
error_kind: None,
timestamp: 0,
cache_hint: None,
}
}
#[test]
fn sanitize_null_arguments_with_partial_json_returns_empty_object() {
let mut msg = build_assistant_with_tool_call(Value::Null, Some("{\"path\": \"/tm".into()));
let fixed = sanitize_incomplete_tool_calls(&mut msg);
assert_eq!(fixed, 1);
match &msg.content[0] {
ContentBlock::ToolCall {
arguments,
partial_json,
..
} => {
assert_eq!(*arguments, Value::Object(serde_json::Map::new()));
assert!(
partial_json.is_none(),
"partial_json must be cleared after scrub"
);
}
other => panic!("expected ToolCall, got {other:?}"),
}
}
#[test]
fn sanitize_leaves_valid_object_arguments_untouched() {
let args = serde_json::json!({ "path": "/tmp/a" });
let mut msg = build_assistant_with_tool_call(args.clone(), None);
let fixed = sanitize_incomplete_tool_calls(&mut msg);
assert_eq!(fixed, 0);
match &msg.content[0] {
ContentBlock::ToolCall {
arguments,
partial_json,
..
} => {
assert_eq!(*arguments, args);
assert!(partial_json.is_none());
}
other => panic!("expected ToolCall, got {other:?}"),
}
}
#[test]
fn sanitize_coerces_non_object_arguments() {
let mut msg = build_assistant_with_tool_call(Value::String("truncated".into()), None);
let fixed = sanitize_incomplete_tool_calls(&mut msg);
assert_eq!(fixed, 1);
match &msg.content[0] {
ContentBlock::ToolCall { arguments, .. } => {
assert_eq!(*arguments, Value::Object(serde_json::Map::new()));
}
other => panic!("expected ToolCall, got {other:?}"),
}
}
#[test]
fn sanitize_is_idempotent() {
let mut msg = build_assistant_with_tool_call(Value::Null, Some("{\"path\":".into()));
assert_eq!(sanitize_incomplete_tool_calls(&mut msg), 1);
assert_eq!(sanitize_incomplete_tool_calls(&mut msg), 0);
}
#[test]
fn sanitize_preserves_non_tool_blocks() {
let mut msg = AssistantMessage {
content: vec![
ContentBlock::Text {
text: "hello".into(),
},
ContentBlock::ToolCall {
id: "tc_1".into(),
name: "foo".into(),
arguments: Value::Null,
partial_json: Some("{".into()),
},
ContentBlock::Text {
text: "world".into(),
},
],
provider: "test".into(),
model_id: "test".into(),
usage: Usage::default(),
cost: Cost::default(),
stop_reason: StopReason::Length,
error_message: None,
error_kind: None,
timestamp: 0,
cache_hint: None,
};
let fixed = sanitize_incomplete_tool_calls(&mut msg);
assert_eq!(fixed, 1);
match &msg.content[0] {
ContentBlock::Text { text } => assert_eq!(text, "hello"),
other => panic!("expected Text, got {other:?}"),
}
match &msg.content[2] {
ContentBlock::Text { text } => assert_eq!(text, "world"),
other => panic!("expected Text, got {other:?}"),
}
}
#[test]
fn accumulate_plus_sanitize_yields_adapter_safe_tool_call() {
let events = vec![
AssistantMessageEvent::Start,
AssistantMessageEvent::ToolCallStart {
content_index: 0,
id: "tc_1".into(),
name: "read_file".into(),
},
AssistantMessageEvent::ToolCallDelta {
content_index: 0,
delta: r#"{"path": "/tm"#.into(),
},
AssistantMessageEvent::Done {
stop_reason: StopReason::Length,
usage: Usage::default(),
cost: Cost::default(),
},
];
let mut msg = accumulate_message(events, "test", "test")
.expect("Done(Length) with unterminated tool-call should accumulate");
match &msg.content[0] {
ContentBlock::ToolCall {
arguments,
partial_json,
..
} => {
assert!(partial_json.is_some());
assert!(arguments.is_null());
}
other => panic!("expected ToolCall, got {other:?}"),
}
sanitize_incomplete_tool_calls(&mut msg);
match &msg.content[0] {
ContentBlock::ToolCall {
arguments,
partial_json,
..
} => {
assert!(arguments.is_object());
assert_eq!(arguments.as_object().unwrap().len(), 0);
assert!(partial_json.is_none());
}
other => panic!("expected ToolCall, got {other:?}"),
}
}
}