use std::collections::BTreeMap;
use serde::{Deserialize, Serialize};
use super::common::{
AssistantMessage, ChatCompletionTool, Message, ResponseFormat, StopSequence, ToolChoice, ToolType, Usage,
};
use crate::cost;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum FinishReason {
Stop,
Length,
ToolCalls,
ContentFilter,
#[serde(rename = "function_call")]
FunctionCall,
#[serde(other)]
Other,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum ReasoningEffort {
Low,
Medium,
High,
}
#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct ChatCompletionRequest {
pub model: String,
pub messages: Vec<Message>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub temperature: Option<f64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub top_p: Option<f64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub n: Option<u32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub(crate) stream: Option<bool>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub stop: Option<StopSequence>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub presence_penalty: Option<f64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub frequency_penalty: Option<f64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub logit_bias: Option<BTreeMap<String, f64>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub user: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<ChatCompletionTool>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<ToolChoice>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub parallel_tool_calls: Option<bool>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub response_format: Option<ResponseFormat>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub stream_options: Option<StreamOptions>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub seed: Option<i64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub reasoning_effort: Option<ReasoningEffort>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub extra_body: Option<serde_json::Value>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct StreamOptions {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub include_usage: Option<bool>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ChatCompletionResponse {
pub id: String,
pub object: String,
pub created: u64,
pub model: String,
pub choices: Vec<Choice>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub usage: Option<Usage>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub system_fingerprint: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub service_tier: Option<String>,
}
impl ChatCompletionResponse {
#[must_use]
pub fn estimated_cost(&self) -> Option<f64> {
let usage = self.usage.as_ref()?;
cost::completion_cost(&self.model, usage.prompt_tokens, usage.completion_tokens)
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct Choice {
pub index: u32,
pub message: AssistantMessage,
pub finish_reason: Option<FinishReason>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ChatCompletionChunk {
pub id: String,
pub object: String,
pub created: u64,
pub model: String,
pub choices: Vec<StreamChoice>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub usage: Option<Usage>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub system_fingerprint: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub service_tier: Option<String>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct StreamChoice {
pub index: u32,
pub delta: StreamDelta,
pub finish_reason: Option<FinishReason>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct StreamDelta {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub role: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub content: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<StreamToolCall>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub function_call: Option<StreamFunctionCall>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub refusal: Option<String>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct StreamToolCall {
pub index: u32,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub id: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none", rename = "type")]
pub call_type: Option<ToolType>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub function: Option<StreamFunctionCall>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct StreamFunctionCall {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub arguments: Option<String>,
}