use crate::common::types::{Body, InParam, RetryCount, Timeout};
use http::{
HeaderValue,
header::{IntoHeaderName, USER_AGENT},
};
use serde_json::Value;
use std::{collections::HashMap, time::Duration};
pub struct CompletionsParam {
inner: InParam,
}
impl CompletionsParam {
pub fn new(model: &str, prompt: &str) -> Self {
let mut inner = InParam::new();
inner.body = Some(Body::new());
inner
.body
.as_mut()
.unwrap()
.insert("model".to_string(), serde_json::to_value(model).unwrap());
inner
.body
.as_mut()
.unwrap()
.insert("prompt".to_string(), serde_json::to_value(prompt).unwrap());
CompletionsParam { inner }
}
pub fn max_tokens(mut self, max_tokens: i32) -> Self {
self.inner.body.as_mut().unwrap().insert(
"max_tokens".to_string(),
serde_json::to_value(max_tokens).unwrap(),
);
self
}
pub fn temperature(mut self, temperature: f32) -> Self {
self.inner.body.as_mut().unwrap().insert(
"temperature".to_string(),
serde_json::to_value(temperature).unwrap(),
);
self
}
pub fn top_p(mut self, top_p: f32) -> Self {
self.inner
.body
.as_mut()
.unwrap()
.insert("top_p".to_string(), serde_json::to_value(top_p).unwrap());
self
}
pub fn n(mut self, n: i32) -> Self {
self.inner
.body
.as_mut()
.unwrap()
.insert("n".to_string(), serde_json::to_value(n).unwrap());
self
}
pub fn logprobs(mut self, logprobs: i32) -> Self {
self.inner.body.as_mut().unwrap().insert(
"logprobs".to_string(),
serde_json::to_value(logprobs).unwrap(),
);
self
}
pub fn echo(mut self, echo: bool) -> Self {
self.inner
.body
.as_mut()
.unwrap()
.insert("echo".to_string(), serde_json::to_value(echo).unwrap());
self
}
pub fn stop(mut self, stop: Vec<String>) -> Self {
self.inner
.body
.as_mut()
.unwrap()
.insert("stop".to_string(), serde_json::to_value(stop).unwrap());
self
}
pub fn presence_penalty(mut self, presence_penalty: f32) -> Self {
self.inner.body.as_mut().unwrap().insert(
"presence_penalty".to_string(),
serde_json::to_value(presence_penalty).unwrap(),
);
self
}
pub fn frequency_penalty(mut self, frequency_penalty: f32) -> Self {
self.inner.body.as_mut().unwrap().insert(
"frequency_penalty".to_string(),
serde_json::to_value(frequency_penalty).unwrap(),
);
self
}
pub fn best_of(mut self, best_of: i32) -> Self {
self.inner.body.as_mut().unwrap().insert(
"best_of".to_string(),
serde_json::to_value(best_of).unwrap(),
);
self
}
pub fn logit_bias(mut self, bias: HashMap<String, i32>) -> Self {
self.inner.body.as_mut().unwrap().insert(
"logit_bias".to_string(),
serde_json::to_value(bias).unwrap(),
);
self
}
pub fn user(mut self, user: String) -> Self {
self.inner
.body
.as_mut()
.unwrap()
.insert("user".to_string(), serde_json::to_value(user).unwrap());
self
}
pub fn timeout(mut self, timeout: Duration) -> Self {
self.inner.extensions.insert(Timeout(timeout));
self
}
pub fn user_agent(mut self, user_agent: HeaderValue) -> Self {
self.inner.headers.insert(USER_AGENT, user_agent);
self
}
pub fn header<K: IntoHeaderName>(mut self, key: K, val: HeaderValue) -> Self {
self.inner.headers.insert(key, val);
self
}
pub fn body<K: Into<String>, V: Into<Value>>(mut self, key: K, val: V) -> Self {
self.inner
.body
.as_mut()
.unwrap()
.insert(key.into(), val.into());
self
}
pub fn retry_count(mut self, retry_count: usize) -> Self {
self.inner.extensions.insert(RetryCount(retry_count));
self
}
}
impl CompletionsParam {
pub(crate) fn take(self) -> InParam {
self.inner
}
}