use serde::{Deserialize, Serialize};
use crate::effort::Effort;
use crate::error::{Error, Result};
use crate::message::{ContentBlock, Message, Role};
use crate::thinking::ThinkingSetting;
use crate::tool::{Tool, ToolChoice};
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct CompletionRequest {
pub model: String,
pub messages: Vec<Message>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub tools: Vec<Tool>,
#[serde(default)]
pub tool_choice: ToolChoice,
pub max_tokens: u32,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub top_k: Option<u32>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub stop_sequences: Vec<String>,
#[serde(default)]
pub thinking: ThinkingSetting,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub effort: Option<Effort>,
#[serde(default)]
pub metadata: RequestMetadata,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum RequestPurpose {
MainLoop,
Summarization,
FastClassifier,
SubAgent,
Embedding,
Other,
}
#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
pub struct RequestMetadata {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub user_id: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub purpose: Option<RequestPurpose>,
}
impl CompletionRequest {
pub fn builder(model: impl Into<String>) -> CompletionRequestBuilder {
CompletionRequestBuilder {
model: model.into(),
messages: Vec::new(),
tools: Vec::new(),
tool_choice: ToolChoice::default(),
max_tokens: 1024,
temperature: None,
top_p: None,
top_k: None,
stop_sequences: Vec::new(),
thinking: ThinkingSetting::Auto,
effort: None,
metadata: RequestMetadata::default(),
}
}
pub fn validate(&self) -> Result<()> {
if self.model.is_empty() {
return Err(Error::InvalidRequest("model is empty".into()));
}
if self.max_tokens == 0 {
return Err(Error::InvalidRequest("max_tokens must be > 0".into()));
}
validate_messages(&self.messages)
}
}
fn validate_messages(messages: &[Message]) -> Result<()> {
let mut seen_non_system = false;
let mut has_user_or_assistant = false;
for (i, msg) in messages.iter().enumerate() {
match msg.role {
Role::System => {
if seen_non_system {
return Err(Error::InvalidRequest(format!(
"Role::System message at index {i} appears after a User/Assistant \
message; System must lead"
)));
}
for block in &msg.content {
if !matches!(block, ContentBlock::Text(_)) {
return Err(Error::InvalidRequest(format!(
"Role::System message at index {i} contains a non-text block"
)));
}
}
}
Role::User | Role::Assistant => {
seen_non_system = true;
has_user_or_assistant = true;
}
}
}
if !has_user_or_assistant {
return Err(Error::InvalidRequest(
"request has no User or Assistant messages".into(),
));
}
Ok(())
}
#[must_use = "builder has no effect until .build() is called"]
pub struct CompletionRequestBuilder {
model: String,
messages: Vec<Message>,
tools: Vec<Tool>,
tool_choice: ToolChoice,
max_tokens: u32,
temperature: Option<f32>,
top_p: Option<f32>,
top_k: Option<u32>,
stop_sequences: Vec<String>,
thinking: ThinkingSetting,
effort: Option<Effort>,
metadata: RequestMetadata,
}
impl CompletionRequestBuilder {
pub fn system(mut self, text: impl Into<String>) -> Self {
let insertion_index = self
.messages
.iter()
.position(|m| m.role != Role::System)
.unwrap_or(self.messages.len());
self.messages
.insert(insertion_index, Message::system_text(text));
self
}
pub fn user_text(mut self, text: impl Into<String>) -> Self {
self.messages.push(Message::user_text(text));
self
}
pub fn assistant_text(mut self, text: impl Into<String>) -> Self {
self.messages.push(Message::assistant_text(text));
self
}
pub fn message(mut self, m: Message) -> Self {
self.messages.push(m);
self
}
pub fn tool(mut self, t: Tool) -> Self {
self.tools.push(t);
self
}
pub fn tool_choice(mut self, choice: ToolChoice) -> Self {
self.tool_choice = choice;
self
}
pub fn max_tokens(mut self, n: u32) -> Self {
self.max_tokens = n;
self
}
pub fn temperature(mut self, t: f32) -> Self {
self.temperature = Some(t);
self
}
pub fn top_p(mut self, p: f32) -> Self {
self.top_p = Some(p);
self
}
pub fn top_k(mut self, k: u32) -> Self {
self.top_k = Some(k);
self
}
pub fn stop_sequence(mut self, s: impl Into<String>) -> Self {
self.stop_sequences.push(s.into());
self
}
pub fn thinking(mut self, setting: ThinkingSetting) -> Self {
self.thinking = setting;
self
}
pub fn effort(mut self, e: Effort) -> Self {
self.effort = Some(e);
self
}
pub fn user_id(mut self, id: impl Into<String>) -> Self {
self.metadata.user_id = Some(id.into());
self
}
#[must_use = "discarding the Result silently ignores validation errors"]
pub fn build(self) -> Result<CompletionRequest> {
let req = CompletionRequest {
model: self.model,
messages: self.messages,
tools: self.tools,
tool_choice: self.tool_choice,
max_tokens: self.max_tokens,
temperature: self.temperature,
top_p: self.top_p,
top_k: self.top_k,
stop_sequences: self.stop_sequences,
thinking: self.thinking,
effort: self.effort,
metadata: self.metadata,
};
req.validate()?;
Ok(req)
}
}