use serde::{Deserialize, Serialize};
use crate::types::shared::{RequestId, Usage};
use crate::files::{File, FileError};
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct Message {
pub id: String,
#[serde(rename = "type")]
pub type_: String,
pub role: Role,
#[serde(default)]
pub content: Vec<ContentBlock>,
pub model: String,
pub stop_reason: Option<StopReason>,
pub stop_sequence: Option<String>,
pub usage: Usage,
#[serde(skip)]
pub request_id: Option<RequestId>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum Role {
User,
Assistant,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(tag = "type")]
pub enum ContentBlock {
#[serde(rename = "text")]
Text { text: String },
#[serde(rename = "thinking")]
Thinking {
thinking: String,
signature: String,
},
#[serde(rename = "image")]
Image { source: ImageSource },
#[serde(rename = "tool_use")]
ToolUse {
id: String,
name: String,
input: serde_json::Value,
},
#[serde(rename = "tool_result")]
ToolResult {
tool_use_id: String,
content: Option<String>,
is_error: Option<bool>,
},
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(tag = "type")]
pub enum ImageSource {
#[serde(rename = "base64")]
Base64 {
media_type: String,
data: String,
},
#[serde(rename = "url")]
Url {
url: String,
},
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum StopReason {
EndTurn,
MaxTokens,
StopSequence,
ToolUse,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MessageCreateParams {
pub model: String,
pub max_tokens: u32,
pub messages: Vec<MessageParam>,
#[serde(skip_serializing_if = "Option::is_none")]
pub system: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_k: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop_sequences: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stream: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<crate::types::Tool>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<crate::types::ToolChoice>,
#[serde(skip_serializing_if = "Option::is_none")]
pub metadata: Option<std::collections::HashMap<String, String>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MessageParam {
pub role: Role,
pub content: MessageContent,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum MessageContent {
Text(String),
Blocks(Vec<ContentBlockParam>),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum ContentBlockParam {
#[serde(rename = "text")]
Text { text: String },
#[serde(rename = "thinking")]
Thinking {
thinking: String,
signature: String,
},
#[serde(rename = "image")]
Image { source: ImageSource },
#[serde(rename = "tool_use")]
ToolUse {
id: String,
name: String,
input: serde_json::Value,
},
#[serde(rename = "tool_result")]
ToolResult {
tool_use_id: String,
content: Option<String>,
is_error: Option<bool>,
},
}
#[derive(Debug, Clone)]
pub struct MessageCreateBuilder {
params: MessageCreateParams,
}
impl MessageCreateBuilder {
pub fn new(model: impl Into<String>, max_tokens: u32) -> Self {
Self {
params: MessageCreateParams {
model: model.into(),
max_tokens,
messages: Vec::new(),
system: None,
temperature: None,
top_p: None,
top_k: None,
stop_sequences: None,
stream: None,
tools: None,
tool_choice: None,
metadata: None,
},
}
}
pub fn message(mut self, role: Role, content: impl Into<MessageContent>) -> Self {
self.params.messages.push(MessageParam {
role,
content: content.into(),
});
self
}
pub fn user(self, content: impl Into<MessageContent>) -> Self {
self.message(Role::User, content)
}
pub fn assistant(self, content: impl Into<MessageContent>) -> Self {
self.message(Role::Assistant, content)
}
pub fn system(mut self, system: impl Into<String>) -> Self {
self.params.system = Some(system.into());
self
}
pub fn temperature(mut self, temperature: f32) -> Self {
self.params.temperature = Some(temperature);
self
}
pub fn top_p(mut self, top_p: f32) -> Self {
self.params.top_p = Some(top_p);
self
}
pub fn top_k(mut self, top_k: u32) -> Self {
self.params.top_k = Some(top_k);
self
}
pub fn stop_sequences(mut self, stop_sequences: Vec<String>) -> Self {
self.params.stop_sequences = Some(stop_sequences);
self
}
pub fn stream(mut self, stream: bool) -> Self {
self.params.stream = Some(stream);
self
}
pub fn tools(mut self, tools: Vec<crate::types::Tool>) -> Self {
self.params.tools = Some(tools);
self
}
pub fn tool_choice(mut self, tool_choice: crate::types::ToolChoice) -> Self {
self.params.tool_choice = Some(tool_choice);
self
}
pub fn metadata(mut self, metadata: std::collections::HashMap<String, String>) -> Self {
self.params.metadata = Some(metadata);
self
}
pub fn build(self) -> MessageCreateParams {
self.params
}
}
impl From<String> for MessageContent {
fn from(text: String) -> Self {
Self::Text(text)
}
}
impl From<&str> for MessageContent {
fn from(text: &str) -> Self {
Self::Text(text.to_string())
}
}
impl From<Vec<ContentBlockParam>> for MessageContent {
fn from(blocks: Vec<ContentBlockParam>) -> Self {
Self::Blocks(blocks)
}
}
impl ContentBlockParam {
pub fn text(text: impl Into<String>) -> Self {
Self::Text { text: text.into() }
}
pub fn image_base64(media_type: impl Into<String>, data: impl Into<String>) -> Self {
Self::Image {
source: ImageSource::Base64 {
media_type: media_type.into(),
data: data.into(),
},
}
}
pub fn image_url(url: impl Into<String>) -> Self {
Self::Image {
source: ImageSource::Url {
url: url.into(),
},
}
}
pub async fn image_file(file: File) -> Result<Self, FileError> {
if !file.is_image() {
return Err(FileError::InvalidMimeType {
mime_type: file.mime_type.to_string(),
allowed: vec!["image/*".to_string()],
});
}
let base64_data = file.to_base64().await?;
Ok(Self::Image {
source: ImageSource::Base64 {
media_type: file.mime_type.to_string(),
data: base64_data,
},
})
}
pub async fn from_file(file: File) -> Result<Self, FileError> {
Self::image_file(file).await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_message_builder() {
let params = MessageCreateBuilder::new("claude-3-5-sonnet-latest", 1024)
.user("Hello, Claude!")
.system("You are a helpful assistant.")
.temperature(0.7)
.build();
assert_eq!(params.model, "claude-3-5-sonnet-latest");
assert_eq!(params.max_tokens, 1024);
assert_eq!(params.messages.len(), 1);
assert_eq!(params.messages[0].role, Role::User);
assert_eq!(params.system, Some("You are a helpful assistant.".to_string()));
assert_eq!(params.temperature, Some(0.7));
}
#[test]
fn test_content_block_creation() {
let text_block = ContentBlockParam::text("Hello world");
match text_block {
ContentBlockParam::Text { text } => assert_eq!(text, "Hello world"),
_ => panic!("Expected text block"),
}
let image_block = ContentBlockParam::image_base64("image/jpeg", "base64data");
match image_block {
ContentBlockParam::Image { source } => match source {
ImageSource::Base64 { media_type, data } => {
assert_eq!(media_type, "image/jpeg");
assert_eq!(data, "base64data");
},
_ => panic!("Expected base64 image source"),
},
_ => panic!("Expected image block"),
}
}
#[test]
fn test_message_content_from_string() {
let content: MessageContent = "Hello".into();
match content {
MessageContent::Text(text) => assert_eq!(text, "Hello"),
_ => panic!("Expected text content"),
}
}
}