use serde::{Deserialize, Serialize};
use serde_json::Value;
use super::message::{Message, Role};
use super::tool::{Tool, ToolCall, ToolChoice};
use super::usage::CompletionUsage;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum ReasoningEffort {
None,
Minimal,
Low,
Medium,
High,
#[default]
Auto,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(untagged)]
pub enum StopSequence {
Single(String),
Multiple(Vec<String>),
}
impl StopSequence {
pub fn single(s: impl Into<String>) -> Self {
StopSequence::Single(s.into())
}
pub fn multiple(sequences: Vec<String>) -> Self {
StopSequence::Multiple(sequences)
}
pub fn to_vec(&self) -> Vec<String> {
match self {
StopSequence::Single(s) => vec![s.clone()],
StopSequence::Multiple(v) => v.clone(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CompletionParams {
pub model_id: String,
pub messages: Vec<Message>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<Tool>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<ToolChoice>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stream: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub n: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop: Option<StopSequence>,
#[serde(skip_serializing_if = "Option::is_none")]
pub presence_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub frequency_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub seed: Option<i64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub user: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub parallel_tool_calls: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub logprobs: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_logprobs: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub logit_bias: Option<std::collections::HashMap<String, f32>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub response_format: Option<Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub reasoning_effort: Option<ReasoningEffort>,
}
impl Default for CompletionParams {
fn default() -> Self {
Self {
model_id: String::new(),
messages: Vec::new(),
tools: None,
tool_choice: None,
temperature: None,
top_p: None,
max_tokens: None,
stream: None,
n: None,
stop: None,
presence_penalty: None,
frequency_penalty: None,
seed: None,
user: None,
parallel_tool_calls: None,
logprobs: None,
top_logprobs: None,
logit_bias: None,
response_format: None,
reasoning_effort: None,
}
}
}
impl CompletionParams {
pub fn new(model_id: impl Into<String>, messages: Vec<Message>) -> Self {
Self {
model_id: model_id.into(),
messages,
..Default::default()
}
}
pub fn with_temperature(mut self, temperature: f32) -> Self {
self.temperature = Some(temperature);
self
}
pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
self.max_tokens = Some(max_tokens);
self
}
pub fn with_stream(mut self, stream: bool) -> Self {
self.stream = Some(stream);
self
}
pub fn with_tools(mut self, tools: Vec<Tool>) -> Self {
self.tools = Some(tools);
self
}
pub fn with_tool_choice(mut self, tool_choice: ToolChoice) -> Self {
self.tool_choice = Some(tool_choice);
self
}
pub fn with_reasoning_effort(mut self, effort: ReasoningEffort) -> Self {
self.reasoning_effort = Some(effort);
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct Reasoning {
pub content: String,
}
impl Reasoning {
pub fn new(content: impl Into<String>) -> Self {
Self {
content: content.into(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ChatCompletionMessage {
pub role: Role,
#[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCall>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub reasoning: Option<Reasoning>,
#[serde(skip_serializing_if = "Option::is_none")]
pub refusal: Option<String>,
}
impl Default for ChatCompletionMessage {
fn default() -> Self {
Self {
role: Role::Assistant,
content: None,
tool_calls: None,
reasoning: None,
refusal: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct Choice {
pub index: u32,
pub message: ChatCompletionMessage,
#[serde(skip_serializing_if = "Option::is_none")]
pub finish_reason: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub logprobs: Option<Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ChatCompletion {
pub id: String,
pub object: String,
pub created: i64,
pub model: String,
pub choices: Vec<Choice>,
#[serde(skip_serializing_if = "Option::is_none")]
pub usage: Option<CompletionUsage>,
#[serde(skip_serializing_if = "Option::is_none")]
pub system_fingerprint: Option<String>,
}
impl ChatCompletion {
pub fn content(&self) -> Option<&str> {
self.choices
.first()
.and_then(|c| c.message.content.as_deref())
}
pub fn tool_calls(&self) -> Option<&[ToolCall]> {
self.choices
.first()
.and_then(|c| c.message.tool_calls.as_deref())
}
pub fn reasoning(&self) -> Option<&str> {
self.choices
.first()
.and_then(|c| c.message.reasoning.as_ref())
.map(|r| r.content.as_str())
}
pub fn finish_reason(&self) -> Option<&str> {
self.choices
.first()
.and_then(|c| c.finish_reason.as_deref())
}
}