use crate::backend::TokenEventV2;
use inferd_proto::v2::{
ContentBlock, MessageV2, ResolvedV2, RoleV2, StopReasonV2, ToolCallId, UsageV2,
};
use serde::{Deserialize, Serialize};
use serde_json::Value;
#[derive(Debug, thiserror::Error, PartialEq, Eq)]
pub enum BodyError {
#[error("bedrock-invoke does not support {0} attachments in v0.2.0")]
AttachmentUnsupported(&'static str),
#[error("bedrock-invoke received an unknown content-block type")]
UnknownContentBlock,
#[error("bedrock-invoke tool_result content must be text only")]
NonTextToolResult,
}
#[derive(Debug, Clone, Serialize)]
pub(super) struct AnthropicRequest {
pub anthropic_version: &'static str,
pub messages: Vec<AnthropicMessage>,
#[serde(skip_serializing_if = "Option::is_none")]
pub system: Option<String>,
pub max_tokens: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_k: Option<u32>,
#[serde(skip_serializing_if = "Vec::is_empty")]
pub tools: Vec<AnthropicToolDecl>,
}
#[derive(Debug, Clone, Serialize)]
pub(super) struct AnthropicMessage {
pub role: &'static str,
pub content: Vec<AnthropicBlock>,
}
#[derive(Debug, Clone, Serialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub(super) enum AnthropicBlock {
Text {
text: String,
},
ToolUse {
id: String,
name: String,
input: Value,
},
ToolResult {
tool_use_id: String,
content: String,
},
}
#[derive(Debug, Clone, Serialize)]
pub(super) struct AnthropicToolDecl {
pub name: String,
pub description: String,
pub input_schema: Value,
}
const DEFAULT_MAX_TOKENS: u32 = 1024;
pub(super) fn request_body(resolved: &ResolvedV2) -> Result<AnthropicRequest, BodyError> {
if !resolved.attachments.is_empty() {
return Err(BodyError::AttachmentUnsupported("multimodal"));
}
let mut system: Option<String> = None;
let mut messages: Vec<AnthropicMessage> = Vec::with_capacity(resolved.messages.len());
for msg in &resolved.messages {
match msg.role {
RoleV2::System => {
let mut buf = String::new();
for block in &msg.content {
match block {
ContentBlock::Text { text } => {
if !buf.is_empty() {
buf.push('\n');
}
buf.push_str(text);
}
ContentBlock::Unknown => return Err(BodyError::UnknownContentBlock),
_ => {
return Err(BodyError::AttachmentUnsupported("system-non-text"));
}
}
}
system = Some(match system {
Some(prev) => format!("{prev}\n{buf}"),
None => buf,
});
}
RoleV2::User | RoleV2::Assistant => {
let role = role_to_str(msg.role);
let blocks = blocks_for(msg)?;
if !blocks.is_empty() {
messages.push(AnthropicMessage {
role,
content: blocks,
});
}
}
}
}
let tools = resolved
.tools
.iter()
.map(|t| AnthropicToolDecl {
name: t.name.clone(),
description: t.description.clone(),
input_schema: t.input_schema.clone(),
})
.collect();
Ok(AnthropicRequest {
anthropic_version: "bedrock-2023-05-31",
messages,
system,
max_tokens: resolved.max_tokens.unwrap_or(DEFAULT_MAX_TOKENS),
temperature: resolved.temperature,
top_p: resolved.top_p,
top_k: resolved.top_k,
tools,
})
}
fn blocks_for(msg: &MessageV2) -> Result<Vec<AnthropicBlock>, BodyError> {
let mut out: Vec<AnthropicBlock> = Vec::with_capacity(msg.content.len());
for block in &msg.content {
match block {
ContentBlock::Text { text } => out.push(AnthropicBlock::Text { text: text.clone() }),
ContentBlock::ToolUse {
tool_call_id,
name,
input,
} => out.push(AnthropicBlock::ToolUse {
id: tool_call_id.as_str().to_string(),
name: name.clone(),
input: input.clone(),
}),
ContentBlock::ToolResult {
tool_call_id,
content,
} => {
let body = tool_result_to_string(content)?;
out.push(AnthropicBlock::ToolResult {
tool_use_id: tool_call_id.as_str().to_string(),
content: body,
});
}
ContentBlock::Image { .. } => return Err(BodyError::AttachmentUnsupported("image")),
ContentBlock::Audio { .. } => return Err(BodyError::AttachmentUnsupported("audio")),
ContentBlock::Video { .. } => return Err(BodyError::AttachmentUnsupported("video")),
ContentBlock::Unknown => return Err(BodyError::UnknownContentBlock),
}
}
Ok(out)
}
fn tool_result_to_string(content: &[ContentBlock]) -> Result<String, BodyError> {
let mut out = String::new();
for block in content {
match block {
ContentBlock::Text { text } => out.push_str(text),
_ => return Err(BodyError::NonTextToolResult),
}
}
Ok(out)
}
fn role_to_str(role: RoleV2) -> &'static str {
match role {
RoleV2::System => "system",
RoleV2::User => "user",
RoleV2::Assistant => "assistant",
}
}
#[derive(Debug, Clone, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub(super) enum AnthropicEvent {
MessageStart {
#[serde(default)]
message: MessageStartPayload,
},
ContentBlockStart {
index: usize,
content_block: ContentBlockStart,
},
ContentBlockDelta {
index: usize,
delta: ContentBlockDelta,
},
ContentBlockStop {
index: usize,
},
MessageDelta {
#[serde(default)]
delta: MessageDeltaPayload,
#[serde(default)]
usage: Option<UsagePayload>,
},
MessageStop {},
Ping {},
Error {
#[serde(default)]
error: ErrorPayload,
},
#[serde(other)]
Unknown,
}
#[derive(Debug, Clone, Default, Deserialize)]
pub(super) struct MessageStartPayload {
#[serde(default)]
pub usage: Option<UsagePayload>,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub(super) enum ContentBlockStart {
Text {
#[serde(default)]
text: String,
},
ToolUse {
id: String,
name: String,
#[serde(default)]
input: Value,
},
Thinking {},
#[serde(other)]
Unknown,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub(super) enum ContentBlockDelta {
TextDelta {
text: String,
},
InputJsonDelta {
partial_json: String,
},
ThinkingDelta {
#[serde(default)]
#[allow(dead_code)]
thinking: String,
},
SignatureDelta {
#[serde(default)]
#[allow(dead_code)]
signature: String,
},
#[serde(other)]
Unknown,
}
#[derive(Debug, Clone, Default, Deserialize)]
pub(super) struct MessageDeltaPayload {
#[serde(default)]
pub stop_reason: Option<String>,
}
#[derive(Debug, Clone, Default, Deserialize)]
pub(super) struct UsagePayload {
#[serde(default)]
pub input_tokens: u32,
#[serde(default)]
pub output_tokens: u32,
}
#[derive(Debug, Clone, Default, Deserialize)]
pub(super) struct ErrorPayload {
#[serde(default, rename = "type")]
pub kind: String,
#[serde(default)]
pub message: String,
}
#[derive(Debug)]
struct ToolBlockBuffer {
id: String,
name: String,
partial: String,
}
#[derive(Debug, Default)]
pub(super) struct StreamAccumulator {
tool_blocks: std::collections::HashMap<usize, ToolBlockBuffer>,
stop_reason: Option<String>,
input_tokens: u32,
output_tokens: u32,
error: Option<String>,
}
impl StreamAccumulator {
pub(super) fn new() -> Self {
Self::default()
}
pub(super) fn ingest(&mut self, event: AnthropicEvent) -> Vec<TokenEventV2> {
let mut out = Vec::new();
match event {
AnthropicEvent::MessageStart { message } => {
if let Some(u) = message.usage {
self.input_tokens = u.input_tokens;
self.output_tokens = u.output_tokens;
}
}
AnthropicEvent::ContentBlockStart {
index,
content_block,
} => match content_block {
ContentBlockStart::Text { text } => {
if !text.is_empty() {
out.push(TokenEventV2::Text(text));
}
}
ContentBlockStart::ToolUse { id, name, input } => {
let partial = if input.is_null() {
String::new()
} else {
serde_json::to_string(&input).unwrap_or_default()
};
self.tool_blocks
.insert(index, ToolBlockBuffer { id, name, partial });
}
ContentBlockStart::Thinking {} | ContentBlockStart::Unknown => {}
},
AnthropicEvent::ContentBlockDelta { index, delta } => match delta {
ContentBlockDelta::TextDelta { text } => {
if !text.is_empty() {
out.push(TokenEventV2::Text(text));
}
}
ContentBlockDelta::InputJsonDelta { partial_json } => {
if let Some(buf) = self.tool_blocks.get_mut(&index) {
buf.partial.push_str(&partial_json);
}
}
ContentBlockDelta::ThinkingDelta { .. }
| ContentBlockDelta::SignatureDelta { .. }
| ContentBlockDelta::Unknown => {}
},
AnthropicEvent::ContentBlockStop { index } => {
if let Some(buf) = self.tool_blocks.remove(&index) {
let parsed: Value = if buf.partial.is_empty() {
Value::Object(serde_json::Map::new())
} else {
serde_json::from_str(&buf.partial).unwrap_or(Value::Null)
};
out.push(TokenEventV2::ToolUse {
tool_call_id: ToolCallId(buf.id),
name: buf.name,
input: parsed,
});
}
}
AnthropicEvent::MessageDelta { delta, usage } => {
if let Some(reason) = delta.stop_reason {
self.stop_reason = Some(reason);
}
if let Some(u) = usage {
if u.input_tokens > 0 {
self.input_tokens = u.input_tokens;
}
if u.output_tokens > 0 {
self.output_tokens = u.output_tokens;
}
}
}
AnthropicEvent::MessageStop {} | AnthropicEvent::Ping {} => {}
AnthropicEvent::Error { error } => {
self.error = Some(if error.message.is_empty() {
error.kind
} else {
format!("{}: {}", error.kind, error.message)
});
}
AnthropicEvent::Unknown => {}
}
out
}
pub(super) fn finalize(mut self) -> Vec<TokenEventV2> {
let mut out = Vec::new();
self.tool_blocks.clear();
let stop_reason = if self.error.is_some() {
StopReasonV2::Error
} else {
match self.stop_reason.as_deref() {
Some("end_turn") | Some("stop_sequence") => StopReasonV2::EndTurn,
Some("max_tokens") => StopReasonV2::MaxTokens,
Some("tool_use") => StopReasonV2::ToolUse,
None => StopReasonV2::Error,
Some(_) => StopReasonV2::EndTurn,
}
};
out.push(TokenEventV2::Done {
stop_reason,
usage: UsageV2 {
input_tokens: self.input_tokens,
output_tokens: self.output_tokens,
},
});
out
}
}
#[cfg(test)]
mod tests {
use super::*;
use inferd_proto::v2::{ContentBlock, MessageV2, RequestV2, RoleV2, Tool};
use serde_json::json;
fn resolved_with_messages(messages: Vec<MessageV2>) -> ResolvedV2 {
RequestV2 {
id: "req-1".into(),
messages,
..Default::default()
}
.resolve()
.unwrap()
}
#[test]
fn text_only_request_round_trips() {
let r = resolved_with_messages(vec![
MessageV2 {
role: RoleV2::System,
content: vec![ContentBlock::Text {
text: "be terse".into(),
}],
},
MessageV2 {
role: RoleV2::User,
content: vec![ContentBlock::Text {
text: "hello".into(),
}],
},
]);
let body = request_body(&r).unwrap();
assert_eq!(body.anthropic_version, "bedrock-2023-05-31");
assert_eq!(body.system.as_deref(), Some("be terse"));
assert_eq!(body.max_tokens, DEFAULT_MAX_TOKENS);
assert_eq!(body.messages.len(), 1);
assert_eq!(body.messages[0].role, "user");
assert!(
matches!(body.messages[0].content[0], AnthropicBlock::Text { ref text } if text == "hello")
);
}
#[test]
fn multiple_system_messages_concatenate() {
let r = resolved_with_messages(vec![
MessageV2 {
role: RoleV2::System,
content: vec![ContentBlock::Text { text: "one".into() }],
},
MessageV2 {
role: RoleV2::System,
content: vec![ContentBlock::Text { text: "two".into() }],
},
MessageV2 {
role: RoleV2::User,
content: vec![ContentBlock::Text { text: "go".into() }],
},
]);
let body = request_body(&r).unwrap();
assert_eq!(body.system.as_deref(), Some("one\ntwo"));
}
#[test]
fn tools_translate_to_anthropic_tools() {
let mut r = resolved_with_messages(vec![MessageV2 {
role: RoleV2::User,
content: vec![ContentBlock::Text { text: "go".into() }],
}]);
r.tools = vec![Tool {
name: "lookup".into(),
description: "look something up".into(),
input_schema: json!({"type": "object"}),
}];
let body = request_body(&r).unwrap();
assert_eq!(body.tools.len(), 1);
assert_eq!(body.tools[0].name, "lookup");
assert_eq!(body.tools[0].description, "look something up");
}
#[test]
fn assistant_tool_use_round_trips_inline() {
let r = resolved_with_messages(vec![
MessageV2 {
role: RoleV2::User,
content: vec![ContentBlock::Text { text: "go".into() }],
},
MessageV2 {
role: RoleV2::Assistant,
content: vec![ContentBlock::ToolUse {
tool_call_id: ToolCallId("call_1".into()),
name: "lookup".into(),
input: json!({"q": "x"}),
}],
},
MessageV2 {
role: RoleV2::User,
content: vec![ContentBlock::ToolResult {
tool_call_id: ToolCallId("call_1".into()),
content: vec![ContentBlock::Text {
text: "answer".into(),
}],
}],
},
]);
let body = request_body(&r).unwrap();
assert_eq!(body.messages.len(), 3);
assert!(matches!(
body.messages[1].content[0],
AnthropicBlock::ToolUse { ref id, ref name, .. }
if id == "call_1" && name == "lookup"
));
assert!(matches!(
body.messages[2].content[0],
AnthropicBlock::ToolResult { ref tool_use_id, ref content }
if tool_use_id == "call_1" && content == "answer"
));
}
#[test]
fn image_attachment_block_is_rejected() {
let r = ResolvedV2 {
id: "x".into(),
messages: vec![MessageV2 {
role: RoleV2::User,
content: vec![ContentBlock::Image {
attachment_id: "img-1".into(),
}],
}],
attachments: Vec::new(),
tools: Vec::new(),
temperature: None,
top_p: None,
top_k: None,
max_tokens: None,
stream: None,
};
let err = request_body(&r).unwrap_err();
assert_eq!(err, BodyError::AttachmentUnsupported("image"));
}
#[test]
fn body_serialises_anthropic_version_first() {
let r = resolved_with_messages(vec![MessageV2 {
role: RoleV2::User,
content: vec![ContentBlock::Text { text: "hi".into() }],
}]);
let body = request_body(&r).unwrap();
let json = serde_json::to_string(&body).unwrap();
assert!(
json.starts_with(r#"{"anthropic_version":"bedrock-2023-05-31""#),
"body: {json}"
);
}
fn parse_event(s: &str) -> AnthropicEvent {
serde_json::from_str(s).unwrap()
}
#[test]
fn accumulator_passes_text_deltas_through() {
let mut acc = StreamAccumulator::new();
acc.ingest(parse_event(
r#"{"type":"message_start","message":{"usage":{"input_tokens":5,"output_tokens":0}}}"#,
));
acc.ingest(parse_event(
r#"{"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}"#,
));
let out = acc.ingest(parse_event(
r#"{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"hello"}}"#,
));
assert_eq!(out.len(), 1);
assert!(matches!(&out[0], TokenEventV2::Text(t) if t == "hello"));
let out = acc.ingest(parse_event(
r#"{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" world"}}"#,
));
assert!(matches!(&out[0], TokenEventV2::Text(t) if t == " world"));
acc.ingest(parse_event(r#"{"type":"content_block_stop","index":0}"#));
acc.ingest(parse_event(
r#"{"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"output_tokens":2}}"#,
));
acc.ingest(parse_event(r#"{"type":"message_stop"}"#));
let final_evs = acc.finalize();
assert_eq!(final_evs.len(), 1);
match &final_evs[0] {
TokenEventV2::Done { stop_reason, usage } => {
assert_eq!(*stop_reason, StopReasonV2::EndTurn);
assert_eq!(usage.input_tokens, 5);
assert_eq!(usage.output_tokens, 2);
}
other => panic!("expected Done, got {other:?}"),
}
}
#[test]
fn accumulator_assembles_tool_use_across_input_json_deltas() {
let mut acc = StreamAccumulator::new();
acc.ingest(parse_event(r#"{"type":"message_start","message":{}}"#));
acc.ingest(parse_event(
r#"{"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"call_42","name":"lookup","input":{}}}"#,
));
acc.ingest(parse_event(
r#"{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":"{\"q\":\"x"}}"#,
));
let out = acc.ingest(parse_event(
r#"{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":"y\"}"}}"#,
));
assert!(out.is_empty(), "tool-use events fire on block_stop");
let out = acc.ingest(parse_event(r#"{"type":"content_block_stop","index":0}"#));
assert_eq!(out.len(), 1);
match &out[0] {
TokenEventV2::ToolUse {
tool_call_id,
name,
input,
} => {
assert_eq!(tool_call_id.as_str(), "call_42");
assert_eq!(name, "lookup");
let _ = input;
}
other => panic!("expected ToolUse, got {other:?}"),
}
acc.ingest(parse_event(
r#"{"type":"message_delta","delta":{"stop_reason":"tool_use"},"usage":{"output_tokens":7}}"#,
));
let final_evs = acc.finalize();
assert!(matches!(
&final_evs[0],
TokenEventV2::Done {
stop_reason: StopReasonV2::ToolUse,
..
}
));
}
#[test]
fn accumulator_tool_use_with_only_partial_json_parses() {
let mut acc = StreamAccumulator::new();
acc.ingest(parse_event(r#"{"type":"message_start","message":{}}"#));
acc.ingest(parse_event(
r#"{"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"call_1","name":"f","input":null}}"#,
));
acc.ingest(parse_event(
r#"{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":"{\"k\":\"v\"}"}}"#,
));
let out = acc.ingest(parse_event(r#"{"type":"content_block_stop","index":0}"#));
match &out[0] {
TokenEventV2::ToolUse { input, .. } => assert_eq!(input, &json!({"k": "v"})),
other => panic!("expected ToolUse, got {other:?}"),
}
}
#[test]
fn accumulator_missing_message_delta_is_error() {
let mut acc = StreamAccumulator::new();
acc.ingest(parse_event(r#"{"type":"message_start","message":{}}"#));
acc.ingest(parse_event(
r#"{"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}"#,
));
acc.ingest(parse_event(
r#"{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"hi"}}"#,
));
let final_evs = acc.finalize();
assert!(matches!(
&final_evs[0],
TokenEventV2::Done {
stop_reason: StopReasonV2::Error,
..
}
));
}
#[test]
fn accumulator_max_tokens_stop_reason() {
let mut acc = StreamAccumulator::new();
acc.ingest(parse_event(r#"{"type":"message_start","message":{}}"#));
acc.ingest(parse_event(
r#"{"type":"message_delta","delta":{"stop_reason":"max_tokens"}}"#,
));
let final_evs = acc.finalize();
assert!(matches!(
&final_evs[0],
TokenEventV2::Done {
stop_reason: StopReasonV2::MaxTokens,
..
}
));
}
#[test]
fn accumulator_explicit_error_event_surfaces() {
let mut acc = StreamAccumulator::new();
acc.ingest(parse_event(
r#"{"type":"error","error":{"type":"overloaded_error","message":"upstream busy"}}"#,
));
let final_evs = acc.finalize();
assert!(matches!(
&final_evs[0],
TokenEventV2::Done {
stop_reason: StopReasonV2::Error,
..
}
));
}
#[test]
fn accumulator_skips_unknown_event_types() {
let mut acc = StreamAccumulator::new();
acc.ingest(parse_event(
r#"{"type":"future_event_type","payload":{"x":1}}"#,
));
acc.ingest(parse_event(
r#"{"type":"message_delta","delta":{"stop_reason":"end_turn"}}"#,
));
let final_evs = acc.finalize();
assert!(matches!(
&final_evs[0],
TokenEventV2::Done {
stop_reason: StopReasonV2::EndTurn,
..
}
));
}
#[test]
fn accumulator_skips_thinking_deltas() {
let mut acc = StreamAccumulator::new();
acc.ingest(parse_event(r#"{"type":"message_start","message":{}}"#));
acc.ingest(parse_event(
r#"{"type":"content_block_start","index":0,"content_block":{"type":"thinking"}}"#,
));
let out = acc.ingest(parse_event(
r#"{"type":"content_block_delta","index":0,"delta":{"type":"thinking_delta","thinking":"reasoning..."}}"#,
));
assert!(out.is_empty(), "thinking delta should not surface");
acc.ingest(parse_event(r#"{"type":"content_block_stop","index":0}"#));
acc.ingest(parse_event(
r#"{"type":"message_delta","delta":{"stop_reason":"end_turn"}}"#,
));
let final_evs = acc.finalize();
assert!(matches!(
&final_evs[0],
TokenEventV2::Done {
stop_reason: StopReasonV2::EndTurn,
..
}
));
}
}