use serde::Deserialize;
use crate::api::llm::LlmRequest;
use crate::error::{FlowError, Result};
use crate::json::Json;
use super::request::{
AnnotatedLlmRequest, FunctionDefinition, GenerationParams, Message, MessageContent, ToolChoice,
ToolChoiceFunction, ToolChoiceFunctionName, ToolDefinition,
};
use super::response::{
AnnotatedLlmResponse, ApiSpecificResponse, FinishReason, ResponseToolCall, Usage,
};
use super::traits::{LlmCodec, LlmResponseCodec};
pub struct AnthropicMessagesCodec;
#[derive(Deserialize)]
struct RawAnthropicResponse {
id: Option<String>,
model: Option<String>,
content: Option<Vec<Json>>,
stop_reason: Option<String>,
stop_sequence: Option<String>,
usage: Option<RawAnthropicUsage>,
#[serde(flatten)]
extra: serde_json::Map<String, Json>,
}
#[derive(Deserialize)]
struct RawAnthropicUsage {
input_tokens: Option<u64>,
output_tokens: Option<u64>,
cache_read_input_tokens: Option<u64>,
cache_creation_input_tokens: Option<u64>,
}
fn map_anthropic_stop_reason(reason: &str) -> FinishReason {
match reason {
"end_turn" => FinishReason::Complete,
"max_tokens" => FinishReason::Length,
"tool_use" => FinishReason::ToolUse,
other => FinishReason::Unknown(other.to_string()),
}
}
fn json_f64(v: f64) -> Json {
serde_json::Number::from_f64(v)
.map(Json::Number)
.unwrap_or(Json::Null)
}
const MODELED_REQUEST_KEYS: &[&str] = &[
"system",
"messages",
"model",
"max_tokens",
"temperature",
"top_p",
"stop_sequences",
"tools",
"tool_choice",
];
fn decode_anthropic_tool_choice(val: &Json) -> Option<ToolChoice> {
let obj = val.as_object()?;
let tc_type = obj.get("type")?.as_str()?;
match tc_type {
"auto" => Some(ToolChoice::Auto),
"any" => Some(ToolChoice::Required),
"tool" => {
let name = obj.get("name")?.as_str()?.to_string();
Some(ToolChoice::Specific(ToolChoiceFunction {
choice_type: "function".into(),
function: ToolChoiceFunctionName { name },
}))
}
_ => None,
}
}
fn encode_anthropic_tool_choice(tc: &ToolChoice) -> Json {
match tc {
ToolChoice::Auto => serde_json::json!({"type": "auto"}),
ToolChoice::Required => serde_json::json!({"type": "any"}),
ToolChoice::None => serde_json::json!({"type": "auto"}), ToolChoice::Specific(func) => {
serde_json::json!({"type": "tool", "name": func.function.name})
}
}
}
fn extract_system_message(system_val: &Json) -> Option<Message> {
if let Some(s) = system_val.as_str() {
Some(Message::System {
content: MessageContent::Text(s.to_string()),
name: None,
})
} else if let Some(arr) = system_val.as_array() {
let texts: Vec<&str> = arr
.iter()
.filter_map(|block| {
let block_type = block.get("type")?.as_str()?;
if block_type == "text" {
block.get("text")?.as_str()
} else {
None
}
})
.collect();
if texts.is_empty() {
None
} else {
Some(Message::System {
content: MessageContent::Text(texts.join("\n")),
name: None,
})
}
} else {
None
}
}
fn extract_system_text(msg: &Message) -> Option<String> {
match msg {
Message::System {
content: MessageContent::Text(s),
..
} => Some(s.clone()),
Message::System {
content: MessageContent::Parts(parts),
..
} => {
let texts: Vec<&str> = parts
.iter()
.map(|p| {
let super::request::ContentPart::Text { text } = p;
text.as_str()
})
.collect();
if texts.is_empty() {
None
} else {
Some(texts.join("\n"))
}
}
_ => None,
}
}
fn split_system_and_messages(messages: &[Message]) -> (Option<String>, Vec<&Message>) {
let mut system_text = None;
let mut non_system_messages = Vec::new();
for msg in messages {
if let Some(text) = extract_system_text(msg) {
system_text = Some(text);
} else {
non_system_messages.push(msg);
}
}
(system_text, non_system_messages)
}
fn insert_serialized<T: serde::Serialize>(
obj: &mut serde_json::Map<String, Json>,
key: &str,
value: &T,
context: &str,
) -> Result<()> {
let json = serde_json::to_value(value)
.map_err(|e| FlowError::Internal(format!("Anthropic Messages {context} encode: {e}")))?;
obj.insert(key.into(), json);
Ok(())
}
fn overlay_generation_params(obj: &mut serde_json::Map<String, Json>, params: &GenerationParams) {
if let Some(temp) = params.temperature {
obj.insert("temperature".into(), json_f64(temp));
}
if let Some(top_p) = params.top_p {
obj.insert("top_p".into(), json_f64(top_p));
}
if let Some(max_tokens) = params.max_tokens {
obj.insert("max_tokens".into(), Json::from(max_tokens));
}
}
fn encode_anthropic_tools(tools: &[ToolDefinition]) -> Vec<Json> {
tools
.iter()
.map(|td| {
let mut tool = serde_json::Map::new();
tool.insert("name".into(), Json::String(td.function.name.clone()));
if let Some(ref desc) = td.function.description {
tool.insert("description".into(), Json::String(desc.clone()));
}
if let Some(ref params) = td.function.parameters {
tool.insert("input_schema".into(), params.clone());
}
Json::Object(tool)
})
.collect()
}
impl LlmResponseCodec for AnthropicMessagesCodec {
fn decode_response(&self, response: &Json) -> Result<AnnotatedLlmResponse> {
let raw: RawAnthropicResponse = serde_json::from_value(response.clone())
.map_err(|e| FlowError::Internal(format!("Anthropic Messages response decode: {e}")))?;
let content_blocks = raw.content.as_ref();
let text_parts: Vec<&str> = content_blocks
.map(|blocks| {
blocks
.iter()
.filter_map(|block| {
let block_type = block.get("type")?.as_str()?;
if block_type == "text" {
block.get("text")?.as_str()
} else {
None
}
})
.collect()
})
.unwrap_or_default();
let message = if text_parts.is_empty() {
None
} else {
Some(MessageContent::Text(text_parts.join("\n")))
};
let tool_calls: Vec<ResponseToolCall> = content_blocks
.map(|blocks| {
blocks
.iter()
.filter_map(|block| {
let block_type = block.get("type")?.as_str()?;
if block_type == "tool_use" {
let id = block.get("id")?.as_str()?.to_string();
let name = block.get("name")?.as_str()?.to_string();
let arguments = block.get("input")?.clone();
Some(ResponseToolCall {
id,
name,
arguments,
})
} else {
None
}
})
.collect()
})
.unwrap_or_default();
let tool_calls = if tool_calls.is_empty() {
None
} else {
Some(tool_calls)
};
let finish_reason = raw.stop_reason.as_deref().map(map_anthropic_stop_reason);
let usage = raw.usage.map(|u| {
let prompt = u.input_tokens;
let completion = u.output_tokens;
Usage {
prompt_tokens: prompt,
completion_tokens: completion,
total_tokens: match (prompt, completion) {
(Some(p), Some(c)) => Some(p + c),
_ => None,
},
cache_read_tokens: u.cache_read_input_tokens,
cache_write_tokens: u.cache_creation_input_tokens,
}
});
let api_specific_content_blocks = raw.content.clone();
let api_specific = Some(ApiSpecificResponse::AnthropicMessages {
stop_sequence: raw.stop_sequence,
content_blocks: api_specific_content_blocks,
});
Ok(AnnotatedLlmResponse {
id: raw.id,
model: raw.model,
message,
tool_calls,
finish_reason,
usage,
api_specific,
extra: raw.extra,
})
}
}
impl LlmCodec for AnthropicMessagesCodec {
fn decode(&self, request: &LlmRequest) -> Result<AnnotatedLlmRequest> {
let obj = request
.content
.as_object()
.ok_or_else(|| FlowError::Internal("request content is not an object".into()))?;
let system_msg = obj.get("system").and_then(extract_system_message);
let mut messages: Vec<Message> = obj
.get("messages")
.map(|v| serde_json::from_value(v.clone()).unwrap_or_default())
.unwrap_or_default();
if let Some(sys) = system_msg {
messages.insert(0, sys);
}
let model = obj.get("model").and_then(|v| v.as_str()).map(String::from);
let temperature = obj.get("temperature").and_then(|v| v.as_f64());
let top_p = obj.get("top_p").and_then(|v| v.as_f64());
let max_tokens = obj.get("max_tokens").and_then(|v| v.as_u64());
let stop = obj
.get("stop_sequences")
.and_then(|v| serde_json::from_value::<Vec<String>>(v.clone()).ok());
let params =
if temperature.is_some() || max_tokens.is_some() || top_p.is_some() || stop.is_some() {
Some(GenerationParams {
temperature,
max_tokens,
top_p,
stop,
})
} else {
None
};
let tools: Option<Vec<ToolDefinition>> = obj.get("tools").and_then(|v| {
let arr = v.as_array()?;
let defs: Vec<ToolDefinition> = arr
.iter()
.filter_map(|tool| {
let name = tool.get("name")?.as_str()?.to_string();
let description = tool
.get("description")
.and_then(|d| d.as_str())
.map(String::from);
let parameters = tool.get("input_schema").cloned();
Some(ToolDefinition {
tool_type: "function".into(),
function: FunctionDefinition {
name,
description,
parameters,
},
})
})
.collect();
if defs.is_empty() { None } else { Some(defs) }
});
let tool_choice = obj
.get("tool_choice")
.and_then(decode_anthropic_tool_choice);
let extra: serde_json::Map<String, Json> = obj
.iter()
.filter(|(k, _)| !MODELED_REQUEST_KEYS.contains(&k.as_str()))
.map(|(k, v)| (k.clone(), v.clone()))
.collect();
Ok(AnnotatedLlmRequest {
messages,
model,
params,
tools,
tool_choice,
extra,
})
}
fn encode(&self, annotated: &AnnotatedLlmRequest, original: &LlmRequest) -> Result<LlmRequest> {
let mut content = original.content.clone();
let obj = content
.as_object_mut()
.ok_or_else(|| FlowError::Internal("original content is not an object".into()))?;
let (system_text, non_system_messages) = split_system_and_messages(&annotated.messages);
if let Some(text) = system_text {
obj.insert("system".into(), Json::String(text));
}
insert_serialized(obj, "messages", &non_system_messages, "messages")?;
if let Some(ref model) = annotated.model {
obj.insert("model".into(), Json::String(model.clone()));
}
if let Some(ref params) = annotated.params {
overlay_generation_params(obj, params);
if let Some(ref stop) = params.stop {
insert_serialized(obj, "stop_sequences", stop, "stop_sequences")?;
}
}
if let Some(ref tools) = annotated.tools {
let anthropic_tools = encode_anthropic_tools(tools);
insert_serialized(obj, "tools", &anthropic_tools, "tools")?;
}
if let Some(ref tool_choice) = annotated.tool_choice {
obj.insert(
"tool_choice".into(),
encode_anthropic_tool_choice(tool_choice),
);
}
for (k, v) in &annotated.extra {
obj.insert(k.clone(), v.clone());
}
Ok(LlmRequest {
headers: original.headers.clone(),
content,
})
}
}
#[cfg(test)]
#[path = "../../tests/unit/codec/anthropic_tests.rs"]
mod tests;