use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum Role {
User,
Assistant,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ContentPart {
Text {
text: String,
},
Image {
media_type: String,
data: String,
},
ToolUse {
id: String,
name: String,
input: serde_json::Value,
},
ToolResult {
tool_use_id: String,
content: String,
},
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct Message {
pub role: Role,
pub content: Vec<ContentPart>,
}
impl Message {
#[allow(dead_code)]
pub fn text(role: Role, text: impl Into<String>) -> Self {
Message {
role,
content: vec![ContentPart::Text { text: text.into() }],
}
}
pub fn text_content(&self) -> String {
self.content
.iter()
.filter_map(|part| match part {
ContentPart::Text { text } => Some(text.as_str()),
_ => None,
})
.collect::<Vec<_>>()
.join("")
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatRequest {
pub model: String,
pub system: Option<String>,
pub messages: Vec<Message>,
pub max_tokens: Option<u32>,
pub temperature: Option<f32>,
pub thinking: Option<serde_json::Value>,
pub effort: Option<String>,
pub task_budget: Option<serde_json::Value>,
pub output_schema: Option<serde_json::Value>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub tools: Vec<Tool>,
pub stream: bool,
#[serde(default)]
pub plugins: Vec<PluginRequest>,
#[serde(default, skip_serializing)]
pub forced_provider: Option<String>,
#[serde(default, skip_serializing)]
pub tags: Vec<String>,
}
impl ChatRequest {
pub fn needed_capabilities(&self) -> Vec<String> {
let mut needed = Vec::new();
if self
.messages
.iter()
.any(|m| m.content.iter().any(|p| matches!(p, ContentPart::Image { .. })))
{
needed.push("vision".to_string());
}
if !self.tools.is_empty() {
needed.push("tools".to_string());
}
needed
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Tool {
pub name: String,
pub description: Option<String>,
pub input_schema: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCall {
pub id: String,
pub name: String,
pub input: serde_json::Value,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct PluginRequest {
pub id: String,
#[serde(flatten)]
pub settings: serde_json::Map<String, serde_json::Value>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum StopReason {
EndTurn,
MaxTokens,
ToolUse,
Other,
}
#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
pub struct Usage {
pub input_tokens: u32,
pub output_tokens: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatResponse {
pub id: String,
pub model: String,
pub content: String,
pub stop_reason: StopReason,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub tool_calls: Vec<ToolCall>,
pub usage: Usage,
#[serde(default)]
pub tags: Vec<String>,
}
#[derive(Debug, Clone, PartialEq)]
pub enum StreamEvent {
TextDelta {
text: String,
},
ToolCallStart {
id: String,
name: String,
},
ToolCallDelta {
id: String,
partial_input: String,
},
Done {
stop_reason: StopReason,
usage: Usage,
},
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn text_part_serializes_as_tagged_shape() {
let part = ContentPart::Text { text: "hi".to_string() };
assert_eq!(serde_json::to_value(&part).unwrap(), json!({"type": "text", "text": "hi"}));
}
#[test]
fn image_part_serializes_as_tagged_shape() {
let part = ContentPart::Image { media_type: "image/png".to_string(), data: "abc123".to_string() };
assert_eq!(
serde_json::to_value(&part).unwrap(),
json!({"type": "image", "media_type": "image/png", "data": "abc123"})
);
}
#[test]
fn tool_use_part_serializes_as_tagged_shape() {
let part = ContentPart::ToolUse {
id: "call_1".to_string(),
name: "get_weather".to_string(),
input: json!({"city": "nyc"}),
};
assert_eq!(
serde_json::to_value(&part).unwrap(),
json!({"type": "tool_use", "id": "call_1", "name": "get_weather", "input": {"city": "nyc"}})
);
}
#[test]
fn tool_result_part_serializes_as_tagged_shape() {
let part = ContentPart::ToolResult { tool_use_id: "call_1".to_string(), content: "sunny".to_string() };
assert_eq!(
serde_json::to_value(&part).unwrap(),
json!({"type": "tool_result", "tool_use_id": "call_1", "content": "sunny"})
);
}
#[test]
fn content_part_round_trips_through_serde() {
for part in [
ContentPart::Text { text: "hi".to_string() },
ContentPart::Image { media_type: "image/png".to_string(), data: "abc".to_string() },
ContentPart::ToolUse { id: "1".to_string(), name: "f".to_string(), input: json!({}) },
ContentPart::ToolResult { tool_use_id: "1".to_string(), content: "r".to_string() },
] {
let value = serde_json::to_value(&part).unwrap();
let back: ContentPart = serde_json::from_value(value).unwrap();
assert_eq!(part, back);
}
}
#[test]
fn text_constructor_produces_single_text_part() {
let msg = Message::text(Role::User, "hi");
assert_eq!(msg.role, Role::User);
assert_eq!(msg.content, vec![ContentPart::Text { text: "hi".to_string() }]);
}
#[test]
fn text_content_joins_multiple_text_parts() {
let msg = Message {
role: Role::User,
content: vec![
ContentPart::Text { text: "hello ".to_string() },
ContentPart::Text { text: "world".to_string() },
],
};
assert_eq!(msg.text_content(), "hello world");
}
#[test]
fn text_content_skips_non_text_parts() {
let msg = Message {
role: Role::User,
content: vec![
ContentPart::Text { text: "describe this".to_string() },
ContentPart::Image { media_type: "image/png".to_string(), data: "abc".to_string() },
ContentPart::ToolUse { id: "1".to_string(), name: "f".to_string(), input: json!({}) },
ContentPart::ToolResult { tool_use_id: "1".to_string(), content: "r".to_string() },
],
};
assert_eq!(msg.text_content(), "describe this");
}
fn base_request(messages: Vec<Message>) -> ChatRequest {
ChatRequest {
model: "test-model".to_string(),
system: None,
messages,
max_tokens: None,
temperature: None,
thinking: None,
effort: None,
task_budget: None,
output_schema: None,
tools: Vec::new(),
stream: false,
plugins: Vec::new(),
forced_provider: None,
tags: Vec::new(),
}
}
#[test]
fn needed_capabilities_empty_for_plain_text_request() {
let req = base_request(vec![Message::text(Role::User, "hi")]);
assert_eq!(req.needed_capabilities(), Vec::<String>::new());
}
#[test]
fn needed_capabilities_includes_vision_for_image_content() {
let req = base_request(vec![Message {
role: Role::User,
content: vec![ContentPart::Image { media_type: "image/png".to_string(), data: "abc".to_string() }],
}]);
assert_eq!(req.needed_capabilities(), vec!["vision".to_string()]);
}
#[test]
fn needed_capabilities_includes_tools_for_non_empty_tools() {
let mut req = base_request(vec![Message::text(Role::User, "hi")]);
req.tools = vec![Tool {
name: "get_weather".to_string(),
description: None,
input_schema: json!({"type": "object"}),
}];
assert_eq!(req.needed_capabilities(), vec!["tools".to_string()]);
}
#[test]
fn needed_capabilities_includes_both_when_applicable() {
let mut req = base_request(vec![Message {
role: Role::User,
content: vec![ContentPart::Image { media_type: "image/png".to_string(), data: "abc".to_string() }],
}]);
req.tools = vec![Tool {
name: "get_weather".to_string(),
description: None,
input_schema: json!({"type": "object"}),
}];
assert_eq!(req.needed_capabilities(), vec!["vision".to_string(), "tools".to_string()]);
}
}