use super::types::{
ChatCompletionMessageParam, ChatCompletionPredictionContentParam, ChatCompletionToolParam,
Modality, ReasoningEffort, ToolChoice,
};
use crate::common::types::{Body, InParam, RetryCount, ServiceTier, Timeout};
use http::{
HeaderValue,
header::{IntoHeaderName, USER_AGENT},
};
use serde_json::Value;
use std::{collections::HashMap, time::Duration};
pub struct ChatParam {
inner: InParam,
}
impl ChatParam {
pub fn new(model: &str, messages: &Vec<ChatCompletionMessageParam>) -> Self {
let mut inner = InParam::new();
inner.body = Some(Body::new());
let mut_body = inner.body.as_mut().unwrap();
mut_body.insert("model".to_string(), serde_json::to_value(model).unwrap());
mut_body.insert(
"messages".to_string(),
serde_json::to_value(messages).unwrap(),
);
ChatParam { inner }
}
pub fn frequency_penalty(mut self, frequency_penalty: f32) -> Self {
self.inner.body.as_mut().unwrap().insert(
"frequency_penalty".to_string(),
serde_json::to_value(frequency_penalty).unwrap(),
);
self
}
pub fn logit_bias(mut self, logit_bias: HashMap<String, i32>) -> Self {
self.inner.body.as_mut().unwrap().insert(
"logit_bias".to_string(),
serde_json::to_value(logit_bias).unwrap(),
);
self
}
pub fn logprobs(mut self, logprobs: bool) -> Self {
self.inner.body.as_mut().unwrap().insert(
"logprobs".to_string(),
serde_json::to_value(logprobs).unwrap(),
);
self
}
pub fn modalities(mut self, modalities: Vec<Modality>) -> Self {
self.inner.body.as_mut().unwrap().insert(
"modalities".to_string(),
serde_json::to_value(modalities).unwrap(),
);
self
}
pub fn max_completion_tokens(mut self, max_completion_tokens: i32) -> Self {
self.inner.body.as_mut().unwrap().insert(
"max_completion_tokens".to_string(),
serde_json::to_value(max_completion_tokens).unwrap(),
);
self
}
pub fn metadata(mut self, metadata: HashMap<String, String>) -> Self {
self.inner.body.as_mut().unwrap().insert(
"metadata".to_string(),
serde_json::to_value(metadata).unwrap(),
);
self
}
pub fn parallel_tool_calls(mut self, parallel_tool_calls: bool) -> Self {
self.inner.body.as_mut().unwrap().insert(
"parallel_tool_calls".to_string(),
serde_json::to_value(parallel_tool_calls).unwrap(),
);
self
}
pub fn presence_penalty(mut self, presence_penalty: f32) -> Self {
self.inner.body.as_mut().unwrap().insert(
"presence_penalty".to_string(),
serde_json::to_value(presence_penalty).unwrap(),
);
self
}
pub fn n(mut self, n: i32) -> Self {
self.inner
.body
.as_mut()
.unwrap()
.insert("n".to_string(), serde_json::to_value(n).unwrap());
self
}
pub fn top_p(mut self, top_p: f32) -> Self {
self.inner
.body
.as_mut()
.unwrap()
.insert("top_p".to_string(), serde_json::to_value(top_p).unwrap());
self
}
pub fn temperature(mut self, temperature: f32) -> Self {
self.inner.body.as_mut().unwrap().insert(
"temperature".to_string(),
serde_json::to_value(temperature).unwrap(),
);
self
}
pub fn user(mut self, user: String) -> Self {
self.inner
.body
.as_mut()
.unwrap()
.insert("user".to_string(), serde_json::to_value(user).unwrap());
self
}
pub fn top_logprobs(mut self, top_logprobs: i32) -> Self {
self.inner.body.as_mut().unwrap().insert(
"top_logprobs".to_string(),
serde_json::to_value(top_logprobs).unwrap(),
);
self
}
pub fn prediction(mut self, prediction: ChatCompletionPredictionContentParam) -> Self {
self.inner.body.as_mut().unwrap().insert(
"prediction".to_string(),
serde_json::to_value(prediction).unwrap(),
);
self
}
pub fn reasoning_effort(mut self, reasoning_effort: ReasoningEffort) -> Self {
self.inner.body.as_mut().unwrap().insert(
"reasoning_effort".to_string(),
serde_json::to_value(reasoning_effort).unwrap(),
);
self
}
pub fn service_tier(mut self, service_tier: ServiceTier) -> Self {
self.inner.body.as_mut().unwrap().insert(
"service_tier".to_string(),
serde_json::to_value(service_tier).unwrap(),
);
self
}
pub fn tools(mut self, tools: Vec<ChatCompletionToolParam>) -> Self {
self.inner
.body
.as_mut()
.unwrap()
.insert("tools".to_string(), serde_json::to_value(tools).unwrap());
self
}
pub fn tool_choice(mut self, tool_choice: ToolChoice) -> Self {
self.inner.body.as_mut().unwrap().insert(
"tool_choice".to_string(),
serde_json::to_value(tool_choice).unwrap(),
);
self
}
pub fn retry_count(mut self, retry_count: usize) -> Self {
self.inner.extensions.insert(RetryCount(retry_count));
self
}
pub fn timeout(mut self, timeout: Duration) -> Self {
self.inner.extensions.insert(Timeout(timeout));
self
}
pub fn user_agent(mut self, user_agent: HeaderValue) -> Self {
self.inner.headers.insert(USER_AGENT, user_agent);
self
}
pub fn header<K: IntoHeaderName>(mut self, key: K, val: HeaderValue) -> Self {
self.inner.headers.insert(key, val);
self
}
pub fn body<K: Into<String>, V: Into<Value>>(mut self, key: K, val: V) -> Self {
self.inner
.body
.as_mut()
.unwrap()
.insert(key.into(), val.into());
self
}
}
impl ChatParam {
pub(crate) fn take(self) -> InParam {
self.inner
}
}