use std::{fmt, fmt::Display};
use schemars::{JsonSchema, schema_for};
use serde::{Deserialize, Serialize};
use serde_json::Value;
pub const ANTHROPIC_VERSION: &str = "2023-06-01";
pub const DEFAULT_ENDPOINT_HOST: &str = "api.anthropic.com";
pub const DEFAULT_MODEL: &str = "claude-sonnet-4-20250514";
#[derive(Debug, Serialize)]
pub struct MessagesBody<'a> {
pub model: &'a str,
pub max_tokens: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub system: Option<&'a str>,
pub messages: &'a im::Vector<Message>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<&'a im::Vector<Tool>>,
#[serde(skip_serializing_if = "is_false")]
pub stream: bool,
}
fn is_false(value: &bool) -> bool {
!value
}
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum Role {
User,
Assistant,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct Message {
pub role: Role,
pub content: Vec<Content>,
}
impl Message {
pub fn from_text<S: Into<String>>(role: Role, text: S) -> Self {
Self {
role,
content: vec![Content::from_text(text.into())],
}
}
}
impl IntoIterator for Message {
type Item = Content;
type IntoIter = std::vec::IntoIter<Content>;
fn into_iter(self) -> Self::IntoIter {
self.content.into_iter()
}
}
impl<'a> IntoIterator for &'a Message {
type Item = &'a Content;
type IntoIter = std::slice::Iter<'a, Content>;
fn into_iter(self) -> Self::IntoIter {
self.content.iter()
}
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct Tool {
pub name: String,
pub description: String,
pub input_schema: Value,
}
impl Tool {
pub fn new<T, N, D>(name: N, description: D) -> Self
where
T: JsonSchema,
N: Into<String>,
D: Into<String>,
{
let schema = schema_for!(T);
let input_schema =
serde_json::to_value(schema).expect("Schema serialization should not fail");
Self {
name: name.into(),
description: description.into(),
input_schema,
}
}
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct ToolUse {
pub id: String,
pub name: String,
pub input: Value,
}
impl Display for ToolUse {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}({}) with {:?}", self.name, self.id, self.input)
}
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct ToolResult {
pub tool_use_id: String,
pub content: ToolResultContent,
#[serde(skip_serializing_if = "Option::is_none")]
pub is_error: Option<bool>,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
#[serde(untagged)]
pub enum ToolResultContent {
Content(Vec<Content>),
String(String),
}
impl Display for ToolResultContent {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ToolResultContent::Content(contents) => {
for (idx, content) in contents.iter().enumerate() {
if idx > 0 {
write!(f, " ")?;
}
content.fmt(f)?;
}
Ok(())
}
ToolResultContent::String(string) => f.write_str(string),
}
}
}
impl From<String> for ToolResultContent {
fn from(s: String) -> Self {
ToolResultContent::String(s)
}
}
impl From<&str> for ToolResultContent {
fn from(s: &str) -> Self {
ToolResultContent::String(s.to_string())
}
}
impl From<Vec<Content>> for ToolResultContent {
fn from(content: Vec<Content>) -> Self {
ToolResultContent::Content(content)
}
}
impl ToolResult {
pub fn success<T: Into<ToolResultContent>>(tool_use_id: String, content: T) -> Self {
Self {
tool_use_id,
content: content.into(),
is_error: None,
}
}
pub fn error<T: Into<ToolResultContent>>(tool_use_id: String, error_content: T) -> Self {
Self {
tool_use_id,
content: error_content.into(),
is_error: Some(true),
}
}
pub fn unknown_tool<S: AsRef<str>>(tool_use_id: String, tool_name: S) -> Self {
Self::error(tool_use_id, format!("Unknown tool: {}", tool_name.as_ref()))
}
}
impl Display for ToolResult {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if self.is_error == Some(true) {
write!(
f,
"Tool result error for {}: {}",
self.tool_use_id, self.content
)
} else {
write!(f, "Tool result for {}: {}", self.tool_use_id, self.content)
}
}
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct WebSearchResult {
pub title: String,
pub url: String,
pub encrypted_content: String,
pub page_age: Option<String>,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum Content {
Text {
text: String,
},
Image,
ToolUse(ToolUse),
ToolResult(ToolResult),
ServerToolUse {
id: String,
name: String,
input: Value,
},
WebSearchToolResult {
tool_use_id: String,
content: Vec<WebSearchResult>,
},
#[serde(other)]
Unknown,
}
impl Display for Content {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Content::Text { text } => f.write_str(text),
Content::Image => f.write_str("<image>"),
Content::ToolUse(tool_use) => tool_use.fmt(f),
Content::ToolResult(tool_result) => tool_result.fmt(f),
Content::ServerToolUse { id, name, .. } => write!(f, "<server_tool_use:{name}({id})>"),
Content::WebSearchToolResult { tool_use_id, .. } => {
write!(f, "<web_search_result:{tool_use_id}>")
}
Content::Unknown => f.write_str("<unknown>"),
}
}
}
impl Content {
pub fn from_text<S: Into<String>>(text: S) -> Self {
Content::Text { text: text.into() }
}
pub fn as_text(&self) -> Option<&str> {
match self {
Content::Text { text } => Some(text.as_str()),
_ => None,
}
}
}
#[derive(Clone, Debug, thiserror::Error, Deserialize, Serialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ApiError {
#[error("Invalid request")]
InvalidRequestError,
#[error("Authentication error")]
AuthenticationError,
#[error("Permission error")]
PermissionError,
#[error("Not found")]
NotFoundError,
#[error("Request too large")]
RequestTooLarge,
#[error("Rate limit exceeded")]
RateLimitError,
#[error("API error")]
#[allow(clippy::enum_variant_names)]
ApiError,
#[error("API overloaded")]
OverloadedError,
}
#[derive(Debug, Deserialize, Serialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ApiResponse {
Message(MessagesResponse),
Error { error: ApiError },
}
impl ApiResponse {
pub fn kind(&self) -> &'static str {
match self {
ApiResponse::Message(_) => "message",
ApiResponse::Error { .. } => "error",
}
}
}
impl TryFrom<ApiResponse> for MessagesResponse {
type Error = ();
fn try_from(helper: ApiResponse) -> Result<Self, Self::Error> {
match helper {
ApiResponse::Message(response) => Ok(response),
ApiResponse::Error { error: _ } => Err(()),
}
}
}
#[derive(Copy, Clone, Debug, Eq, Deserialize, PartialEq, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum StopReason {
EndTurn,
MaxTokens,
StopSequence,
ToolUse,
PauseTurn,
Refusal,
}
#[derive(Debug, Deserialize, Serialize)]
pub struct MessagesResponse {
pub id: String,
pub model: String,
pub stop_reason: StopReason,
pub stop_sequence: Option<String>,
pub usage: Usage,
#[serde(flatten)]
pub message: Message,
}
#[derive(Debug, Deserialize, Serialize)]
pub struct Usage {
pub input_tokens: u32,
pub output_tokens: u32,
#[serde(default)]
pub cache_creation_input_tokens: u32,
#[serde(default)]
pub cache_read_input_tokens: u32,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct ServerToolUsage {
pub web_search_requests: u32,
}
#[derive(Debug, Deserialize)]
pub struct StreamingUsage {
pub input_tokens: Option<u32>,
pub output_tokens: Option<u32>,
pub cache_creation_input_tokens: Option<u32>,
pub cache_read_input_tokens: Option<u32>,
pub server_tool_use: Option<ServerToolUsage>,
}
#[derive(Debug, Deserialize)]
pub struct StreamingMessage {
pub id: String,
pub model: String,
pub stop_reason: Option<StopReason>,
pub stop_sequence: Option<String>,
pub usage: Usage,
pub role: Role,
pub content: Vec<Content>,
}
impl StreamingMessage {
pub fn update(&mut self, delta: MessageDelta) {
if let Some(stop_reason) = delta.stop_reason {
self.stop_reason = Some(stop_reason);
}
if let Some(stop_sequence) = delta.stop_sequence {
self.stop_sequence = Some(stop_sequence);
}
}
}
#[derive(Debug, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum StreamEvent {
MessageStart { message: StreamingMessage },
ContentBlockStart { index: u32, content_block: Content },
ContentBlockDelta { index: u32, delta: Delta },
ContentBlockStop { index: u32 },
MessageDelta {
delta: MessageDelta,
usage: Option<StreamingUsage>,
},
MessageStop,
Ping,
Error { error: ApiError },
#[serde(skip)]
Unknown {
event_type: Vec<u8>,
contents: serde_json::Value,
},
}
#[derive(Debug, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum Delta {
TextDelta { text: String },
InputJsonDelta { partial_json: String },
ThinkingDelta { thinking: String },
SignatureDelta { signature: String },
}
#[derive(Debug, Deserialize, Clone)]
pub struct MessageDelta {
pub stop_reason: Option<StopReason>,
pub stop_sequence: Option<String>,
}
#[cfg(test)]
mod tests {
use super::{Content, Delta, StopReason, StreamEvent, Usage};
#[test]
fn test_deserialize_content_block_delta_text() {
let data = br#"{"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "Hello"}}"#;
let result: StreamEvent = serde_json::from_slice(data).unwrap();
match result {
StreamEvent::ContentBlockDelta { index, delta } => {
assert_eq!(index, 0);
match delta {
Delta::TextDelta { text } => assert_eq!(text, "Hello"),
_ => panic!("Expected TextDelta"),
}
}
_ => panic!("Expected ContentBlockDelta"),
}
}
#[test]
fn test_deserialize_message_start() {
let data = br#"{"type": "message_start", "message": {"id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY", "type": "message", "role": "assistant", "content": [], "model": "claude-opus-4-20250514", "stop_reason": null, "stop_sequence": null, "usage": {"input_tokens": 25, "output_tokens": 1}}}"#;
let result = crate::deserialize_event(data).unwrap();
match result {
StreamEvent::MessageStart { message } => {
assert_eq!(message.id, "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY");
assert_eq!(message.model, "claude-opus-4-20250514");
assert_eq!(message.stop_reason, None); assert_eq!(message.stop_sequence, None); assert_eq!(message.usage.input_tokens, 25);
assert_eq!(message.usage.output_tokens, 1);
assert!(message.content.is_empty());
}
other => {
panic!("Expected MessageStart event, but got: {:?}", other);
}
}
}
#[test]
fn test_deserialize_message_delta() {
let data = br#"{"type": "message_delta", "delta": {"stop_reason": "end_turn", "stop_sequence": null}, "usage": {"output_tokens": 38}}"#;
let result = crate::deserialize_event(data).unwrap();
match result {
StreamEvent::MessageDelta { delta, usage } => {
assert_eq!(delta.stop_reason, Some(StopReason::EndTurn));
assert_eq!(delta.stop_sequence, None);
assert_eq!(usage.unwrap().output_tokens, Some(38));
}
StreamEvent::Unknown {
event_type,
contents,
} => {
eprintln!(
"Got Unknown: {:?}, contents: {:?}",
String::from_utf8_lossy(&event_type),
contents
);
panic!("Expected MessageDelta but got Unknown");
}
other => {
panic!("Expected MessageDelta event, but got: {:?}", other);
}
}
}
#[test]
fn test_usage_with_cache_fields() {
let data = br#"{"input_tokens": 100, "output_tokens": 50, "cache_creation_input_tokens": 1000, "cache_read_input_tokens": 500}"#;
let usage: Usage = serde_json::from_slice(data).expect("should deserialize");
assert_eq!(usage.input_tokens, 100);
assert_eq!(usage.output_tokens, 50);
assert_eq!(usage.cache_creation_input_tokens, 1000);
assert_eq!(usage.cache_read_input_tokens, 500);
}
#[test]
fn test_usage_without_cache_fields() {
let data = br#"{"input_tokens": 100, "output_tokens": 50}"#;
let usage: Usage = serde_json::from_slice(data).expect("should deserialize");
assert_eq!(usage.cache_creation_input_tokens, 0);
assert_eq!(usage.cache_read_input_tokens, 0);
}
#[test]
fn test_deserialize_server_tool_use() {
let data = br#"{"type": "content_block_start", "index": 0, "content_block": {"type": "server_tool_use", "id": "srvtoolu_xxx", "name": "web_search", "input": {"query": "rust programming"}}}"#;
let event: StreamEvent = serde_json::from_slice(data).expect("should deserialize");
match event {
StreamEvent::ContentBlockStart {
content_block: Content::ServerToolUse { id, name, input },
..
} => {
assert_eq!(id, "srvtoolu_xxx");
assert_eq!(name, "web_search");
assert_eq!(input["query"], "rust programming");
}
_ => panic!("expected ContentBlockStart with ServerToolUse"),
}
}
#[test]
fn test_deserialize_web_search_result() {
let data = br#"{"type": "content_block_start", "index": 1, "content_block": {"type": "web_search_tool_result", "tool_use_id": "srvtoolu_xxx", "content": [{"type": "web_search_result", "title": "Rust Programming Language", "url": "https://www.rust-lang.org/", "encrypted_content": "...", "page_age": "2 days ago"}]}}"#;
let event: StreamEvent = serde_json::from_slice(data).expect("should deserialize");
match event {
StreamEvent::ContentBlockStart {
content_block:
Content::WebSearchToolResult {
tool_use_id,
content,
},
..
} => {
assert_eq!(tool_use_id, "srvtoolu_xxx");
assert_eq!(content.len(), 1);
assert_eq!(content[0].title, "Rust Programming Language");
assert_eq!(content[0].url, "https://www.rust-lang.org/");
assert_eq!(content[0].page_age, Some("2 days ago".to_string()));
}
_ => panic!("expected ContentBlockStart with WebSearchToolResult"),
}
}
#[test]
fn test_deserialize_unknown_content() {
let data = br#"{"type": "content_block_start", "index": 0, "content_block": {"type": "future_content_type", "some_field": "value"}}"#;
let event: StreamEvent = serde_json::from_slice(data).expect("should deserialize");
match event {
StreamEvent::ContentBlockStart {
content_block: Content::Unknown,
..
} => {}
_ => panic!("expected ContentBlockStart with Unknown"),
}
}
}