#![allow(dead_code)]
use std::{error::Error, fmt::Display};
use async_trait::async_trait;
use once_cell::sync::OnceCell;
use serde::Deserialize;
use serde_json::json;
pub mod chat;
pub mod completion;
static RQCLIENT: OnceCell<reqwest::Client> = OnceCell::new();
static COMPLETION_URL: &str = "https://api.openai.com/v1/completions";
static CHAT_URL: &str = "https://api.openai.com/v1/chat/completions";
#[derive(Debug, Clone)]
pub struct JsonParseError {
json_string: String,
}
#[derive(Debug)]
pub enum SendRequestError {
ReqwestError(reqwest::Error),
OpenAiError(String),
JsonError(JsonParseError),
}
impl Display for SendRequestError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
SendRequestError::ReqwestError(e) => write!(f, "Reqwest error: {}", e),
SendRequestError::OpenAiError(e) => write!(f, "OpenAI error: {}", e),
SendRequestError::JsonError(e) => write!(f, "Json error: {}", e),
}
}
}
impl Display for JsonParseError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Could not parse json: {}", self.json_string)
}
}
impl Error for SendRequestError {}
impl From<reqwest::Error> for SendRequestError {
fn from(e: reqwest::Error) -> Self {
SendRequestError::ReqwestError(e)
}
}
#[async_trait]
pub trait SendRequest {
type Response;
type Error;
async fn send(self) -> Result<Self::Response, Self::Error>;
}
#[doc(hidden)]
pub trait CompletionLike {}
#[doc(hidden)]
pub struct CompletionState;
#[doc(hidden)]
pub struct ChatState;
#[derive(Debug, Clone)]
pub enum CompletionModel {
TextDavinci003,
TextDavinci002,
CodeDavinci002,
}
#[derive(Debug, Clone)]
pub enum ChatModel {
Gpt35Turbo,
GPT35Turbo0301,
}
impl CompletionLike for CompletionState {}
impl CompletionLike for ChatState {}
impl ToString for CompletionModel {
fn to_string(&self) -> String {
match self {
CompletionModel::TextDavinci003 => "text-davinci-003",
CompletionModel::TextDavinci002 => "text-davinci-002",
CompletionModel::CodeDavinci002 => "code-davinci-002",
}
.to_string()
}
}
impl ToString for ChatModel {
fn to_string(&self) -> String {
match self {
ChatModel::Gpt35Turbo => "gpt-3.5-turbo",
ChatModel::GPT35Turbo0301 => "gpt-3.5-turbo-0301",
}
.to_string()
}
}
#[derive(Debug)]
pub struct Request<T> {
to_send: String,
api_key: String,
state: std::marker::PhantomData<T>,
}
#[async_trait]
impl SendRequest for Request<CompletionState> {
type Response = completion::CompletionResponse;
type Error = SendRequestError;
async fn send(self) -> Result<Self::Response, Self::Error> {
use SendRequestError::*;
let client = RQCLIENT.get_or_init(reqwest::Client::new);
let resp = client
.post(COMPLETION_URL)
.header("Content-Type", "application/json")
.header("Authorization", self.api_key)
.body(self.to_send)
.send()
.await?;
let body = resp.text().await.unwrap();
let json: serde_json::Value = serde_json::from_str(&body).unwrap();
let response = match completion::CompletionResponse::deserialize(json.clone()) {
Ok(r) => r,
Err(_) => return Err(JsonError(JsonParseError { json_string: serde_json::to_string_pretty(&json).unwrap() })),
};
Ok(response)
}
}
#[async_trait]
impl SendRequest for Request<ChatState> {
type Response = chat::ChatResponse;
type Error = SendRequestError;
async fn send(self) -> Result<Self::Response, SendRequestError> {
use SendRequestError::*;
if !self.to_send.contains("messages") {
return Err(OpenAiError("No messages in request.".into()));
}
let client = RQCLIENT.get_or_init(reqwest::Client::new);
let resp = client
.post(CHAT_URL)
.header("Content-Type", "application/json")
.header("Authorization", self.api_key)
.body(self.to_send)
.send()
.await?;
let body = resp.text().await.unwrap();
let json: serde_json::Value = serde_json::from_str(&body).unwrap();
if !json["error"].is_null() {
return Err(OpenAiError(serde_json::to_string_pretty(&json).unwrap().into()));
}
let response = match chat::ChatResponse::deserialize(json.clone()) {
Ok(r) => r,
Err(_) => return Err(JsonError(JsonParseError { json_string: serde_json::to_string_pretty(&json).unwrap() })),
};
Ok(response)
}
}
#[derive(Debug)]
pub struct RequestBuilder<T> {
req: serde_json::Value,
api_key: String,
state: std::marker::PhantomData<T>,
}
impl<C: CompletionLike> RequestBuilder<C> {
pub fn new<T: ToString, S: Display>(model: T, api_key: S) -> Self {
let api_key = format!("Bearer {api_key}");
let req = json!({
"model": model.to_string(),
});
Self {
req,
api_key,
state: std::marker::PhantomData,
}
}
pub fn max_tokens(mut self, max_tokens: u32) -> Self {
self.req["max_tokens"] = json!(max_tokens);
self
}
pub fn temperature(mut self, temperature: f32) -> Self {
self.req["temperature"] = json!(temperature);
self
}
pub fn top_p(mut self, top_p: f32) -> Self {
self.req["top_p"] = json!(top_p);
self
}
pub fn frequency_penalty(mut self, frequency_penalty: f32) -> Self {
self.req["frequency_penalty"] = json!(frequency_penalty);
self
}
pub fn presence_penalty(mut self, presence_penalty: f32) -> Self {
self.req["presence_penalty"] = json!(presence_penalty);
self
}
pub fn stop<T: ToString>(mut self, stop: T) -> Self {
self.req["stop"] = json!(stop.to_string());
self
}
pub fn n(mut self, n: u32) -> Self {
self.req["n"] = json!(n);
self
}
pub fn user(mut self, user: String) -> Self {
self.req["user"] = json!(user);
self
}
}
impl RequestBuilder<CompletionState> {
pub fn prompt<T: ToString>(mut self, prompt: T) -> Self {
self.req["prompt"] = json!(prompt.to_string());
self
}
pub fn build_completion(self) -> Request<CompletionState> {
Request {
api_key: self.api_key,
to_send: self.req.to_string(),
state: std::marker::PhantomData,
}
}
}
impl RequestBuilder<ChatState> {
pub fn messages(mut self, messages: Vec<chat::ChatMessage>) -> Self {
self.req["messages"] = json!(messages);
self
}
fn chat_parameters(mut self, chat_parameters: chat::ChatParameters) -> Self {
let mut params = json!(chat_parameters);
params["messages"] = self.req.get("messages").unwrap().clone();
params["model"] = self.req.get("model").unwrap().clone();
self.req = params;
self
}
pub fn build_chat(self) -> Request<ChatState> {
Request {
api_key: self.api_key,
to_send: self.req.to_string(),
state: std::marker::PhantomData,
}
}
}