use crate::ValidationError;
use crate::macros::impl_display_for_serialize;
use crate::messages::{
ClaudeModel, MaxTokens, Message, Metadata, StopSequence, StreamOption,
SystemPrompt, Temperature, ToolDefinition, TopK, TopP,
};
#[derive(
Debug, Clone, PartialEq, Default, serde::Serialize, serde::Deserialize,
)]
pub struct MessagesRequestBody {
pub model: ClaudeModel,
pub messages: Vec<Message>,
#[serde(skip_serializing_if = "Option::is_none")]
pub system: Option<SystemPrompt>,
pub max_tokens: MaxTokens,
#[serde(skip_serializing_if = "Option::is_none")]
pub metadata: Option<Metadata>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop_sequences: Option<Vec<StopSequence>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stream: Option<StreamOption>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<Temperature>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<ToolDefinition>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<TopP>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_k: Option<TopK>,
#[serde(skip_serializing_if = "Option::is_none")]
pub thinking: Option<Thinking>,
}
impl_display_for_serialize!(MessagesRequestBody);
#[derive(
Debug, Clone, PartialEq, Default, serde::Serialize, serde::Deserialize,
)]
pub struct Thinking {
pub r#type: String,
pub budget_tokens: u64,
}
impl_display_for_serialize!(Thinking);
#[derive(Debug, Clone, PartialEq, Default)]
pub struct MessagesRequestBuilder {
request_body: MessagesRequestBody,
}
impl MessagesRequestBuilder {
pub fn new(model: ClaudeModel) -> Self {
Self {
request_body: MessagesRequestBody {
model,
max_tokens: MaxTokens::from_model(model),
..Default::default()
},
}
}
pub fn new_with_max_tokens(
model: ClaudeModel,
max_tokens: u32,
) -> Result<Self, ValidationError<u32>> {
Ok(Self {
request_body: MessagesRequestBody {
model,
max_tokens: MaxTokens::new(max_tokens, model)?,
..Default::default()
},
})
}
pub fn messages(
mut self,
messages: Vec<Message>,
) -> Self {
self.request_body.messages = messages;
self
}
pub fn system(
mut self,
system: SystemPrompt,
) -> Self {
self.request_body.system = Some(system);
self
}
pub fn max_tokens(
mut self,
max_tokens: MaxTokens,
) -> Self {
self.request_body.max_tokens = max_tokens;
self
}
pub fn metadata(
mut self,
metadata: Metadata,
) -> Self {
self.request_body.metadata = Some(metadata);
self
}
pub fn stop_sequences(
mut self,
stop_sequences: Vec<StopSequence>,
) -> Self {
self.request_body
.stop_sequences = Some(stop_sequences);
self
}
pub fn stream(
mut self,
stream: StreamOption,
) -> Self {
self.request_body.stream = Some(stream);
self
}
pub fn temperature(
mut self,
temperature: Temperature,
) -> Self {
self.request_body.temperature = Some(temperature);
self
}
pub fn tools(
mut self,
tools: Vec<ToolDefinition>,
) -> Self {
self.request_body.tools = Some(tools);
self
}
pub fn top_p(
mut self,
top_p: TopP,
) -> Self {
self.request_body.top_p = Some(top_p);
self
}
pub fn top_k(
mut self,
top_k: TopK,
) -> Self {
self.request_body.top_k = Some(top_k);
self
}
pub fn thinking(
mut self,
thinking: Thinking,
) -> Self {
self.request_body.thinking = Some(thinking);
self
}
pub fn build(self) -> MessagesRequestBody {
self.request_body
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn new() {
let messages_request_body = MessagesRequestBody {
model: ClaudeModel::Claude3Sonnet20240229,
messages: vec![],
max_tokens: MaxTokens::new(16, ClaudeModel::Claude3Sonnet20240229)
.unwrap(),
..Default::default()
};
assert_eq!(
messages_request_body.model,
ClaudeModel::Claude3Sonnet20240229
);
assert_eq!(messages_request_body.messages, vec![]);
assert_eq!(messages_request_body.system, None);
assert_eq!(
messages_request_body.max_tokens,
MaxTokens::new(16, ClaudeModel::Claude3Sonnet20240229).unwrap()
);
assert_eq!(messages_request_body.metadata, None);
assert_eq!(
messages_request_body.stop_sequences,
None
);
assert_eq!(messages_request_body.stream, None);
assert_eq!(messages_request_body.temperature, None);
assert_eq!(messages_request_body.top_p, None);
assert_eq!(messages_request_body.top_k, None);
}
#[test]
fn default() {
let messages_request_body = MessagesRequestBody::default();
assert_eq!(
messages_request_body.model,
ClaudeModel::default()
);
assert_eq!(messages_request_body.messages, vec![]);
assert_eq!(messages_request_body.system, None);
assert_eq!(
messages_request_body.max_tokens,
MaxTokens::default()
);
assert_eq!(messages_request_body.metadata, None);
assert_eq!(
messages_request_body.stop_sequences,
None
);
assert_eq!(messages_request_body.stream, None);
assert_eq!(messages_request_body.temperature, None);
assert_eq!(messages_request_body.tools, None);
assert_eq!(messages_request_body.top_p, None);
assert_eq!(messages_request_body.top_k, None);
}
#[test]
fn display() {
let messages_request_body = MessagesRequestBody::default();
assert_eq!(
messages_request_body.to_string(),
"{\n \"model\": \"claude-3-sonnet-20240229\",\n \"messages\": [],\n \"max_tokens\": 4096\n}"
);
}
#[test]
fn serialize() {
let messages_request_body = MessagesRequestBody::default();
assert_eq!(
serde_json::to_string(&messages_request_body).unwrap(),
"{\"model\":\"claude-3-sonnet-20240229\",\"messages\":[],\"max_tokens\":4096}"
);
let messages_request_body = MessagesRequestBody {
model: ClaudeModel::Claude3Sonnet20240229,
messages: vec![],
max_tokens: MaxTokens::new(16, ClaudeModel::Claude3Sonnet20240229)
.unwrap(),
system: Some(SystemPrompt::new("system-prompt")),
metadata: Some(Metadata {
user_id: "metadata".into(),
}),
stop_sequences: Some(vec![StopSequence::new(
"stop-sequence",
)]),
stream: Some(StreamOption::ReturnOnce),
temperature: Some(Temperature::new(0.5).unwrap()),
tools: None,
top_p: Some(TopP::new(0.5).unwrap()),
top_k: Some(TopK::new(50)),
thinking: None,
};
assert_eq!(
serde_json::to_string(&messages_request_body).unwrap(),
"{\"model\":\"claude-3-sonnet-20240229\",\"messages\":[],\"system\":\"system-prompt\",\"max_tokens\":16,\"metadata\":{\"user_id\":\"metadata\"},\"stop_sequences\":[\"stop-sequence\"],\"stream\":false,\"temperature\":0.5,\"top_p\":0.5,\"top_k\":50}"
);
}
#[test]
fn deserialize() {
let messages_request_body = MessagesRequestBody::default();
assert_eq!(
serde_json::from_str::<MessagesRequestBody>("{\"model\":\"claude-3-sonnet-20240229\",\"messages\":[],\"max_tokens\":4096}").unwrap(),
messages_request_body
);
let messages_request_body = MessagesRequestBody {
model: ClaudeModel::Claude3Sonnet20240229,
messages: vec![],
max_tokens: MaxTokens::new(16, ClaudeModel::Claude3Sonnet20240229)
.unwrap(),
system: Some(SystemPrompt::new("system-prompt")),
metadata: Some(Metadata {
user_id: "metadata".into(),
}),
stop_sequences: Some(vec![StopSequence::new(
"stop-sequence",
)]),
stream: Some(StreamOption::ReturnOnce),
temperature: Some(Temperature::new(0.5).unwrap()),
tools: None,
top_p: Some(TopP::new(0.5).unwrap()),
top_k: Some(TopK::new(50)),
thinking: None,
};
assert_eq!(
serde_json::from_str::<MessagesRequestBody>("{\"model\":\"claude-3-sonnet-20240229\",\"messages\":[],\"system\":\"system-prompt\",\"max_tokens\":16,\"metadata\":{\"user_id\":\"metadata\"},\"stop_sequences\":[\"stop-sequence\"],\"stream\":false,\"temperature\":0.5,\"top_p\":0.5,\"top_k\":50}").unwrap(),
messages_request_body
);
}
#[test]
fn builder() {
let messages_request_body =
MessagesRequestBuilder::new(ClaudeModel::Claude3Sonnet20240229)
.messages(vec![])
.system(SystemPrompt::new("system-prompt"))
.max_tokens(
MaxTokens::new(16, ClaudeModel::Claude3Sonnet20240229)
.unwrap(),
)
.metadata(Metadata {
user_id: "metadata".into(),
})
.stop_sequences(vec![StopSequence::new(
"stop-sequence",
)])
.stream(StreamOption::ReturnOnce)
.temperature(Temperature::new(0.5).unwrap())
.tools(vec![ToolDefinition {
name: "tool".into(),
description: Some("tool description".into()),
input_schema: serde_json::Value::Null,
}])
.top_p(TopP::new(0.5).unwrap())
.top_k(TopK::new(50))
.build();
assert_eq!(
messages_request_body.model,
ClaudeModel::Claude3Sonnet20240229
);
assert_eq!(messages_request_body.messages, vec![]);
assert_eq!(
messages_request_body.system,
Some(SystemPrompt::new("system-prompt"))
);
assert_eq!(
messages_request_body.max_tokens,
MaxTokens::new(16, ClaudeModel::Claude3Sonnet20240229).unwrap()
);
assert_eq!(
messages_request_body.metadata,
Some(Metadata {
user_id: "metadata".into(),
})
);
assert_eq!(
messages_request_body.stop_sequences,
Some(vec![StopSequence::new(
"stop-sequence"
)])
);
assert_eq!(
messages_request_body.stream,
Some(StreamOption::ReturnOnce)
);
assert_eq!(
messages_request_body.temperature,
Some(Temperature::new(0.5).unwrap())
);
assert_eq!(
messages_request_body.tools,
Some(vec![ToolDefinition {
name: "tool".into(),
description: Some("tool description".into()),
input_schema: serde_json::Value::Null,
}])
);
assert_eq!(
messages_request_body.top_p,
Some(TopP::new(0.5).unwrap())
);
assert_eq!(
messages_request_body.top_k,
Some(TopK::new(50))
);
}
#[test]
fn builder_with_max_tokens() {
let messages_request_body =
MessagesRequestBuilder::new_with_max_tokens(
ClaudeModel::Claude3Sonnet20240229,
16,
)
.unwrap()
.messages(vec![])
.system(SystemPrompt::new("system-prompt"))
.metadata(Metadata {
user_id: "metadata".into(),
})
.stop_sequences(vec![StopSequence::new(
"stop-sequence",
)])
.stream(StreamOption::ReturnOnce)
.temperature(Temperature::new(0.5).unwrap())
.tools(vec![ToolDefinition {
name: "tool".into(),
description: Some("tool description".into()),
input_schema: serde_json::Value::Null,
}])
.top_p(TopP::new(0.5).unwrap())
.top_k(TopK::new(50))
.build();
assert_eq!(
messages_request_body.model,
ClaudeModel::Claude3Sonnet20240229
);
assert_eq!(messages_request_body.messages, vec![]);
assert_eq!(
messages_request_body.system,
Some(SystemPrompt::new("system-prompt"))
);
assert_eq!(
messages_request_body.max_tokens,
MaxTokens::new(16, ClaudeModel::Claude3Sonnet20240229).unwrap()
);
assert_eq!(
messages_request_body.metadata,
Some(Metadata {
user_id: "metadata".into(),
})
);
assert_eq!(
messages_request_body.stop_sequences,
Some(vec![StopSequence::new(
"stop-sequence"
)])
);
assert_eq!(
messages_request_body.stream,
Some(StreamOption::ReturnOnce)
);
assert_eq!(
messages_request_body.temperature,
Some(Temperature::new(0.5).unwrap())
);
assert_eq!(
messages_request_body.tools,
Some(vec![ToolDefinition {
name: "tool".into(),
description: Some("tool description".into()),
input_schema: serde_json::Value::Null,
}])
);
assert_eq!(
messages_request_body.top_p,
Some(TopP::new(0.5).unwrap())
);
assert_eq!(
messages_request_body.top_k,
Some(TopK::new(50))
);
}
}