use serde::{Deserialize, Serialize};
use serde_json::Value;
use super::content::Role;
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[non_exhaustive]
#[serde(rename_all = "camelCase")]
pub struct ModelPreferences {
#[serde(skip_serializing_if = "Option::is_none")]
pub hints: Option<Vec<ModelHint>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub cost_priority: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub speed_priority: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub intelligence_priority: Option<f64>,
}
impl ModelPreferences {
pub fn new() -> Self {
Self::default()
}
pub fn with_cost_priority(mut self, priority: f64) -> Self {
self.cost_priority = Some(priority);
self
}
pub fn with_speed_priority(mut self, priority: f64) -> Self {
self.speed_priority = Some(priority);
self
}
pub fn with_intelligence_priority(mut self, priority: f64) -> Self {
self.intelligence_priority = Some(priority);
self
}
pub fn with_hints(mut self, hints: Vec<ModelHint>) -> Self {
self.hints = Some(hints);
self
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[non_exhaustive]
#[serde(rename_all = "camelCase")]
pub struct ModelHint {
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
}
impl ModelHint {
pub fn new(name: impl Into<String>) -> Self {
Self {
name: Some(name.into()),
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[non_exhaustive]
#[serde(rename_all = "camelCase")]
pub struct ToolChoice {
#[serde(skip_serializing_if = "Option::is_none")]
pub mode: Option<ToolChoiceMode>,
}
impl ToolChoice {
pub fn auto() -> Self {
Self {
mode: Some(ToolChoiceMode::Auto),
}
}
pub fn required() -> Self {
Self {
mode: Some(ToolChoiceMode::Required),
}
}
pub fn none() -> Self {
Self {
mode: Some(ToolChoiceMode::None),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum ToolChoiceMode {
Auto,
Required,
None,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum SamplingMessageContent {
#[serde(rename = "text", rename_all = "camelCase")]
Text {
text: String,
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
meta: Option<serde_json::Map<String, Value>>,
},
#[serde(rename = "image", rename_all = "camelCase")]
Image {
data: String,
mime_type: String,
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
meta: Option<serde_json::Map<String, Value>>,
},
#[serde(rename = "audio", rename_all = "camelCase")]
Audio {
data: String,
mime_type: String,
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
meta: Option<serde_json::Map<String, Value>>,
},
#[serde(rename = "tool_use", rename_all = "camelCase")]
ToolUse {
name: String,
id: String,
input: Value,
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
meta: Option<serde_json::Map<String, Value>>,
},
#[serde(rename = "tool_result", rename_all = "camelCase")]
ToolResult {
tool_use_id: String,
content: Vec<super::content::Content>,
#[serde(skip_serializing_if = "Option::is_none")]
structured_content: Option<Value>,
#[serde(skip_serializing_if = "Option::is_none")]
is_error: Option<bool>,
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
meta: Option<serde_json::Map<String, Value>>,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[non_exhaustive]
#[serde(rename_all = "camelCase")]
pub struct CreateMessageParams {
pub messages: Vec<SamplingMessage>,
#[serde(skip_serializing_if = "Option::is_none")]
pub model_preferences: Option<ModelPreferences>,
#[serde(skip_serializing_if = "Option::is_none")]
pub system_prompt: Option<String>,
#[serde(default)]
pub include_context: IncludeContext,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop_sequences: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub metadata: Option<Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<super::tools::ToolInfo>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<ToolChoice>,
}
impl CreateMessageParams {
pub fn new(messages: Vec<SamplingMessage>) -> Self {
Self {
messages,
model_preferences: None,
system_prompt: None,
include_context: IncludeContext::default(),
temperature: None,
max_tokens: None,
stop_sequences: None,
metadata: None,
tools: None,
tool_choice: None,
}
}
pub fn with_model_preferences(mut self, prefs: ModelPreferences) -> Self {
self.model_preferences = Some(prefs);
self
}
pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
self.system_prompt = Some(prompt.into());
self
}
pub fn with_temperature(mut self, temp: f64) -> Self {
self.temperature = Some(temp);
self
}
pub fn with_max_tokens(mut self, tokens: u32) -> Self {
self.max_tokens = Some(tokens);
self
}
pub fn with_tools(mut self, tools: Vec<super::tools::ToolInfo>) -> Self {
self.tools = Some(tools);
self
}
pub fn with_tool_choice(mut self, choice: ToolChoice) -> Self {
self.tool_choice = Some(choice);
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[non_exhaustive]
#[serde(rename_all = "camelCase")]
pub struct CreateMessageResult {
pub content: super::content::Content,
pub model: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub usage: Option<TokenUsage>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop_reason: Option<String>,
}
impl CreateMessageResult {
pub fn new(content: super::content::Content, model: impl Into<String>) -> Self {
Self {
content,
model: model.into(),
usage: None,
stop_reason: None,
}
}
pub fn with_usage(mut self, usage: TokenUsage) -> Self {
self.usage = Some(usage);
self
}
pub fn with_stop_reason(mut self, reason: impl Into<String>) -> Self {
self.stop_reason = Some(reason.into());
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[non_exhaustive]
#[serde(rename_all = "camelCase")]
pub struct CreateMessageResultWithTools {
pub model: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop_reason: Option<String>,
pub role: Role,
pub content: Vec<SamplingMessageContent>,
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
pub meta: Option<serde_json::Map<String, Value>>,
}
impl CreateMessageResultWithTools {
pub fn new(model: impl Into<String>, role: Role, content: Vec<SamplingMessageContent>) -> Self {
Self {
model: model.into(),
stop_reason: None,
role,
content,
meta: None,
}
}
pub fn with_stop_reason(mut self, reason: impl Into<String>) -> Self {
self.stop_reason = Some(reason.into());
self
}
pub fn with_meta(mut self, meta: serde_json::Map<String, Value>) -> Self {
self.meta = Some(meta);
self
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[non_exhaustive]
#[serde(rename_all = "camelCase")]
pub struct TokenUsage {
pub input_tokens: u32,
pub output_tokens: u32,
pub total_tokens: u32,
}
impl TokenUsage {
pub fn new(input_tokens: u32, output_tokens: u32, total_tokens: u32) -> Self {
Self {
input_tokens,
output_tokens,
total_tokens,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[non_exhaustive]
#[serde(rename_all = "camelCase")]
pub struct SamplingMessage {
pub role: Role,
pub content: SamplingMessageContent,
}
impl SamplingMessage {
pub fn new(role: Role, content: SamplingMessageContent) -> Self {
Self { role, content }
}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, Default)]
#[serde(rename_all = "camelCase")]
pub enum IncludeContext {
AllServers,
#[default]
None,
ThisServer,
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn include_context_serializes_correctly() {
assert_eq!(
serde_json::to_value(IncludeContext::AllServers).unwrap(),
"allServers"
);
assert_eq!(serde_json::to_value(IncludeContext::None).unwrap(), "none");
assert_eq!(
serde_json::to_value(IncludeContext::ThisServer).unwrap(),
"thisServer"
);
}
#[test]
fn include_context_deserializes_correctly() {
let all: IncludeContext = serde_json::from_value(json!("allServers")).unwrap();
assert!(matches!(all, IncludeContext::AllServers));
let none: IncludeContext = serde_json::from_value(json!("none")).unwrap();
assert!(matches!(none, IncludeContext::None));
let this: IncludeContext = serde_json::from_value(json!("thisServer")).unwrap();
assert!(matches!(this, IncludeContext::ThisServer));
}
#[test]
fn tool_choice_serialization() {
let choice = ToolChoice::auto();
let json = serde_json::to_value(&choice).unwrap();
assert_eq!(json["mode"], "auto");
let choice2 = ToolChoice::required();
let json2 = serde_json::to_value(&choice2).unwrap();
assert_eq!(json2["mode"], "required");
let choice3 = ToolChoice::none();
let json3 = serde_json::to_value(&choice3).unwrap();
assert_eq!(json3["mode"], "none");
}
#[test]
fn create_message_result_with_tools_roundtrip() {
let result = CreateMessageResultWithTools::new(
"claude-3",
Role::Assistant,
vec![
SamplingMessageContent::Text {
text: "I'll call the tool.".to_string(),
meta: None,
},
SamplingMessageContent::ToolUse {
name: "search".to_string(),
id: "tu-1".to_string(),
input: json!({"query": "rust"}),
meta: None,
},
],
)
.with_stop_reason("end_turn");
let json = serde_json::to_value(&result).unwrap();
assert_eq!(json["model"], "claude-3");
assert_eq!(json["role"], "assistant");
assert_eq!(json["content"].as_array().unwrap().len(), 2);
assert_eq!(json["content"][0]["type"], "text");
assert_eq!(json["content"][1]["type"], "tool_use");
assert_eq!(json["content"][1]["name"], "search");
let roundtrip: CreateMessageResultWithTools = serde_json::from_value(json).unwrap();
assert_eq!(roundtrip.model, "claude-3");
assert_eq!(roundtrip.content.len(), 2);
}
#[test]
fn sampling_message_with_tool_use_content() {
let msg = SamplingMessage::new(
Role::Assistant,
SamplingMessageContent::ToolUse {
name: "calculate".to_string(),
id: "tu-2".to_string(),
input: json!({"expression": "2+2"}),
meta: None,
},
);
let json = serde_json::to_value(&msg).unwrap();
assert_eq!(json["role"], "assistant");
assert_eq!(json["content"]["type"], "tool_use");
assert_eq!(json["content"]["name"], "calculate");
let roundtrip: SamplingMessage = serde_json::from_value(json).unwrap();
match roundtrip.content {
SamplingMessageContent::ToolUse { name, id, .. } => {
assert_eq!(name, "calculate");
assert_eq!(id, "tu-2");
},
_ => panic!("Expected ToolUse content"),
}
}
#[test]
fn sampling_message_content_text_roundtrip() {
let content = SamplingMessageContent::Text {
text: "hello".to_string(),
meta: None,
};
let json = serde_json::to_value(&content).unwrap();
assert_eq!(json["type"], "text");
assert_eq!(json["text"], "hello");
let roundtrip: SamplingMessageContent = serde_json::from_value(json).unwrap();
match roundtrip {
SamplingMessageContent::Text { text, .. } => assert_eq!(text, "hello"),
_ => panic!("Expected Text"),
}
}
}