use defect_core::llm::{
CompletionRequest, ImageData, Message, MessageContent, ProviderChunk, ProviderErrorKind,
ReasoningEffort, Role, SamplingParams, StopReason, ThinkingConfig, ToolChoice, ToolResultBody,
ToolResultContent,
};
use defect_core::tool::ToolSchema;
use futures::StreamExt;
use serde_json::json;
use sse_stream::Sse;
use tokio_util::sync::CancellationToken;
use super::*;
use crate::wire::anthropic::components as wire;
#[derive(Debug, thiserror::Error)]
#[error("test sse never errors")]
struct NeverError;
fn make_sse_events(events: &[(&str, &str)]) -> Vec<Sse> {
events
.iter()
.map(|(name, data)| Sse {
event: Some((*name).to_owned()),
data: Some((*data).to_owned()),
id: None,
retry: None,
})
.collect()
}
fn run_state_machine(
events: &[(&str, &str)],
) -> (DecoderState, Vec<Result<ProviderChunk, ProviderError>>) {
let mut state = DecoderState::default();
let mut out = Vec::new();
for sse in make_sse_events(events) {
let mut buf = Vec::new();
process_sse(&mut state, sse, &mut buf);
buf.reverse();
out.extend(buf);
if state.fatal {
break;
}
}
(state, out)
}
async fn run_decode_stream_generic(
events: &[(&str, &str)],
cancel: CancellationToken,
) -> Vec<Result<ProviderChunk, ProviderError>> {
let items: Vec<Result<Sse, NeverError>> = make_sse_events(events).into_iter().map(Ok).collect();
let stream = futures::stream::iter(items);
decode_stream_generic(stream, cancel)
.collect::<Vec<_>>()
.await
}
fn ok_chunks(results: Vec<Result<ProviderChunk, ProviderError>>) -> Vec<ProviderChunk> {
results.into_iter().map(|r| r.expect("err chunk")).collect()
}
#[test]
fn encode_minimal_request() {
let req = CompletionRequest {
model: "claude-opus-4-7".into(),
system: Some("you are helpful".into()),
messages: vec![Message {
role: Role::User,
content: vec![MessageContent::Text { text: "hi".into() }].into(),
}],
tools: vec![],
tool_choice: ToolChoice::Auto,
sampling: SamplingParams::default(),
hosted_capabilities: ::defect_core::llm::HostedCapabilities::default(),
};
let wire_req = encode_request(&req, ThinkingWireFormat::Adaptive);
assert_eq!(wire_req.max_tokens, i64::from(DEFAULT_MAX_TOKENS));
assert!(matches!(wire_req.stream, Some(true)));
assert!(matches!(
wire_req.system,
Some(wire::SystemPrompt::SystemPromptVariant1(ref blocks))
if matches!(
blocks.as_slice(),
[wire::TextBlockParam {
text,
cache_control: Some(_),
..
}] if text == "you are helpful"
)
));
assert_eq!(wire_req.messages.len(), 1);
assert!(matches!(
wire_req.messages[0].role,
wire::MessageParamRole::User
));
let wire::MessageParamContent::MessageParamContentVariant1(content) =
&wire_req.messages[0].content
else {
panic!("expected list content");
};
assert!(matches!(
content.as_slice(),
[wire::ContentBlockParam::TextBlockParam(wire::TextBlockParam {
text,
cache_control: Some(_),
..
})] if text == "hi"
));
assert!(wire_req.tools.is_none());
assert!(matches!(
wire_req.tool_choice,
Some(wire::ToolChoice::ToolChoiceAuto(_))
));
assert!(wire_req.thinking.is_none());
}
#[test]
fn encode_request_carries_sampling() {
let req = CompletionRequest {
model: "claude-opus-4-7".into(),
system: None,
messages: vec![Message {
role: Role::User,
content: vec![MessageContent::Text { text: "x".into() }].into(),
}],
tools: vec![],
tool_choice: ToolChoice::Required,
sampling: SamplingParams {
max_tokens: Some(8000),
temperature: Some(0.5),
top_p: Some(0.9),
top_k: Some(40),
stop_sequences: vec!["END".into()],
thinking: ThinkingConfig::Enabled {
budget_tokens: Some(2000),
},
reasoning_effort: None,
},
hosted_capabilities: ::defect_core::llm::HostedCapabilities::default(),
};
let w = encode_request(&req, ThinkingWireFormat::Legacy);
assert_eq!(w.max_tokens, 8000);
assert_eq!(w.temperature, Some(0.5));
assert_eq!(w.top_p, Some(0.9));
assert_eq!(w.top_k, Some(40));
assert_eq!(w.stop_sequences.as_deref(), Some(&["END".to_string()][..]));
assert!(matches!(
w.tool_choice,
Some(wire::ToolChoice::ToolChoiceAny(_))
));
assert!(matches!(
w.thinking,
Some(wire::ThinkingConfigParam::ThinkingConfigEnabled(ref t)) if t.budget_tokens == 2000
));
}
fn thinking_budget(w: &wire::CreateMessageParams) -> Option<i64> {
match &w.thinking {
Some(wire::ThinkingConfigParam::ThinkingConfigEnabled(t)) => Some(t.budget_tokens),
_ => None,
}
}
fn req_with(
effort: Option<ReasoningEffort>,
thinking: ThinkingConfig,
max_tokens: u32,
) -> CompletionRequest {
CompletionRequest {
model: "claude-opus-4-7".into(),
system: None,
messages: vec![Message {
role: Role::User,
content: vec![MessageContent::Text { text: "x".into() }].into(),
}],
tools: vec![],
tool_choice: ToolChoice::Auto,
sampling: SamplingParams {
max_tokens: Some(max_tokens),
temperature: None,
top_p: None,
top_k: None,
stop_sequences: vec![],
thinking,
reasoning_effort: effort,
},
hosted_capabilities: ::defect_core::llm::HostedCapabilities::default(),
}
}
#[test]
fn legacy_effort_maps_to_thinking_budget() {
let w = encode_request(
&req_with(
Some(ReasoningEffort::High),
ThinkingConfig::Disabled,
64_000,
),
ThinkingWireFormat::Legacy,
);
assert_eq!(thinking_budget(&w), Some(16_384));
assert!(w.output_config.is_none());
}
#[test]
fn legacy_effort_none_disables_thinking() {
let w = encode_request(
&req_with(
Some(ReasoningEffort::None),
ThinkingConfig::Enabled {
budget_tokens: Some(8_000),
},
64_000,
),
ThinkingWireFormat::Legacy,
);
assert!(w.thinking.is_none());
}
#[test]
fn legacy_effort_takes_precedence_over_thinking_config() {
let w = encode_request(
&req_with(
Some(ReasoningEffort::Low),
ThinkingConfig::Enabled {
budget_tokens: Some(30_000),
},
64_000,
),
ThinkingWireFormat::Legacy,
);
assert_eq!(thinking_budget(&w), Some(4_096));
}
#[test]
fn legacy_effort_budget_clamped_below_max_tokens() {
let w = encode_request(
&req_with(
Some(ReasoningEffort::Xhigh),
ThinkingConfig::Disabled,
5_000,
),
ThinkingWireFormat::Legacy,
);
assert_eq!(thinking_budget(&w), Some(4_999));
}
#[test]
fn legacy_thinking_dropped_when_max_tokens_too_small_for_minimum_budget() {
let w = encode_request(
&req_with(Some(ReasoningEffort::High), ThinkingConfig::Disabled, 501),
ThinkingWireFormat::Legacy,
);
assert!(w.thinking.is_none());
}
#[test]
fn legacy_thinking_config_used_when_no_effort() {
let w = encode_request(
&req_with(
None,
ThinkingConfig::Enabled {
budget_tokens: Some(2_000),
},
64_000,
),
ThinkingWireFormat::Legacy,
);
assert_eq!(thinking_budget(&w), Some(2_000));
}
fn output_effort(w: &wire::CreateMessageParams) -> Option<wire::OutputConfigEffort> {
w.output_config.as_ref().and_then(|c| c.effort)
}
fn is_adaptive(w: &wire::CreateMessageParams) -> bool {
matches!(
w.thinking,
Some(wire::ThinkingConfigParam::ThinkingConfigAdaptive(_))
)
}
#[test]
fn adaptive_effort_maps_to_output_config_effort() {
let w = encode_request(
&req_with(
Some(ReasoningEffort::High),
ThinkingConfig::Disabled,
64_000,
),
ThinkingWireFormat::Adaptive,
);
assert!(is_adaptive(&w));
assert_eq!(output_effort(&w), Some(wire::OutputConfigEffort::High));
assert_eq!(thinking_budget(&w), None);
}
#[test]
fn adaptive_xhigh_maps_to_xhigh_effort_without_budget_clamp() {
let w = encode_request(
&req_with(Some(ReasoningEffort::Xhigh), ThinkingConfig::Disabled, 500),
ThinkingWireFormat::Adaptive,
);
assert!(is_adaptive(&w));
assert_eq!(output_effort(&w), Some(wire::OutputConfigEffort::Xhigh));
}
#[test]
fn adaptive_effort_none_disables_thinking() {
let w = encode_request(
&req_with(
Some(ReasoningEffort::None),
ThinkingConfig::Enabled {
budget_tokens: Some(8_000),
},
64_000,
),
ThinkingWireFormat::Adaptive,
);
assert!(w.thinking.is_none());
assert!(w.output_config.is_none());
}
#[test]
fn adaptive_enabled_without_effort_omits_output_config() {
let w = encode_request(
&req_with(
None,
ThinkingConfig::Enabled {
budget_tokens: Some(2_000),
},
64_000,
),
ThinkingWireFormat::Adaptive,
);
assert!(is_adaptive(&w));
assert!(w.output_config.is_none());
}
#[test]
fn adaptive_disabled_thinking_sends_nothing() {
let w = encode_request(
&req_with(None, ThinkingConfig::Disabled, 64_000),
ThinkingWireFormat::Adaptive,
);
assert!(w.thinking.is_none());
assert!(w.output_config.is_none());
}
#[test]
fn encode_request_tool_uses_and_results() {
let req = CompletionRequest {
model: "claude-opus-4-7".into(),
system: None,
messages: vec![
Message {
role: Role::Assistant,
content: vec![MessageContent::ToolUse {
id: "toolu_1".into(),
name: "fs_read".into(),
args: json!({"path": "/tmp/a"}),
}]
.into(),
},
Message {
role: Role::User,
content: vec![MessageContent::ToolResult {
tool_use_id: "toolu_1".into(),
output: ToolResultBody::Text {
text: "hello".into(),
},
is_error: false,
}]
.into(),
},
],
tools: vec![ToolSchema {
name: "fs_read".into(),
description: "Read a file".into(),
input_schema: json!({
"type": "object",
"properties": {"path": {"type": "string"}},
"required": ["path"]
}),
}],
tool_choice: ToolChoice::Named {
name: "fs_read".into(),
},
sampling: SamplingParams::default(),
hosted_capabilities: ::defect_core::llm::HostedCapabilities::default(),
};
let w = encode_request(&req, ThinkingWireFormat::Adaptive);
assert!(matches!(
w.tool_choice,
Some(wire::ToolChoice::ToolChoiceTool(ref t)) if t.name == "fs_read"
));
let tools = w.tools.as_ref().expect("tools");
assert_eq!(tools.len(), 1);
let wire::ToolUnion::Tool(t) = &tools[0] else {
panic!("expected Tool");
};
assert_eq!(t.name, "fs_read");
assert_eq!(t.description.as_deref(), Some("Read a file"));
assert_eq!(
t.input_schema.required.as_deref(),
Some(&["path".to_string()][..])
);
let assistant = match &w.messages[0].content {
wire::MessageParamContent::MessageParamContentVariant1(v) => v,
_ => panic!("expected list content"),
};
let wire::ContentBlockParam::ToolUseBlockParam(tu) = &assistant[0] else {
panic!("expected tool_use_block_param");
};
assert_eq!(tu.id, "toolu_1");
assert_eq!(tu.name, "fs_read");
assert_eq!(tu.input.get("path"), Some(&json!("/tmp/a")));
assert!(tu.cache_control.is_some());
let user = match &w.messages[1].content {
wire::MessageParamContent::MessageParamContentVariant1(v) => v,
_ => panic!("expected list content"),
};
let wire::ContentBlockParam::ToolResultBlockParam(tr) = &user[0] else {
panic!("expected tool_result_block_param");
};
assert_eq!(tr.tool_use_id, "toolu_1");
assert_eq!(tr.is_error, Some(false));
let wire::ToolUnion::Tool(tool) = &tools[0] else {
panic!("expected Tool");
};
assert!(tool.cache_control.is_some());
}
#[test]
fn encode_multimodal_tool_result_emits_text_and_image_blocks() {
let req = CompletionRequest {
model: "claude-opus-4-7".into(),
system: None,
messages: vec![Message {
role: Role::User,
content: vec![MessageContent::ToolResult {
tool_use_id: "toolu_img".into(),
output: ToolResultBody::Content {
blocks: vec![
ToolResultContent::Text {
text: "here is the screenshot".into(),
},
ToolResultContent::Image {
mime: "image/png".into(),
data: ImageData::Base64 {
encoded: "AAAA".into(),
},
},
],
},
is_error: false,
}]
.into(),
}],
tools: vec![],
tool_choice: ToolChoice::Auto,
sampling: SamplingParams::default(),
hosted_capabilities: ::defect_core::llm::HostedCapabilities::default(),
};
let w = encode_request(&req, ThinkingWireFormat::Adaptive);
let user = match &w.messages[0].content {
wire::MessageParamContent::MessageParamContentVariant1(v) => v,
_ => panic!("expected list content"),
};
let wire::ContentBlockParam::ToolResultBlockParam(tr) = &user[0] else {
panic!("expected tool_result_block_param");
};
let Some(wire::ToolResultBlockParamContent102::ToolResultBlockParamContent102Variant1(blocks)) =
&tr.content
else {
panic!("expected list tool_result content");
};
assert_eq!(blocks.len(), 2);
assert!(matches!(
&blocks[0],
wire::ToolResultBlockParamContent::TextBlockParam(t) if t.text == "here is the screenshot"
));
assert!(matches!(
&blocks[1],
wire::ToolResultBlockParamContent::ImageBlockParam(_)
));
}
fn encode_with_thinking(text: &str, signature: Option<&str>) -> Vec<wire::ContentBlockParam> {
let req = CompletionRequest {
model: "claude-opus-4-7".into(),
system: None,
messages: vec![Message {
role: Role::Assistant,
content: vec![
MessageContent::Thinking {
text: text.to_owned(),
signature: signature.map(str::to_owned),
},
MessageContent::Text {
text: "answer".into(),
},
]
.into(),
}],
tools: vec![],
tool_choice: ToolChoice::Auto,
sampling: SamplingParams::default(),
hosted_capabilities: ::defect_core::llm::HostedCapabilities::default(),
};
let w = encode_request(&req, ThinkingWireFormat::Adaptive);
let wire::MessageParamContent::MessageParamContentVariant1(blocks) =
w.messages[0].content.clone()
else {
panic!("expected list content");
};
blocks
}
#[test]
fn encode_thinking_with_signature_emits_thinking_block_param() {
let blocks = encode_with_thinking("step 1", Some("sig-abc"));
assert_eq!(blocks.len(), 2);
let wire::ContentBlockParam::ThinkingBlockParam(t) = &blocks[0] else {
panic!("expected thinking block first, got {:?}", blocks[0]);
};
assert_eq!(t.thinking, "step 1");
assert_eq!(t.signature, "sig-abc");
}
#[test]
fn encode_thinking_without_signature_skips_thinking_block_param() {
let blocks = encode_with_thinking("step 1", None);
assert_eq!(blocks.len(), 1);
assert!(matches!(
&blocks[0],
wire::ContentBlockParam::TextBlockParam(t) if t.text == "answer"
));
}
const MODEL_START: &str = r#"{"type":"message_start","message":{"id":"msg_1","type":"message","role":"assistant","content":[],"model":"claude-opus-4-7","stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":42,"output_tokens":1}}}"#;
const TEXT_START_0: &str = r#"{"type":"content_block_start","index":0,"content_block":{"type":"text","text":"","citations":[]}}"#;
const TEXT_DELTA_0: &str =
r#"{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"hello "}}"#;
const TEXT_DELTA_1: &str =
r#"{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"world"}}"#;
const TEXT_STOP_0: &str = r#"{"type":"content_block_stop","index":0}"#;
const TOOL_START_1: &str = r#"{"type":"content_block_start","index":1,"content_block":{"type":"tool_use","id":"toolu_a","name":"calc","input":{}}}"#;
const TOOL_DELTA_1A: &str = r#"{"type":"content_block_delta","index":1,"delta":{"type":"input_json_delta","partial_json":"{\"x\":1"}}"#;
const TOOL_DELTA_1B: &str = r#"{"type":"content_block_delta","index":1,"delta":{"type":"input_json_delta","partial_json":"}"}}"#;
const TOOL_STOP_1: &str = r#"{"type":"content_block_stop","index":1}"#;
const MSG_DELTA_END: &str =
r#"{"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"output_tokens":17}}"#;
const MSG_DELTA_TOOL: &str =
r#"{"type":"message_delta","delta":{"stop_reason":"tool_use"},"usage":{"output_tokens":3}}"#;
const MSG_STOP: &str = r#"{"type":"message_stop"}"#;
const PING: &str = r#"{"type":"ping"}"#;
#[test]
fn decode_text_then_tool_use() {
let events = [
("message_start", MODEL_START),
("content_block_start", TEXT_START_0),
("content_block_delta", TEXT_DELTA_0),
("content_block_delta", TEXT_DELTA_1),
("content_block_stop", TEXT_STOP_0),
("content_block_start", TOOL_START_1),
("content_block_delta", TOOL_DELTA_1A),
("content_block_delta", TOOL_DELTA_1B),
("content_block_stop", TOOL_STOP_1),
("message_delta", MSG_DELTA_TOOL),
("message_stop", MSG_STOP),
];
let (state, results) = run_state_machine(&events);
assert!(state.stopped);
let chunks = ok_chunks(results);
let mut iter = chunks.into_iter();
assert!(
matches!(iter.next().unwrap(), ProviderChunk::MessageStart { id, .. } if id == "msg_1")
);
assert!(matches!(
iter.next().unwrap(),
ProviderChunk::Usage(u) if u.input_tokens == Some(42)
));
assert!(matches!(iter.next().unwrap(), ProviderChunk::TextDelta { text } if text == "hello "));
assert!(matches!(iter.next().unwrap(), ProviderChunk::TextDelta { text } if text == "world"));
assert!(matches!(
iter.next().unwrap(),
ProviderChunk::ToolUseStart { id, name } if id == "toolu_a" && name == "calc"
));
assert!(matches!(
iter.next().unwrap(),
ProviderChunk::ToolUseArgsDelta { id, fragment } if id == "toolu_a" && fragment.starts_with("{\"x\"")
));
assert!(matches!(
iter.next().unwrap(),
ProviderChunk::ToolUseArgsDelta { id, .. } if id == "toolu_a"
));
assert!(matches!(
iter.next().unwrap(),
ProviderChunk::ToolUseEnd { id } if id == "toolu_a"
));
assert!(matches!(
iter.next().unwrap(),
ProviderChunk::Stop {
reason: StopReason::ToolUse
}
));
assert!(matches!(
iter.next().unwrap(),
ProviderChunk::Usage(u) if u.output_tokens == Some(3)
));
}
#[test]
fn decode_two_concurrent_tool_uses_interleaved() {
let tool_start_b = r#"{"type":"content_block_start","index":2,"content_block":{"type":"tool_use","id":"toolu_b","name":"echo","input":{}}}"#;
let tool_delta_a = r#"{"type":"content_block_delta","index":1,"delta":{"type":"input_json_delta","partial_json":"{\"x\":1}"}}"#;
let tool_delta_b = r#"{"type":"content_block_delta","index":2,"delta":{"type":"input_json_delta","partial_json":"{\"y\":2}"}}"#;
let tool_stop_b = r#"{"type":"content_block_stop","index":2}"#;
let events = [
("message_start", MODEL_START),
("content_block_start", TOOL_START_1),
("content_block_start", tool_start_b),
("content_block_delta", tool_delta_a),
("content_block_delta", tool_delta_b),
("content_block_stop", TOOL_STOP_1),
("content_block_stop", tool_stop_b),
("message_delta", MSG_DELTA_TOOL),
("message_stop", MSG_STOP),
];
let (state, results) = run_state_machine(&events);
assert!(state.stopped);
let chunks = ok_chunks(results);
let tool_use_starts: Vec<_> = chunks
.iter()
.filter_map(|c| match c {
ProviderChunk::ToolUseStart { id, .. } => Some(id.clone()),
_ => None,
})
.collect();
assert_eq!(tool_use_starts, vec!["toolu_a", "toolu_b"]);
let args_pairs: Vec<_> = chunks
.iter()
.filter_map(|c| match c {
ProviderChunk::ToolUseArgsDelta { id, fragment } => {
Some((id.clone(), fragment.clone()))
}
_ => None,
})
.collect();
assert_eq!(
args_pairs,
vec![
("toolu_a".into(), "{\"x\":1}".into()),
("toolu_b".into(), "{\"y\":2}".into()),
]
);
let tool_use_ends: Vec<_> = chunks
.iter()
.filter_map(|c| match c {
ProviderChunk::ToolUseEnd { id } => Some(id.clone()),
_ => None,
})
.collect();
assert_eq!(tool_use_ends, vec!["toolu_a", "toolu_b"]);
}
#[test]
fn decode_thinking_with_signature() {
let think_start = r#"{"type":"content_block_start","index":0,"content_block":{"type":"thinking","thinking":"","signature":""}}"#;
let think_delta = r#"{"type":"content_block_delta","index":0,"delta":{"type":"thinking_delta","thinking":"step 1"}}"#;
let sig_delta = r#"{"type":"content_block_delta","index":0,"delta":{"type":"signature_delta","signature":"abc"}}"#;
let events = [
("message_start", MODEL_START),
("content_block_start", think_start),
("content_block_delta", think_delta),
("content_block_delta", sig_delta),
("content_block_stop", TEXT_STOP_0),
("message_delta", MSG_DELTA_END),
("message_stop", MSG_STOP),
];
let (_state, results) = run_state_machine(&events);
let chunks = ok_chunks(results);
let mut saw_think = false;
let mut saw_sig = false;
for c in chunks {
match c {
ProviderChunk::ThinkingDelta { text } if text == "step 1" => saw_think = true,
ProviderChunk::ThinkingSignature { signature } if signature == "abc" => saw_sig = true,
_ => {}
}
}
assert!(saw_think, "expected ThinkingDelta");
assert!(saw_sig, "expected ThinkingSignature");
}
#[test]
fn decode_ping_is_swallowed() {
let events = [
("message_start", MODEL_START),
("ping", PING),
("ping", PING),
("message_delta", MSG_DELTA_END),
("message_stop", MSG_STOP),
];
let (_state, results) = run_state_machine(&events);
let chunks = ok_chunks(results);
assert!(
!chunks
.iter()
.any(|c| matches!(c, ProviderChunk::TextDelta { .. }))
);
let stops = chunks
.iter()
.filter(|c| matches!(c, ProviderChunk::Stop { .. }))
.count();
assert_eq!(stops, 1);
}
#[test]
fn decode_error_event_terminates() {
let err = r#"{"type":"error","error":{"type":"overloaded_error","message":"too busy"}}"#;
let events = [("message_start", MODEL_START), ("error", err)];
let (state, results) = run_state_machine(&events);
assert!(state.fatal);
let last = results.last().expect("at least one chunk");
assert!(last.is_err(), "last must be Err");
let kind = &last.as_ref().err().unwrap().kind;
assert!(matches!(kind, ProviderErrorKind::ServerError { .. }));
}
#[test]
fn decode_malformed_json_continues() {
let bad = r#"{not json}"#;
let events = [
("message_start", MODEL_START),
("content_block_start", TEXT_START_0),
("content_block_delta", bad),
("content_block_delta", TEXT_DELTA_0),
("content_block_stop", TEXT_STOP_0),
("message_delta", MSG_DELTA_END),
("message_stop", MSG_STOP),
];
let (state, results) = run_state_machine(&events);
assert!(state.stopped);
let mut saw_malformed = false;
let mut saw_text = false;
for r in results {
match r {
Err(e) if matches!(e.kind, ProviderErrorKind::Malformed(_)) => saw_malformed = true,
Ok(ProviderChunk::TextDelta { text }) if text == "hello " => saw_text = true,
_ => {}
}
}
assert!(saw_malformed);
assert!(saw_text);
}
#[tokio::test]
async fn decode_stream_end_to_end_text_path() {
let events = [
("message_start", MODEL_START),
("content_block_start", TEXT_START_0),
("content_block_delta", TEXT_DELTA_0),
("content_block_stop", TEXT_STOP_0),
("message_delta", MSG_DELTA_END),
("message_stop", MSG_STOP),
];
let chunks = run_decode_stream_generic(&events, CancellationToken::new()).await;
assert!(
chunks.iter().all(|r| r.is_ok()),
"got error chunks: {:?}",
chunks
);
let last = chunks.last().unwrap().as_ref().ok().unwrap();
assert!(matches!(last, ProviderChunk::Usage(_)));
}
#[tokio::test]
async fn decode_stream_protocol_violation_when_no_stop() {
let events = [
("message_start", MODEL_START),
("content_block_start", TEXT_START_0),
("content_block_delta", TEXT_DELTA_0),
("content_block_stop", TEXT_STOP_0),
];
let chunks = run_decode_stream_generic(&events, CancellationToken::new()).await;
let last = chunks.last().expect("chunks");
assert!(last.is_err());
let kind = &last.as_ref().err().unwrap().kind;
assert!(matches!(kind, ProviderErrorKind::ProtocolViolation { .. }));
}
fn message_has_breakpoint(w: &wire::CreateMessageParams, idx: usize) -> bool {
let wire::MessageParamContent::MessageParamContentVariant1(blocks) = &w.messages[idx].content
else {
return false;
};
blocks.iter().any(|b| match b {
wire::ContentBlockParam::TextBlockParam(b) => b.cache_control.is_some(),
wire::ContentBlockParam::ToolUseBlockParam(b) => b.cache_control.is_some(),
wire::ContentBlockParam::ToolResultBlockParam(b) => b.cache_control.is_some(),
wire::ContentBlockParam::ImageBlockParam(b) => b.cache_control.is_some(),
_ => false,
})
}
fn system_has_breakpoint(w: &wire::CreateMessageParams) -> bool {
matches!(
&w.system,
Some(wire::SystemPrompt::SystemPromptVariant1(blocks))
if blocks.iter().any(|b| b.cache_control.is_some())
)
}
fn text_msg(role: Role, text: &str) -> Message {
Message {
role,
content: vec![MessageContent::Text { text: text.into() }].into(),
}
}
#[test]
fn cache_breakpoints_are_end_biased() {
let messages: Vec<Message> = (0..6)
.map(|i| {
let role = if i % 2 == 0 {
Role::User
} else {
Role::Assistant
};
text_msg(role, &format!("m{i}"))
})
.collect();
let req = CompletionRequest {
model: "claude-opus-4-7".into(),
system: Some("sys".into()),
messages,
tools: vec![],
tool_choice: ToolChoice::Auto,
sampling: SamplingParams::default(),
hosted_capabilities: ::defect_core::llm::HostedCapabilities::default(),
};
let w = encode_request(&req, ThinkingWireFormat::Adaptive);
assert!(
system_has_breakpoint(&w),
"system must carry the static breakpoint"
);
assert!(message_has_breakpoint(&w, 5));
assert!(message_has_breakpoint(&w, 4));
assert!(message_has_breakpoint(&w, 3));
assert!(!message_has_breakpoint(&w, 2));
assert!(!message_has_breakpoint(&w, 1));
assert!(!message_has_breakpoint(&w, 0));
}
#[test]
fn cache_breakpoint_falls_back_to_last_tool_without_system() {
let req = CompletionRequest {
model: "claude-opus-4-7".into(),
system: None,
messages: vec![text_msg(Role::User, "hi")],
tools: vec![
ToolSchema {
name: "a".into(),
description: "first".into(),
input_schema: json!({"type": "object", "properties": {}}),
},
ToolSchema {
name: "b".into(),
description: "second".into(),
input_schema: json!({"type": "object", "properties": {}}),
},
],
tool_choice: ToolChoice::Auto,
sampling: SamplingParams::default(),
hosted_capabilities: ::defect_core::llm::HostedCapabilities::default(),
};
let w = encode_request(&req, ThinkingWireFormat::Adaptive);
assert!(!system_has_breakpoint(&w));
let tools = w.tools.as_ref().expect("tools");
let breakpoint_on = |i: usize| {
let wire::ToolUnion::Tool(t) = &tools[i] else {
panic!("expected Tool");
};
t.cache_control.is_some()
};
assert!(!breakpoint_on(0));
assert!(breakpoint_on(1));
}
#[tokio::test]
async fn decode_stream_cancel_terminates_silently() {
let events = [
("message_start", MODEL_START),
("content_block_start", TEXT_START_0),
("content_block_delta", TEXT_DELTA_0),
];
let cancel = CancellationToken::new();
cancel.cancel(); let chunks = run_decode_stream_generic(&events, cancel).await;
assert!(chunks.iter().all(|r| r.is_ok()), "expected no Err chunks");
}