use serde::{Deserialize, Serialize};
use super::{Role, Annotations};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct CreateMessageResult {
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
pub meta: Option<super::messages::ResultMeta>,
pub content: ContentType,
pub role: Role,
pub model: String,
#[serde(rename = "stopReason", skip_serializing_if = "Option::is_none")]
pub stop_reason: Option<String>,
}
impl CreateMessageResult {
pub fn new(role: Role, content: ContentType, model: impl Into<String>) -> Self {
Self {
meta: None,
content,
role,
model: model.into(),
stop_reason: None,
}
}
pub fn with_stop_reason(
role: Role,
content: ContentType,
model: impl Into<String>,
stop_reason: impl Into<String>,
) -> Self {
Self {
meta: None,
content,
role,
model: model.into(),
stop_reason: Some(stop_reason.into()),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct CreateMessageRequest {
pub method: String,
pub params: CreateMessageParams,
}
impl CreateMessageRequest {
pub fn new(messages: Vec<SamplingMessage>, max_tokens: i32) -> Self {
Self {
method: "sampling/createMessage".to_string(),
params: CreateMessageParams {
messages,
model_preferences: None,
system_prompt: None,
include_context: None,
temperature: None,
max_tokens,
stop_sequences: None,
metadata: None,
},
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct CreateMessageParams {
pub messages: Vec<SamplingMessage>,
#[serde(rename = "modelPreferences", skip_serializing_if = "Option::is_none")]
pub model_preferences: Option<ModelPreferences>,
#[serde(rename = "systemPrompt", skip_serializing_if = "Option::is_none")]
pub system_prompt: Option<String>,
#[serde(rename = "includeContext", skip_serializing_if = "Option::is_none")]
pub include_context: Option<IncludeContext>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f64>,
#[serde(rename = "maxTokens")]
pub max_tokens: i32,
#[serde(rename = "stopSequences", skip_serializing_if = "Option::is_none")]
pub stop_sequences: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub metadata: Option<HashMap<String, serde_json::Value>>,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)]
#[serde(rename_all = "camelCase")]
pub enum IncludeContext {
None,
ThisServer,
AllServers,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct SamplingMessage {
pub role: Role,
pub content: ContentType,
}
impl SamplingMessage {
pub fn text(role: Role, text: impl Into<String>) -> Self {
Self {
role,
content: ContentType::Text(TextContent::new(text)),
}
}
pub fn image(role: Role, data: impl Into<String>, mime_type: impl Into<String>) -> Self {
Self {
role,
content: ContentType::Image(ImageContent::new(data, mime_type)),
}
}
pub fn audio(role: Role, data: impl Into<String>, mime_type: impl Into<String>) -> Self {
Self {
role,
content: ContentType::Audio(AudioContent::new(data, mime_type)),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(untagged)]
pub enum ContentType {
Text(TextContent),
Image(ImageContent),
Audio(AudioContent),
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct TextContent {
pub r#type: String,
pub text: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub annotations: Option<Annotations>,
}
impl TextContent {
pub fn new(text: impl Into<String>) -> Self {
Self {
r#type: "text".to_string(),
text: text.into(),
annotations: None,
}
}
pub fn with_annotations(text: impl Into<String>, annotations: Annotations) -> Self {
Self {
r#type: "text".to_string(),
text: text.into(),
annotations: Some(annotations),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ImageContent {
pub r#type: String,
pub data: String,
#[serde(rename = "mimeType")]
pub mime_type: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub annotations: Option<Annotations>,
}
impl ImageContent {
pub fn new(data: impl Into<String>, mime_type: impl Into<String>) -> Self {
Self {
r#type: "image".to_string(),
data: data.into(),
mime_type: mime_type.into(),
annotations: None,
}
}
pub fn with_annotations(
data: impl Into<String>,
mime_type: impl Into<String>,
annotations: Annotations,
) -> Self {
Self {
r#type: "image".to_string(),
data: data.into(),
mime_type: mime_type.into(),
annotations: Some(annotations),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct AudioContent {
pub r#type: String,
pub data: String,
#[serde(rename = "mimeType")]
pub mime_type: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub annotations: Option<Annotations>,
}
impl AudioContent {
pub fn new(data: impl Into<String>, mime_type: impl Into<String>) -> Self {
Self {
r#type: "audio".to_string(),
data: data.into(),
mime_type: mime_type.into(),
annotations: None,
}
}
pub fn with_annotations(
data: impl Into<String>,
mime_type: impl Into<String>,
annotations: Annotations,
) -> Self {
Self {
r#type: "audio".to_string(),
data: data.into(),
mime_type: mime_type.into(),
annotations: Some(annotations),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
pub struct ModelPreferences {
#[serde(skip_serializing_if = "Option::is_none")]
pub hints: Option<Vec<ModelHint>>,
#[serde(rename = "costPriority", skip_serializing_if = "Option::is_none")]
pub cost_priority: Option<f64>,
#[serde(rename = "speedPriority", skip_serializing_if = "Option::is_none")]
pub speed_priority: Option<f64>,
#[serde(rename = "intelligencePriority", skip_serializing_if = "Option::is_none")]
pub intelligence_priority: Option<f64>,
}
impl ModelPreferences {
pub fn new() -> Self {
Self::default()
}
pub fn with_priorities(cost: f64, speed: f64, intelligence: f64) -> Self {
let cost = cost.max(0.0).min(1.0);
let speed = speed.max(0.0).min(1.0);
let intelligence = intelligence.max(0.0).min(1.0);
Self {
hints: None,
cost_priority: Some(cost),
speed_priority: Some(speed),
intelligence_priority: Some(intelligence),
}
}
pub fn add_hint(&mut self, hint: ModelHint) {
if let Some(hints) = &mut self.hints {
hints.push(hint);
} else {
self.hints = Some(vec![hint]);
}
}
pub fn set_cost_priority(&mut self, priority: f64) {
let priority = priority.max(0.0).min(1.0);
self.cost_priority = Some(priority);
}
pub fn set_speed_priority(&mut self, priority: f64) {
let priority = priority.max(0.0).min(1.0);
self.speed_priority = Some(priority);
}
pub fn set_intelligence_priority(&mut self, priority: f64) {
let priority = priority.max(0.0).min(1.0);
self.intelligence_priority = Some(priority);
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
pub struct ModelHint {
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
#[serde(flatten)]
pub additional: HashMap<String, serde_json::Value>,
}
impl ModelHint {
pub fn new(name: impl Into<String>) -> Self {
Self {
name: Some(name.into()),
additional: HashMap::new(),
}
}
pub fn add_property<T: Serialize>(&mut self, name: impl Into<String>, value: &T) -> Result<(), serde_json::Error> {
self.additional.insert(name.into(), serde_json::to_value(value)?);
Ok(())
}
}