use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use uuid::Uuid;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum Role {
User,
Assistant,
System,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub role: Role,
pub content: Vec<ContentBlock>,
}
impl Message {
pub fn user(text: impl Into<String>) -> Self {
Self {
role: Role::User,
content: vec![ContentBlock::Text { text: text.into() }],
}
}
pub fn assistant(text: impl Into<String>) -> Self {
Self {
role: Role::Assistant,
content: vec![ContentBlock::Text { text: text.into() }],
}
}
pub fn system(text: impl Into<String>) -> Self {
Self {
role: Role::System,
content: vec![ContentBlock::Text { text: text.into() }],
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ContentBlock {
Text { text: String },
Image { source: ImageSource },
ToolUse {
id: String,
name: String,
input: serde_json::Value,
},
ToolResult {
tool_use_id: String,
content: String,
#[serde(skip_serializing_if = "Option::is_none")]
is_error: Option<bool>,
},
Thinking {
thinking: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
signature: Option<String>,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ImageSource {
Base64 { media_type: String, data: String },
Url { url: String },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LLMRequest {
pub model: String,
pub messages: Vec<Message>,
#[serde(skip_serializing_if = "Option::is_none")]
pub system: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<Tool>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u32>,
#[serde(skip)]
pub stream: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub metadata: Option<HashMap<String, String>>,
#[serde(skip)]
pub working_directory: Option<String>,
#[serde(skip)]
pub session_id: Option<Uuid>,
}
impl LLMRequest {
pub fn new(model: impl Into<String>, messages: Vec<Message>) -> Self {
Self {
model: model.into(),
messages,
system: None,
tools: None,
temperature: None,
max_tokens: None,
stream: false,
metadata: None,
working_directory: None,
session_id: None,
}
}
pub fn with_system(mut self, system: impl Into<String>) -> Self {
self.system = Some(system.into());
self
}
pub fn with_tools(mut self, tools: Vec<Tool>) -> Self {
self.tools = Some(tools);
self
}
pub fn with_temperature(mut self, temperature: f32) -> Self {
self.temperature = Some(temperature);
self
}
pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
self.max_tokens = Some(max_tokens);
self
}
pub fn with_streaming(mut self) -> Self {
self.stream = true;
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Tool {
pub name: String,
pub description: String,
pub input_schema: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LLMResponse {
pub id: String,
pub model: String,
pub content: Vec<ContentBlock>,
pub stop_reason: Option<StopReason>,
pub usage: TokenUsage,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum StopReason {
EndTurn,
MaxTokens,
StopSequence,
ToolUse,
}
#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
pub struct TokenUsage {
pub input_tokens: u32,
pub output_tokens: u32,
#[serde(default)]
pub cache_creation_tokens: u32,
#[serde(default)]
pub cache_read_tokens: u32,
#[serde(default)]
pub billing_cache_creation: u32,
#[serde(default)]
pub billing_cache_read: u32,
}
impl TokenUsage {
pub fn total(&self) -> u32 {
self.input_tokens + self.output_tokens
}
pub fn billable_input(&self) -> u32 {
let cc = if self.billing_cache_creation > 0 {
self.billing_cache_creation
} else {
self.cache_creation_tokens
};
let cr = if self.billing_cache_read > 0 {
self.billing_cache_read
} else {
self.cache_read_tokens
};
self.input_tokens + cc + cr
}
pub fn context_input(&self) -> u32 {
self.input_tokens + self.cache_creation_tokens + self.cache_read_tokens
}
pub fn billable_total(&self) -> u32 {
self.billable_input() + self.output_tokens
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum StreamEvent {
MessageStart { message: StreamMessage },
ContentBlockStart {
index: usize,
content_block: ContentBlock,
},
ContentBlockDelta { index: usize, delta: ContentDelta },
ContentBlockStop { index: usize },
MessageDelta {
delta: MessageDelta,
usage: TokenUsage,
},
MessageStop,
Ping,
Error { error: String },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StreamMessage {
pub id: String,
pub model: String,
pub role: Role,
pub usage: TokenUsage,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ContentDelta {
TextDelta { text: String },
InputJsonDelta { partial_json: String },
ReasoningDelta { text: String },
ThinkingDelta { thinking: String },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MessageDelta {
pub stop_reason: Option<StopReason>,
pub stop_sequence: Option<String>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_message_creation() {
let user_msg = Message::user("Hello");
assert_eq!(user_msg.role, Role::User);
assert_eq!(user_msg.content.len(), 1);
let assistant_msg = Message::assistant("Hi there");
assert_eq!(assistant_msg.role, Role::Assistant);
}
#[test]
fn test_llm_request_builder() {
let request = LLMRequest::new("claude-3-sonnet-20240229", vec![Message::user("Test")])
.with_system("You are helpful")
.with_temperature(0.7)
.with_max_tokens(1000)
.with_streaming();
assert_eq!(request.model, "claude-3-sonnet-20240229");
assert!(request.system.is_some());
assert_eq!(request.temperature, Some(0.7));
assert_eq!(request.max_tokens, Some(1000));
assert!(request.stream);
}
#[test]
fn test_token_usage() {
let usage = TokenUsage {
input_tokens: 100,
output_tokens: 200,
..Default::default()
};
assert_eq!(usage.total(), 300);
}
}