use crate::shared::{IntoArgs, OneOrMany};
use crate::types::{
AudioContent, Content, EmbeddedResource, ImageContent, IntoResponse, PromptMessage, RequestId,
ResourceLink, Response, Role, TextContent, Tool, ToolResult, ToolUse,
};
use serde::{Deserialize, Deserializer, Serialize, Serializer};
#[cfg(feature = "tasks")]
use crate::types::TaskMetadata;
#[cfg(feature = "client")]
use std::{future::Future, pin::Pin, sync::Arc};
const DEFAULT_MESSAGE_MAX_TOKENS: i32 = 512;
pub mod commands {
pub const CREATE: &str = "sampling/createMessage";
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SamplingMessage {
pub role: Role,
pub content: OneOrMany<Content>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CreateMessageRequestParams {
pub messages: Vec<SamplingMessage>,
#[serde(rename = "maxTokens")]
pub max_tokens: i32,
#[serde(rename = "includeContext", skip_serializing_if = "Option::is_none")]
pub include_context: Option<ContextInclusion>,
#[serde(rename = "metadata", skip_serializing_if = "Option::is_none")]
pub meta: Option<serde_json::Value>,
#[serde(rename = "modelPreferences", skip_serializing_if = "Option::is_none")]
pub model_pref: Option<ModelPreferences>,
#[serde(rename = "systemPrompt", skip_serializing_if = "Option::is_none")]
pub sys_prompt: Option<String>,
#[serde(rename = "temperature", skip_serializing_if = "Option::is_none")]
pub temp: Option<f32>,
#[serde(rename = "stopSequences", skip_serializing_if = "Option::is_none")]
pub stop_sequences: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<Tool>>,
#[serde(rename = "toolChoice", skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<ToolChoice>,
#[cfg(feature = "tasks")]
#[serde(skip_serializing_if = "Option::is_none")]
pub task: Option<TaskMetadata>,
}
#[derive(Default, Debug, Clone, Copy, Serialize, Deserialize)]
pub struct ToolChoice {
pub mode: ToolChoiceMode,
}
#[derive(Default, Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum ToolChoiceMode {
#[default]
Auto,
Required,
None,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub enum ContextInclusion {
#[serde(rename = "none")]
None,
#[serde(rename = "thisServer")]
ThisServer,
#[serde(rename = "allServers")]
AllServers,
}
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
pub struct ModelPreferences {
#[serde(rename = "costPriority", skip_serializing_if = "Option::is_none")]
pub cost_priority: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub hints: Option<Vec<ModelHint>>,
#[serde(rename = "speedPriority", skip_serializing_if = "Option::is_none")]
pub speed_priority: Option<f32>,
#[serde(
rename = "intelligencePriority",
skip_serializing_if = "Option::is_none"
)]
pub intelligence_priority: Option<f32>,
}
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
pub struct ModelHint {
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CreateMessageResult {
pub role: Role,
pub content: OneOrMany<Content>,
pub model: String,
#[serde(rename = "stopReason", skip_serializing_if = "Option::is_none")]
pub stop_reason: Option<StopReason>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum StopReason {
EndTurn,
MaxTokens,
StopSequence,
ToolUse,
Other(String),
}
impl Serialize for StopReason {
#[inline]
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
match self {
StopReason::EndTurn => serializer.serialize_str("endTurn"),
StopReason::MaxTokens => serializer.serialize_str("maxTokens"),
StopReason::StopSequence => serializer.serialize_str("stopSequence"),
StopReason::ToolUse => serializer.serialize_str("toolUse"),
StopReason::Other(s) => serializer.serialize_str(s),
}
}
}
impl<'de> Deserialize<'de> for StopReason {
#[inline]
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
Ok(StopReason::from(s))
}
}
impl From<String> for StopReason {
#[inline]
fn from(s: String) -> Self {
match s.as_str() {
"endTurn" => StopReason::EndTurn,
"maxTokens" => StopReason::MaxTokens,
"stopSequence" => StopReason::StopSequence,
"toolUse" => StopReason::ToolUse,
_ => StopReason::Other(s),
}
}
}
impl From<&str> for StopReason {
#[inline]
fn from(s: &str) -> Self {
match s {
"endTurn" => StopReason::EndTurn,
"maxTokens" => StopReason::MaxTokens,
"stopSequence" => StopReason::StopSequence,
"toolUse" => StopReason::ToolUse,
_ => StopReason::Other(s.to_string()),
}
}
}
impl Default for CreateMessageRequestParams {
#[inline]
fn default() -> Self {
Self {
max_tokens: DEFAULT_MESSAGE_MAX_TOKENS,
messages: Vec::with_capacity(8),
sys_prompt: None,
include_context: None,
meta: None,
model_pref: None,
temp: None,
stop_sequences: None,
tool_choice: None,
tools: None,
#[cfg(feature = "tasks")]
task: None,
}
}
}
impl IntoResponse for CreateMessageResult {
#[inline]
fn into_response(self, req_id: RequestId) -> Response {
match serde_json::to_value(self) {
Ok(v) => Response::success(req_id, v),
Err(err) => Response::error(req_id, err.into()),
}
}
}
impl From<&str> for SamplingMessage {
#[inline]
fn from(s: &str) -> Self {
Self::user().with(s)
}
}
impl From<String> for SamplingMessage {
#[inline]
fn from(s: String) -> Self {
Self::user().with(s)
}
}
impl From<PromptMessage> for SamplingMessage {
#[inline]
fn from(msg: PromptMessage) -> Self {
Self::new(msg.role).with(msg.content)
}
}
impl From<&str> for ModelHint {
#[inline]
fn from(s: &str) -> Self {
Self::new(s)
}
}
impl From<String> for ModelHint {
#[inline]
fn from(s: String) -> Self {
Self::new(s)
}
}
impl SamplingMessage {
#[inline]
pub fn new(role: Role) -> Self {
Self {
content: OneOrMany::new(),
role,
}
}
pub fn user() -> Self {
Self::new(Role::User)
}
pub fn assistant() -> Self {
Self::new(Role::Assistant)
}
pub fn with<T: Into<Content>>(mut self, content: T) -> Self {
self.content.push(content.into());
self
}
}
impl ModelPreferences {
#[inline]
pub fn new() -> Self {
Self::default()
}
pub fn with_cost_priority(mut self, priority: f32) -> Self {
self.cost_priority = Some(priority);
self
}
pub fn with_speed_priority(mut self, priority: f32) -> Self {
self.speed_priority = Some(priority);
self
}
pub fn with_intel_priority(mut self, priority: f32) -> Self {
self.intelligence_priority = Some(priority);
self
}
pub fn with_hint(mut self, hint: impl Into<ModelHint>) -> Self {
self.hints.get_or_insert_with(Vec::new).push(hint.into());
self
}
pub fn with_hints<T, I>(mut self, hint: T) -> Self
where
T: IntoIterator<Item = I>,
I: Into<ModelHint>,
{
self.hints
.get_or_insert_with(Vec::new)
.extend(hint.into_iter().map(Into::into));
self
}
}
impl ModelHint {
#[inline]
pub fn new(name: impl Into<String>) -> Self {
Self {
name: Some(name.into()),
}
}
}
impl ToolChoice {
#[inline]
pub fn auto() -> Self {
Self {
mode: ToolChoiceMode::Auto,
}
}
#[inline]
pub fn none() -> Self {
Self {
mode: ToolChoiceMode::None,
}
}
#[inline]
pub fn required() -> Self {
Self {
mode: ToolChoiceMode::Required,
}
}
#[inline]
pub fn is_auto(&self) -> bool {
self.mode == ToolChoiceMode::Auto
}
#[inline]
pub fn is_none(&self) -> bool {
self.mode == ToolChoiceMode::None
}
#[inline]
pub fn is_required(&self) -> bool {
self.mode == ToolChoiceMode::Required
}
}
impl CreateMessageRequestParams {
pub fn new() -> Self {
Self::default()
}
pub fn with_message(mut self, message: impl Into<SamplingMessage>) -> Self {
self.messages.push(message.into());
self
}
pub fn with_messages<T, I>(mut self, messages: I) -> Self
where
I: IntoIterator<Item = T>,
T: Into<SamplingMessage>,
{
self.messages.extend(messages.into_iter().map(Into::into));
self
}
pub fn with_sys_prompt(mut self, sys_prompt: impl Into<String>) -> Self {
self.sys_prompt = Some(sys_prompt.into());
self
}
pub fn with_max_tokens(mut self, max_tokens: i32) -> Self {
self.max_tokens = max_tokens;
self
}
pub fn with_include_ctx(mut self, inc: ContextInclusion) -> Self {
self.include_context = Some(inc);
self
}
pub fn with_no_ctx(mut self) -> Self {
self.include_context = Some(ContextInclusion::None);
self
}
pub fn with_this_server(mut self) -> Self {
self.include_context = Some(ContextInclusion::ThisServer);
self
}
pub fn with_all_servers(mut self) -> Self {
self.include_context = Some(ContextInclusion::AllServers);
self
}
pub fn with_pref(mut self, pref: ModelPreferences) -> Self {
self.model_pref = Some(pref);
self
}
pub fn with_temp(mut self, temp: f32) -> Self {
self.temp = Some(temp);
self
}
pub fn with_stop_seq(mut self, stop_sequences: Vec<String>) -> Self {
self.stop_sequences = Some(stop_sequences);
self
}
pub fn with_tools<T: IntoIterator<Item = Tool>>(mut self, tools: T) -> Self {
self.tools = Some(tools.into_iter().collect());
self.with_tool_choice(ToolChoiceMode::Auto)
}
pub fn with_tool_choice(mut self, mode: ToolChoiceMode) -> Self {
self.tool_choice = Some(ToolChoice { mode });
self
}
#[cfg(feature = "tasks")]
pub fn with_ttl(mut self, ttl: Option<usize>) -> Self {
self.task = Some(TaskMetadata { ttl });
self
}
pub fn text(&self) -> impl Iterator<Item = &TextContent> {
self.msg_iter("text").filter_map(|c| c.as_text())
}
pub fn audio(&self) -> impl Iterator<Item = &AudioContent> {
self.msg_iter("audio").filter_map(|c| c.as_audio())
}
pub fn images(&self) -> impl Iterator<Item = &ImageContent> {
self.msg_iter("image").filter_map(|c| c.as_image())
}
pub fn links(&self) -> impl Iterator<Item = &ResourceLink> {
self.msg_iter("resource_link").filter_map(|c| c.as_link())
}
pub fn resources(&self) -> impl Iterator<Item = &EmbeddedResource> {
self.msg_iter("resource").filter_map(|c| c.as_resource())
}
pub fn tools(&self) -> impl Iterator<Item = &ToolUse> {
self.msg_iter("tool_use").filter_map(|c| c.as_tool())
}
pub fn results(&self) -> impl Iterator<Item = &ToolResult> {
self.msg_iter("tool_result").filter_map(|c| c.as_result())
}
#[inline]
fn msg_iter(&self, t: &'static str) -> impl Iterator<Item = &Content> {
self.messages
.iter()
.flat_map(|m| m.content.as_slice())
.filter(move |c| c.get_type() == t)
}
}
impl CreateMessageResult {
#[inline]
pub fn new(role: Role) -> Self {
Self {
stop_reason: None,
model: String::new(),
content: OneOrMany::new(),
role,
}
}
pub fn user() -> Self {
Self::new(Role::User)
}
pub fn assistant() -> Self {
Self::new(Role::Assistant)
}
pub fn with_stop_reason(mut self, reason: impl Into<StopReason>) -> Self {
self.stop_reason = Some(reason.into());
self
}
pub fn with_model(mut self, model: impl Into<String>) -> Self {
self.model = model.into();
self
}
pub fn with_content<T: Into<Content>>(mut self, content: T) -> Self {
self.content.push(content.into());
self
}
#[inline]
pub fn end_turn(self) -> Self {
self.with_stop_reason(StopReason::EndTurn)
}
pub fn use_tool<N, Args>(self, name: N, args: Args) -> Self
where
N: Into<String>,
Args: IntoArgs,
{
self.with_content(ToolUse::new(name, args))
.with_stop_reason(StopReason::ToolUse)
}
pub fn use_tools<N, Args>(self, tools: impl IntoIterator<Item = (N, Args)>) -> Self
where
N: Into<String>,
Args: IntoArgs,
{
tools
.into_iter()
.fold(self, |acc, (name, args)| acc.use_tool(name, args))
.with_stop_reason(StopReason::ToolUse)
}
pub fn text(&self) -> impl Iterator<Item = &TextContent> {
self.msg_iter("text").filter_map(|c| c.as_text())
}
pub fn audio(&self) -> impl Iterator<Item = &AudioContent> {
self.msg_iter("audio").filter_map(|c| c.as_audio())
}
pub fn images(&self) -> impl Iterator<Item = &ImageContent> {
self.msg_iter("image").filter_map(|c| c.as_image())
}
pub fn links(&self) -> impl Iterator<Item = &ResourceLink> {
self.msg_iter("resource_link").filter_map(|c| c.as_link())
}
pub fn resources(&self) -> impl Iterator<Item = &EmbeddedResource> {
self.msg_iter("resource").filter_map(|c| c.as_resource())
}
pub fn tools(&self) -> impl Iterator<Item = &ToolUse> {
self.msg_iter("tool_use").filter_map(|c| c.as_tool())
}
pub fn results(&self) -> impl Iterator<Item = &ToolResult> {
self.msg_iter("tool_result").filter_map(|c| c.as_result())
}
#[inline]
fn msg_iter(&self, t: &'static str) -> impl Iterator<Item = &Content> {
self.content.iter().filter(move |c| c.get_type() == t)
}
}
#[cfg(feature = "client")]
pub(crate) type SamplingHandler = Arc<
dyn Fn(
CreateMessageRequestParams,
) -> Pin<Box<dyn Future<Output = CreateMessageResult> + Send + 'static>>
+ Send
+ Sync,
>;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn it_sets_auto_tool_choice_mode_by_default() {
let mode = ToolChoiceMode::default();
assert_eq!(mode, ToolChoiceMode::Auto);
}
#[test]
fn it_sets_auto_tool_choice_by_default() {
let tool_choice = ToolChoice::default();
assert_eq!(tool_choice.mode, ToolChoiceMode::Auto);
}
#[test]
#[cfg(feature = "server")]
fn it_sets_auto_tool_choice_when_tools_specified() {
let params = CreateMessageRequestParams::new().with_tools([
Tool::new("test 1", async || "test 1"),
Tool::new("test 2", async || "test 2"),
]);
assert_eq!(params.tool_choice.unwrap().mode, ToolChoiceMode::Auto);
}
#[test]
fn it_sets_tool_choice() {
let params = CreateMessageRequestParams::new().with_tool_choice(ToolChoiceMode::Required);
assert_eq!(params.tool_choice.unwrap().mode, ToolChoiceMode::Required);
}
#[test]
fn it_builds_sampling_message() {
let msg = SamplingMessage::user().with("Hello");
assert_eq!(msg.role, Role::User);
assert_eq!(msg.content.len(), 1);
}
#[test]
fn it_builds_create_message_request_params() {
let params = CreateMessageRequestParams::new()
.with_message("Hello")
.with_sys_prompt("System prompt")
.with_max_tokens(100)
.with_temp(0.7);
assert_eq!(params.messages.len(), 1);
assert_eq!(params.sys_prompt.as_deref(), Some("System prompt"));
assert_eq!(params.max_tokens, 100);
assert_eq!(params.temp, Some(0.7));
}
#[test]
fn it_sets_context_inclusion() {
let params = CreateMessageRequestParams::new().with_no_ctx();
assert!(matches!(
params.include_context,
Some(ContextInclusion::None)
));
let params = CreateMessageRequestParams::new().with_this_server();
assert!(matches!(
params.include_context,
Some(ContextInclusion::ThisServer)
));
let params = CreateMessageRequestParams::new().with_all_servers();
assert!(matches!(
params.include_context,
Some(ContextInclusion::AllServers)
));
}
#[test]
fn it_builds_create_message_result() {
let result = CreateMessageResult::assistant()
.with_model("gpt-4")
.with_content("Hello world")
.end_turn();
assert_eq!(result.role, Role::Assistant);
assert_eq!(result.model, "gpt-4");
assert_eq!(result.content.len(), 1);
assert_eq!(result.stop_reason, Some(StopReason::EndTurn));
}
#[test]
fn it_handles_tool_use_in_result() {
let result = CreateMessageResult::assistant().use_tool("calculator", ());
assert_eq!(result.stop_reason, Some(StopReason::ToolUse));
assert_eq!(result.content.len(), 1);
let tool_use = result.tools().next().unwrap();
assert_eq!(tool_use.name, "calculator");
}
#[test]
fn it_adds_model_hints() {
let pref = ModelPreferences::new()
.with_hint("claude")
.with_hints(["gpt-4", "llama"]);
assert_eq!(pref.hints.as_ref().unwrap().len(), 3);
assert_eq!(
pref.hints.as_ref().unwrap()[0].name.as_deref(),
Some("claude")
);
}
#[test]
fn it_converts_stop_reason_from_str() {
let reasons = [
(StopReason::ToolUse, "toolUse"),
(StopReason::MaxTokens, "maxTokens"),
(StopReason::EndTurn, "endTurn"),
(StopReason::StopSequence, "stopSequence"),
(StopReason::Other("test".to_string()), "test"),
];
for (expected, reason_str) in reasons {
let reason = StopReason::from(reason_str);
assert_eq!(reason, expected);
}
}
#[test]
fn it_converts_stop_reason_from_string() {
let reasons = [
(StopReason::ToolUse, "toolUse"),
(StopReason::MaxTokens, "maxTokens"),
(StopReason::EndTurn, "endTurn"),
(StopReason::StopSequence, "stopSequence"),
(StopReason::Other("test".to_string()), "test"),
];
for (expected, reason_str) in reasons {
let reason = StopReason::from(reason_str.to_string());
assert_eq!(reason, expected);
}
}
#[test]
fn it_serializes_stop_reason() {
let reasons = [
(StopReason::ToolUse, "\"toolUse\""),
(StopReason::MaxTokens, "\"maxTokens\""),
(StopReason::EndTurn, "\"endTurn\""),
(StopReason::StopSequence, "\"stopSequence\""),
(StopReason::Other("test".to_string()), "\"test\""),
];
for (reason, expected) in reasons {
let json = serde_json::to_string(&reason).unwrap();
assert_eq!(json, expected);
}
}
#[test]
fn it_deserializes_stop_reason() {
let reasons = [
(StopReason::ToolUse, "\"toolUse\""),
(StopReason::MaxTokens, "\"maxTokens\""),
(StopReason::EndTurn, "\"endTurn\""),
(StopReason::StopSequence, "\"stopSequence\""),
(StopReason::Other("test".to_string()), "\"test\""),
];
for (expected, reason_str) in reasons {
let reason: StopReason = serde_json::from_str(reason_str).unwrap();
assert_eq!(reason, expected);
}
}
#[test]
fn it_serializes_model_preferences() {
let pref = ModelPreferences::new()
.with_cost_priority(0.5)
.with_speed_priority(0.75)
.with_intel_priority(0.25);
let json = serde_json::to_string(&pref).unwrap();
let expected = r#"{"costPriority":0.5,"speedPriority":0.75,"intelligencePriority":0.25}"#;
assert_eq!(json, expected);
}
#[test]
fn it_deserializes_model_preferences() {
let json = r#"{"costPriority":0.5,"speedPriority":0.75,"intelligencePriority":0.25}"#;
let pref: ModelPreferences = serde_json::from_str(json).unwrap();
assert_eq!(pref.cost_priority, Some(0.5));
assert_eq!(pref.speed_priority, Some(0.75));
assert_eq!(pref.intelligence_priority, Some(0.25));
}
}