use crate::{anthropic_post, anthropic_request_stream, ApiResponseOrError, Credentials, Usage};
use anyhow::Result;
use derive_builder::Builder;
use futures_util::StreamExt;
use reqwest::Method;
use reqwest_eventsource::{CannotCloneRequestError, Event, EventSource};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use tokio::sync::mpsc::{channel, Receiver, Sender};
#[derive(Deserialize, Debug, Clone, Eq, PartialEq)]
pub struct MessagesResponse {
pub id: String,
pub model: String,
pub role: MessageRole,
pub content: Vec<ResponseContentBlock>,
pub stop_reason: Option<String>,
pub stop_sequence: Option<String>,
#[serde(rename = "type")]
pub typ: String,
pub usage: Usage,
}
#[derive(Deserialize, Debug, Clone, Eq, PartialEq)]
#[serde(tag = "type")]
pub enum ResponseContentBlock {
#[serde(rename = "text")]
Text { text: String },
#[serde(rename = "tool_use")]
ToolUse {
id: String,
name: String,
input: Value,
},
#[serde(rename = "thinking")]
Thinking { signature: String, thinking: String },
#[serde(rename = "redacted_thinking")]
RedactedThinking { data: String },
}
#[derive(Deserialize, Debug, Clone, Eq, PartialEq)]
#[serde(tag = "type")]
pub enum StreamEvent {
#[serde(rename = "message_start")]
MessageStart { message: MessageStart },
#[serde(rename = "content_block_start")]
ContentBlockStart {
index: u32,
content_block: ContentBlockStart,
},
#[serde(rename = "content_block_delta")]
ContentBlockDelta {
index: u32,
delta: ContentBlockDelta,
},
#[serde(rename = "content_block_stop")]
ContentBlockStop { index: u32 },
#[serde(rename = "message_delta")]
MessageDelta { delta: MessageDelta, usage: Usage },
#[serde(rename = "message_stop")]
MessageStop,
#[serde(rename = "ping")]
Ping,
}
#[derive(Deserialize, Debug, Clone, Eq, PartialEq)]
pub struct MessageStart {
pub id: String,
pub model: String,
pub role: MessageRole,
pub content: Vec<ContentBlockStart>,
}
#[derive(Deserialize, Debug, Clone, Eq, PartialEq)]
#[serde(untagged)]
pub enum ContentBlockStart {
Text { text: String },
ToolUse {
id: String,
name: String,
input: Value,
},
}
#[derive(Deserialize, Debug, Clone, Eq, PartialEq)]
#[serde(untagged)]
pub enum ContentBlockDelta {
Text { text: String },
InputJsonDelta { partial_json: String },
}
#[derive(Deserialize, Debug, Clone, Eq, PartialEq)]
pub struct MessageDelta {
pub stop_reason: Option<String>,
pub stop_sequence: Option<String>,
}
#[derive(Serialize, Builder, Debug, Clone)]
#[builder(derive(Clone, Debug, PartialEq))]
#[builder(pattern = "owned")]
#[builder(name = "MessagesBuilder")]
#[builder(setter(strip_option, into))]
pub struct MessagesRequest {
pub model: String,
pub messages: Vec<Message>,
pub max_tokens: u64,
#[builder(default)]
#[serde(skip_serializing_if = "Option::is_none")]
pub metadata: Option<Metadata>,
#[builder(default)]
#[serde(skip_serializing_if = "Option::is_none")]
pub stop_sequences: Option<Vec<String>>,
#[builder(default)]
#[serde(skip_serializing_if = "Option::is_none")]
pub stream: Option<bool>,
#[builder(default)]
#[serde(skip_serializing_if = "Option::is_none")]
pub system: Option<String>,
#[builder(default)]
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f64>,
#[builder(default)]
#[serde(skip_serializing_if = "Option::is_none")]
pub thinking: Option<Thinking>,
#[builder(default)]
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<ToolChoice>,
#[builder(default)]
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<Tool>>,
#[builder(default)]
#[serde(skip_serializing_if = "Option::is_none")]
pub top_k: Option<u32>,
#[builder(default)]
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f64>,
#[serde(skip_serializing)]
#[builder(default)]
pub credentials: Option<Credentials>,
}
#[derive(Serialize, Debug, Clone, Eq, PartialEq)]
pub struct Message {
pub role: MessageRole,
pub content: MessageContent,
}
#[derive(Serialize, Deserialize, Debug, Clone, Copy, Eq, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum MessageRole {
User,
Assistant,
}
#[derive(Serialize, Debug, Clone, Eq, PartialEq)]
#[serde(untagged)]
pub enum MessageContent {
Text(String),
ContentBlocks(Vec<RequestContentBlock>),
}
#[derive(Serialize, Debug, Clone, Eq, PartialEq)]
#[serde(tag = "type")]
pub enum RequestContentBlock {
#[serde(rename = "text")]
Text { text: String },
#[serde(rename = "image")]
Image { source: ImageSource },
}
#[derive(Serialize, Debug, Clone, Eq, PartialEq)]
pub struct ImageSource {
#[serde(rename = "type")]
pub source_type: String,
pub media_type: String,
pub data: String,
}
#[derive(Serialize, Debug, Clone, Eq, PartialEq)]
pub enum ThinkingType {
#[serde(rename = "enabled")]
Enabled,
#[serde(rename = "disabled")]
Disabled,
}
#[derive(Serialize, Debug, Clone, Eq, PartialEq)]
pub struct Thinking {
#[serde(rename = "type")]
pub thinking_type: ThinkingType,
#[serde(rename = "budget_tokens")]
pub budget_tokens: u64,
}
#[derive(Serialize, Debug, Clone, Eq, PartialEq)]
pub struct Tool {
pub name: String,
pub description: String,
pub input_schema: Value,
}
#[derive(Serialize, Debug, Clone, Eq, PartialEq)]
#[serde(tag = "type")]
pub enum ToolChoice {
#[serde(rename = "auto")]
Auto,
#[serde(rename = "any")]
Any,
#[serde(rename = "tool")]
Tool { name: String },
#[serde(rename = "none")]
None,
}
#[derive(Serialize, Debug, Clone, Eq, PartialEq)]
pub struct Metadata {
pub user_id: Option<String>,
}
impl MessagesResponse {
pub async fn create(request: MessagesRequest) -> ApiResponseOrError<Self> {
let credentials_opt = request.credentials.clone();
anthropic_post("messages", &request, credentials_opt).await
}
}
impl StreamEvent {
pub async fn create_stream(
request: MessagesRequest,
) -> Result<Receiver<Self>, CannotCloneRequestError> {
let credentials_opt = request.credentials.clone();
let stream = anthropic_request_stream(
Method::POST,
"messages",
|r| r.json(&request),
credentials_opt,
)
.await?;
let (tx, rx) = channel::<Self>(32);
tokio::spawn(forward_deserialized_anthropic_stream(stream, tx));
Ok(rx)
}
}
async fn forward_deserialized_anthropic_stream(
mut stream: EventSource,
tx: Sender<StreamEvent>,
) -> anyhow::Result<()> {
while let Some(event) = stream.next().await {
let event = event?;
if let Event::Message(event) = event {
let stream_event = serde_json::from_str::<StreamEvent>(&event.data)?;
if matches!(stream_event, StreamEvent::Ping) {
continue; }
tx.send(stream_event).await?;
}
}
Ok(())
}
impl MessagesBuilder {
pub fn builder(model: &str, messages: impl Into<Vec<Message>>, max_tokens: u64) -> Self {
Self::create_empty()
.model(model)
.messages(messages)
.max_tokens(max_tokens)
}
pub async fn create(self) -> ApiResponseOrError<MessagesResponse> {
let request = self.build().unwrap();
MessagesResponse::create(request).await
}
pub async fn create_stream(self) -> Result<Receiver<StreamEvent>, CannotCloneRequestError> {
let mut request = self.build().expect("Failed to build MessagesRequest");
request.stream = Some(true);
StreamEvent::create_stream(request).await
}
}
impl MessagesResponse {
pub fn builder(
model: &str,
messages: impl Into<Vec<Message>>,
max_tokens: u64,
) -> MessagesBuilder {
MessagesBuilder::create_empty()
.model(model)
.messages(messages)
.max_tokens(max_tokens)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_simple_message() {
let credentials = Credentials::from_env();
let response = MessagesResponse::builder(
"claude-3-7-sonnet-20250219",
vec![Message {
role: MessageRole::User,
content: MessageContent::Text("Hello!".to_string()),
}],
100,
)
.credentials(credentials)
.create()
.await
.unwrap();
assert!(!response.content.is_empty());
}
#[tokio::test]
async fn test_streaming_message() {
let credentials = Credentials::from_env();
let mut stream = MessagesResponse::builder(
"claude-3-7-sonnet-20250219",
vec![Message {
role: MessageRole::User,
content: MessageContent::Text("Hello!".to_string()),
}],
100,
)
.credentials(credentials)
.create_stream()
.await
.unwrap();
while let Some(event) = stream.recv().await {
match event {
StreamEvent::ContentBlockDelta { delta, .. } => {
if let ContentBlockDelta::Text { text } = delta {
print!("{}", text);
}
}
_ => {}
}
}
}
}