crate::ix!();
#[derive(Debug, Serialize, Deserialize)]
pub struct GptBatchAPIRequest {
custom_id: String,
#[serde(with = "http_method")]
method: HttpMethod,
#[serde(with = "api_url")]
url: GptApiUrl,
body: GptRequestBody,
}
impl GptBatchAPIRequest {
pub fn new_basic(idx: usize, system_message: &str, user_message: &str) -> Self {
Self {
custom_id: Self::custom_id_for_idx(idx),
method: HttpMethod::Post,
url: GptApiUrl::ChatCompletions,
body: GptRequestBody::new_basic(system_message,user_message),
}
}
pub fn new_with_image(idx: usize, system_message: &str, user_message: &str, image_b64: &str) -> Self {
Self {
custom_id: Self::custom_id_for_idx(idx),
method: HttpMethod::Post,
url: GptApiUrl::ChatCompletions,
body: GptRequestBody::new_with_image(system_message,user_message,image_b64),
}
}
}
impl Display for GptBatchAPIRequest {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match serde_json::to_string(self) {
Ok(json) => write!(f, "{}", json),
Err(e) => {
write!(f, "Error serializing to JSON: {}", e)
}
}
}
}
impl GptBatchAPIRequest {
pub(crate) fn custom_id_for_idx(idx: usize) -> String {
format!("request-{}",idx)
}
}
mod http_method {
use super::*;
pub fn serialize<S>(value: &HttpMethod, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(&value.to_string())
}
pub fn deserialize<'de, D>(deserializer: D) -> Result<HttpMethod, D::Error>
where
D: Deserializer<'de>,
{
let s: String = Deserialize::deserialize(deserializer)?;
match s.as_ref() {
"POST" => Ok(HttpMethod::Post),
_ => Err(serde::de::Error::custom("unknown method")),
}
}
}
mod api_url {
use super::*;
pub fn serialize<S>(value: &GptApiUrl, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(&value.to_string())
}
pub fn deserialize<'de, D>(deserializer: D) -> Result<GptApiUrl, D::Error>
where
D: Deserializer<'de>,
{
let s: String = Deserialize::deserialize(deserializer)?;
match s.as_ref() {
"/v1/chat/completions" => Ok(GptApiUrl::ChatCompletions),
_ => Err(serde::de::Error::custom("unknown URL")),
}
}
}