use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use super::{content::ContentBlock, core::Role};
use super::tools::Tool;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "camelCase")]
pub enum IncludeContext {
None,
ThisServer,
AllServers,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SamplingMessage {
pub role: Role,
pub content: ContentBlock,
#[serde(skip_serializing_if = "Option::is_none")]
pub metadata: Option<HashMap<String, serde_json::Value>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CreateMessageRequest {
pub messages: Vec<SamplingMessage>,
#[serde(skip_serializing_if = "Option::is_none")]
pub model_preferences: Option<ModelPreferences>,
#[serde(rename = "systemPrompt", skip_serializing_if = "Option::is_none")]
pub system_prompt: Option<String>,
#[serde(rename = "includeContext", skip_serializing_if = "Option::is_none")]
pub include_context: Option<IncludeContext>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f64>,
#[serde(rename = "maxTokens")]
pub max_tokens: u32,
#[serde(rename = "stopSequences", skip_serializing_if = "Option::is_none")]
pub stop_sequences: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<Tool>>,
#[serde(rename = "toolChoice", skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<ToolChoice>,
#[serde(skip_serializing_if = "Option::is_none")]
pub task: Option<crate::types::tasks::TaskMetadata>,
#[serde(skip_serializing_if = "Option::is_none")]
pub _meta: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ModelHint {
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
}
impl ModelHint {
pub fn new(name: impl Into<String>) -> Self {
Self {
name: Some(name.into()),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelPreferences {
#[serde(skip_serializing_if = "Option::is_none")]
pub hints: Option<Vec<ModelHint>>,
#[serde(rename = "costPriority", skip_serializing_if = "Option::is_none")]
pub cost_priority: Option<f64>,
#[serde(rename = "speedPriority", skip_serializing_if = "Option::is_none")]
pub speed_priority: Option<f64>,
#[serde(
rename = "intelligencePriority",
skip_serializing_if = "Option::is_none"
)]
pub intelligence_priority: Option<f64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CreateMessageResult {
pub role: super::core::Role,
pub content: ContentBlock,
pub model: String,
#[serde(rename = "stopReason", skip_serializing_if = "Option::is_none")]
pub stop_reason: Option<StopReason>,
#[serde(skip_serializing_if = "Option::is_none")]
pub _meta: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)]
#[serde(rename_all = "camelCase")]
pub enum StopReason {
EndTurn,
MaxTokens,
StopSequence,
ContentFilter,
ToolUse,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UsageStats {
#[serde(rename = "inputTokens", skip_serializing_if = "Option::is_none")]
pub input_tokens: Option<u32>,
#[serde(rename = "outputTokens", skip_serializing_if = "Option::is_none")]
pub output_tokens: Option<u32>,
#[serde(rename = "totalTokens", skip_serializing_if = "Option::is_none")]
pub total_tokens: Option<u32>,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash, Default)]
#[serde(rename_all = "lowercase")]
pub enum ToolChoiceMode {
#[default]
Auto,
Required,
None,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct ToolChoice {
#[serde(skip_serializing_if = "Option::is_none")]
pub mode: Option<ToolChoiceMode>,
}
impl ToolChoice {
pub fn auto() -> Self {
Self {
mode: Some(ToolChoiceMode::Auto),
}
}
pub fn required() -> Self {
Self {
mode: Some(ToolChoiceMode::Required),
}
}
pub fn none() -> Self {
Self {
mode: Some(ToolChoiceMode::None),
}
}
}
impl Default for ToolChoice {
fn default() -> Self {
Self::auto()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tool_choice_mode_serialization() {
assert_eq!(
serde_json::to_string(&ToolChoiceMode::Auto).unwrap(),
"\"auto\""
);
assert_eq!(
serde_json::to_string(&ToolChoiceMode::Required).unwrap(),
"\"required\""
);
assert_eq!(
serde_json::to_string(&ToolChoiceMode::None).unwrap(),
"\"none\""
);
}
#[test]
fn test_tool_choice_constructors() {
let auto = ToolChoice::auto();
assert_eq!(auto.mode, Some(ToolChoiceMode::Auto));
let required = ToolChoice::required();
assert_eq!(required.mode, Some(ToolChoiceMode::Required));
let none = ToolChoice::none();
assert_eq!(none.mode, Some(ToolChoiceMode::None));
}
#[test]
fn test_tool_choice_default() {
let default = ToolChoice::default();
assert_eq!(default.mode, Some(ToolChoiceMode::Auto));
}
#[test]
fn test_tool_choice_serialization() {
let choice = ToolChoice::required();
let json = serde_json::to_string(&choice).unwrap();
assert!(json.contains("\"mode\":\"required\""));
}
}