use super::{types::Message, LLMOptions};
use serde::{Deserialize, Serialize};
use serde_json::Value;
#[derive(Serialize, Debug, Clone)]
pub struct ChatCompletionRequest {
pub model: String,
pub messages: Vec<Message>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub presence_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub frequency_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub repeat_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub context_overflow_policy: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_k: Option<u32>,
pub stream: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub response_format: Option<ResponseFormat>,
#[serde(skip_serializing_if = "Option::is_none")]
pub options: Option<Options>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<Tool>>,
}
#[derive(Serialize, Debug, Clone)]
pub struct Options {
#[serde(skip_serializing_if = "Option::is_none")]
pub response_format: Option<String>,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Tool {
#[serde(rename = "type")]
pub tool_type: String,
pub function: FunctionDefinition,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct FunctionDefinition {
pub name: String,
pub description: String,
pub parameters: Parameters,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Parameters {
#[serde(rename = "type")]
pub param_type: String,
pub properties: serde_json::Value,
pub required: Vec<String>,
pub additional_properties: bool,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct JsonSchemaFormat {
pub name: String,
pub strict: bool,
pub schema: Value,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct ResponseFormat {
pub r#type: String,
pub json_schema: JsonSchemaFormat,
}
impl ChatCompletionRequest {
pub fn new(model: String, messages: Vec<Message>) -> Self {
Self {
model,
messages,
temperature: None,
top_p: None,
max_tokens: None,
stop: None,
presence_penalty: None,
frequency_penalty: None,
stream: false,
response_format: None,
options: None,
tools: None,
tool_choice: None,
context_overflow_policy: None,
repeat_penalty: None,
top_k: None,
}
}
pub fn from_llm_options(&mut self, opts: &LLMOptions) -> &mut Self {
self.stream = opts.streaming;
if let Some(t) = opts.temperature {
self.temperature = Some(t);
}
if let Some(p) = opts.top_p {
self.top_p = Some(p);
}
if let Some(max) = opts.max_tokens {
self.max_tokens = Some(max);
}
if let Some(ref stop) = opts.stop {
self.stop = Some(stop.clone());
}
if let Some(p) = opts.presence_penalty {
self.presence_penalty = Some(p);
}
if let Some(p) = opts.frequency_penalty {
self.frequency_penalty = Some(p);
}
self
}
pub fn with_response_format(&mut self, format: ResponseFormat) -> &mut Self {
self.response_format = Some(format);
self
}
pub fn with_temperature(&mut self, temperature: f32) -> &mut Self {
self.temperature = Some(temperature);
self
}
}