use serde::{Deserialize, Serialize};
use crate::message::{ContentBlock, Message, ToolDefinition};
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ThinkingConfig {
pub enabled: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub budget_tokens: Option<u32>,
}
#[derive(Clone, Debug, Serialize)]
pub struct CompletionRequest {
pub model: String,
pub messages: Vec<Message>,
pub max_tokens: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub system: Option<String>,
#[serde(skip_serializing_if = "Vec::is_empty")]
pub tools: Vec<ToolDefinition>,
#[serde(skip_serializing_if = "std::ops::Not::not")]
pub stream: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Vec::is_empty")]
pub stop_sequences: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub thinking: Option<ThinkingConfig>,
}
impl CompletionRequest {
pub fn new(model: impl Into<String>, messages: Vec<Message>, max_tokens: u32) -> Self {
Self {
model: model.into(),
messages,
max_tokens,
system: None,
tools: Vec::new(),
stream: false,
temperature: None,
stop_sequences: Vec::new(),
thinking: None,
}
}
#[must_use]
pub fn system(mut self, system: impl Into<String>) -> Self {
self.system = Some(system.into());
self
}
#[must_use]
pub fn stream(mut self, stream: bool) -> Self {
self.stream = stream;
self
}
#[must_use]
pub fn temperature(mut self, temp: f32) -> Self {
self.temperature = Some(temp);
self
}
#[must_use]
pub fn tools(mut self, tools: Vec<ToolDefinition>) -> Self {
self.tools = tools;
self
}
#[must_use]
pub fn thinking(mut self, config: ThinkingConfig) -> Self {
self.thinking = Some(config);
self
}
}
#[derive(Clone, Debug, Deserialize)]
pub struct CompletionResponse {
pub id: String,
pub content: Vec<ContentBlock>,
pub model: String,
pub stop_reason: Option<StopReason>,
pub usage: Usage,
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum StopReason {
EndTurn,
MaxTokens,
StopSequence,
ToolUse,
}
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
pub struct Usage {
#[serde(default)]
pub input_tokens: u32,
#[serde(default)]
pub output_tokens: u32,
#[serde(default)]
pub cache_read_tokens: u32,
#[serde(default)]
pub cache_write_tokens: u32,
}
impl Usage {
pub fn total(&self) -> u32 {
self.input_tokens + self.output_tokens + self.cache_read_tokens + self.cache_write_tokens
}
}
#[derive(Clone, Debug)]
pub enum StreamEvent {
MessageStart {
id: String,
model: String,
usage: Usage,
},
ContentBlockStart {
index: u32,
content_block: ContentBlock,
},
ContentBlockDelta {
index: u32,
delta: ContentDelta,
},
ContentBlockStop {
index: u32,
},
MessageDelta {
stop_reason: Option<StopReason>,
usage: Usage,
},
MessageStop,
Ping,
Error {
message: String,
},
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ContentDelta {
TextDelta {
text: String,
},
InputJsonDelta {
partial_json: String,
},
ThinkingDelta {
text: String,
},
}
#[cfg(test)]
mod tests {
use super::*;
use crate::message::Message;
#[test]
fn request_builder() {
let req = CompletionRequest::new(
"claude-sonnet-4-5-20250929",
vec![Message::user("hi")],
1024,
)
.system("You are helpful")
.temperature(0.7)
.stream(true);
assert_eq!(req.model, "claude-sonnet-4-5-20250929");
assert_eq!(req.max_tokens, 1024);
assert!(req.stream);
assert_eq!(req.temperature, Some(0.7));
assert_eq!(req.system, Some("You are helpful".into()));
}
#[test]
fn request_serialization() {
let req = CompletionRequest::new(
"claude-sonnet-4-5-20250929",
vec![Message::user("hi")],
1024,
);
let json = serde_json::to_string(&req);
assert!(json.is_ok());
let json_str = json.as_deref().unwrap_or("");
assert!(json_str.contains("claude-sonnet-4-5-20250929"));
assert!(json_str.contains("1024"));
assert!(!json_str.contains("stream"));
}
#[test]
fn response_parsing() {
let json = r#"{
"id": "msg_123",
"content": [{"type": "text", "text": "Hello!"}],
"model": "claude-sonnet-4-5-20250929",
"stop_reason": "end_turn",
"usage": {"input_tokens": 10, "output_tokens": 5}
}"#;
let resp: std::result::Result<CompletionResponse, _> = serde_json::from_str(json);
assert!(resp.is_ok());
if let Ok(resp) = resp {
assert_eq!(resp.id, "msg_123");
assert_eq!(resp.usage.total(), 15);
assert_eq!(resp.usage.cache_read_tokens, 0);
assert_eq!(resp.usage.cache_write_tokens, 0);
}
}
#[test]
fn stop_reason_parsing() {
let json = r#""end_turn""#;
let reason: Result<StopReason, _> = serde_json::from_str(json);
assert_eq!(reason.ok(), Some(StopReason::EndTurn));
let json = r#""tool_use""#;
let reason: Result<StopReason, _> = serde_json::from_str(json);
assert_eq!(reason.ok(), Some(StopReason::ToolUse));
}
#[test]
fn usage_total() {
let u = Usage {
input_tokens: 100,
output_tokens: 50,
cache_read_tokens: 0,
cache_write_tokens: 0,
};
assert_eq!(u.total(), 150);
}
#[test]
fn usage_total_with_cache_tokens() {
let u = Usage {
input_tokens: 100,
output_tokens: 50,
cache_read_tokens: 20,
cache_write_tokens: 10,
};
assert_eq!(u.total(), 180);
}
#[test]
fn content_delta_serialization() {
let delta = ContentDelta::TextDelta {
text: "hello".into(),
};
let json = serde_json::to_string(&delta);
assert!(json.is_ok());
assert!(json.as_deref().unwrap_or("").contains("text_delta"));
}
#[test]
fn thinking_config_serialization() {
let config = ThinkingConfig {
enabled: true,
budget_tokens: Some(10_000),
};
let json = serde_json::to_string(&config);
assert!(json.is_ok());
let json_str = json.as_deref().unwrap_or("");
assert!(json_str.contains("true"));
assert!(json_str.contains("10000"));
}
#[test]
fn thinking_config_without_budget() {
let config = ThinkingConfig {
enabled: true,
budget_tokens: None,
};
let json = serde_json::to_string(&config);
assert!(json.is_ok());
let json_str = json.as_deref().unwrap_or("");
assert!(json_str.contains("true"));
assert!(!json_str.contains("budget_tokens"));
}
#[test]
fn thinking_config_roundtrip() {
let config = ThinkingConfig {
enabled: true,
budget_tokens: Some(5000),
};
let json = serde_json::to_string(&config).unwrap_or_default();
let parsed: std::result::Result<ThinkingConfig, _> = serde_json::from_str(&json);
assert!(parsed.is_ok());
if let Ok(c) = parsed {
assert!(c.enabled);
assert_eq!(c.budget_tokens, Some(5000));
}
}
#[test]
fn thinking_delta_variant() {
let delta = ContentDelta::ThinkingDelta {
text: "Let me think...".into(),
};
let json = serde_json::to_string(&delta);
assert!(json.is_ok());
assert!(json.as_deref().unwrap_or("").contains("thinking_delta"));
}
#[test]
fn usage_with_cache_tokens_deserialization() {
let json = r#"{"input_tokens": 100, "output_tokens": 50, "cache_read_tokens": 20, "cache_write_tokens": 10}"#;
let usage: std::result::Result<Usage, _> = serde_json::from_str(json);
assert!(usage.is_ok());
if let Ok(u) = usage {
assert_eq!(u.input_tokens, 100);
assert_eq!(u.output_tokens, 50);
assert_eq!(u.cache_read_tokens, 20);
assert_eq!(u.cache_write_tokens, 10);
assert_eq!(u.total(), 180);
}
}
#[test]
fn usage_without_cache_tokens_deserialization() {
let json = r#"{"input_tokens": 100, "output_tokens": 50}"#;
let usage: std::result::Result<Usage, _> = serde_json::from_str(json);
assert!(usage.is_ok());
if let Ok(u) = usage {
assert_eq!(u.cache_read_tokens, 0);
assert_eq!(u.cache_write_tokens, 0);
assert_eq!(u.total(), 150);
}
}
#[test]
fn request_with_thinking() {
let req = CompletionRequest::new("claude-opus-4", vec![Message::user("hi")], 16384)
.thinking(ThinkingConfig {
enabled: true,
budget_tokens: Some(10_000),
});
assert!(req.thinking.is_some());
if let Some(tc) = &req.thinking {
assert!(tc.enabled);
assert_eq!(tc.budget_tokens, Some(10_000));
}
}
}