use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum Role {
User,
Assistant,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ContentBlock {
Text { text: String },
ToolUse {
id: String,
name: String,
input: serde_json::Value,
},
ToolResult {
tool_use_id: String,
content: String,
#[serde(skip_serializing_if = "Option::is_none")]
is_error: Option<bool>,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum MessageContent {
Text(String),
Blocks(Vec<ContentBlock>),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub role: Role,
pub content: MessageContent,
}
impl Message {
pub fn user(text: impl Into<String>) -> Self {
Self {
role: Role::User,
content: MessageContent::Text(text.into()),
}
}
pub fn assistant(text: impl Into<String>) -> Self {
Self {
role: Role::Assistant,
content: MessageContent::Text(text.into()),
}
}
pub fn user_blocks(blocks: Vec<ContentBlock>) -> Self {
Self {
role: Role::User,
content: MessageContent::Blocks(blocks),
}
}
pub fn assistant_blocks(blocks: Vec<ContentBlock>) -> Self {
Self {
role: Role::Assistant,
content: MessageContent::Blocks(blocks),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Tool {
pub name: String,
pub description: String,
pub input_schema: serde_json::Value,
}
impl Tool {
pub fn new(
name: impl Into<String>,
description: impl Into<String>,
input_schema: serde_json::Value,
) -> Self {
Self {
name: name.into(),
description: description.into(),
input_schema,
}
}
}
#[derive(Debug, Clone, Serialize)]
pub struct CreateMessageRequest {
pub model: String,
pub messages: Vec<Message>,
pub max_tokens: usize,
#[serde(skip_serializing_if = "Option::is_none")]
pub system: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop_sequences: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<Tool>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stream: Option<bool>,
}
impl CreateMessageRequest {
pub fn new(model: impl Into<String>, messages: Vec<Message>, max_tokens: usize) -> Self {
Self {
model: model.into(),
messages,
max_tokens,
system: None,
temperature: None,
top_p: None,
stop_sequences: None,
tools: None,
stream: None,
}
}
pub fn with_system(mut self, system: impl Into<String>) -> Self {
self.system = Some(system.into());
self
}
pub fn with_temperature(mut self, temperature: f32) -> Self {
self.temperature = Some(temperature);
self
}
pub fn with_top_p(mut self, top_p: f32) -> Self {
self.top_p = Some(top_p);
self
}
pub fn with_stop_sequences(mut self, sequences: Vec<String>) -> Self {
self.stop_sequences = Some(sequences);
self
}
pub fn with_tools(mut self, tools: Vec<Tool>) -> Self {
self.tools = Some(tools);
self
}
}
#[derive(Debug, Clone, Deserialize)]
pub struct Usage {
pub input_tokens: usize,
pub output_tokens: usize,
}
#[derive(Debug, Clone, Deserialize)]
pub struct CreateMessageResponse {
pub id: String,
#[serde(rename = "type")]
pub response_type: String,
pub role: Role,
pub content: Vec<ContentBlock>,
pub model: String,
pub stop_reason: Option<String>,
pub stop_sequence: Option<String>,
pub usage: Usage,
}
impl CreateMessageResponse {
pub fn text(&self) -> Option<String> {
for block in &self.content {
if let ContentBlock::Text { text } = block {
return Some(text.clone());
}
}
None
}
pub fn tool_uses(&self) -> Vec<(&str, &str, &serde_json::Value)> {
self.content
.iter()
.filter_map(|block| {
if let ContentBlock::ToolUse { id, name, input } = block {
Some((id.as_str(), name.as_str(), input))
} else {
None
}
})
.collect()
}
pub fn has_tool_use(&self) -> bool {
self.content
.iter()
.any(|block| matches!(block, ContentBlock::ToolUse { .. }))
}
}
#[derive(Debug, Clone, Deserialize)]
pub struct Model {
pub id: String,
pub display_name: String,
pub created_at: String,
}
#[derive(Debug, Clone, Deserialize)]
pub struct ErrorResponse {
#[serde(rename = "type")]
pub error_type: String,
pub error: ErrorDetail,
}
#[derive(Debug, Clone, Deserialize)]
pub struct ErrorDetail {
#[serde(rename = "type")]
pub error_type: String,
pub message: String,
}