#[macro_use]
extern crate derive_builder;
use thiserror::Error;
type Result<T> = std::result::Result<T, Error>;
#[allow(clippy::default_trait_access)]
pub mod api {
use std::{collections::HashMap, convert::TryFrom, fmt::Display};
use serde::{Deserialize, Serialize};
#[derive(Deserialize, Debug)]
pub(crate) struct Container<T> {
pub data: Vec<T>,
}
#[derive(Deserialize, Debug, Eq, PartialEq, Clone)]
pub struct EngineInfo {
pub id: String,
pub owner: String,
pub ready: bool,
}
#[derive(Serialize, Debug, Builder, Clone)]
#[builder(pattern = "immutable")]
pub struct CompletionArgs {
#[builder(setter(into), default = "\"davinci\".into()")]
#[serde(skip_serializing)]
pub(super) engine: String,
#[builder(setter(into), default = "\"<|endoftext|>\".into()")]
prompt: String,
#[builder(default = "16")]
max_tokens: u64,
#[builder(default = "1.0")]
temperature: f64,
#[builder(default = "1.0")]
top_p: f64,
#[builder(default = "1")]
n: u64,
#[builder(setter(strip_option), default)]
logprobs: Option<u64>,
#[builder(default = "false")]
echo: bool,
#[builder(setter(strip_option), default)]
stop: Option<Vec<String>>,
#[builder(default = "0.0")]
presence_penalty: f64,
#[builder(default = "0.0")]
frequency_penalty: f64,
#[builder(default)]
logit_bias: HashMap<String, f64>,
}
impl From<&str> for CompletionArgs {
fn from(prompt_string: &str) -> Self {
Self {
prompt: prompt_string.into(),
..CompletionArgsBuilder::default()
.build()
.expect("default should build")
}
}
}
impl CompletionArgs {
#[must_use]
pub fn builder() -> CompletionArgsBuilder {
CompletionArgsBuilder::default()
}
}
impl TryFrom<CompletionArgsBuilder> for CompletionArgs {
type Error = CompletionArgsBuilderError;
fn try_from(builder: CompletionArgsBuilder) -> Result<Self, Self::Error> {
builder.build()
}
}
#[derive(Deserialize, Debug, Clone)]
pub struct Completion {
pub id: String,
pub created: u64,
pub model: String,
pub choices: Vec<Choice>,
}
impl std::fmt::Display for Completion {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.choices[0])
}
}
#[derive(Deserialize, Debug, Clone)]
pub struct Choice {
pub text: String,
pub index: u64,
pub logprobs: Option<LogProbs>,
pub finish_reason: String,
}
impl std::fmt::Display for Choice {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.text.fmt(f)
}
}
#[derive(Deserialize, Debug, Clone)]
pub struct LogProbs {
pub tokens: Vec<String>,
pub token_logprobs: Vec<Option<f64>>,
pub top_logprobs: Vec<Option<HashMap<String, f64>>>,
pub text_offset: Vec<u64>,
}
#[derive(Deserialize, Debug, Eq, PartialEq, Clone)]
pub struct ErrorMessage {
pub message: String,
#[serde(rename = "type")]
pub error_type: String,
}
impl Display for ErrorMessage {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.message.fmt(f)
}
}
#[derive(Deserialize, Debug)]
pub(crate) struct ErrorWrapper {
pub error: ErrorMessage,
}
}
#[derive(Error, Debug)]
pub enum Error {
#[error("API returned an Error: {}", .0.message)]
Api(api::ErrorMessage),
#[error("Bad arguments: {0}")]
BadArguments(String),
#[error("Error at the protocol level: {0}")]
AsyncProtocol(reqwest::Error),
}
impl From<api::ErrorMessage> for Error {
fn from(e: api::ErrorMessage) -> Self {
Error::Api(e)
}
}
impl From<String> for Error {
fn from(e: String) -> Self {
Error::BadArguments(e)
}
}
impl From<reqwest::Error> for Error {
fn from(e: reqwest::Error) -> Self {
Error::AsyncProtocol(e)
}
}
struct BearerToken {
token: String,
}
impl std::fmt::Debug for BearerToken {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
r#"Bearer {{ token: "{}" }}"#,
self.token.get(0..8).ok_or(std::fmt::Error)?
)
}
}
impl BearerToken {
fn new(token: &str) -> Self {
Self {
token: String::from(token),
}
}
}
#[derive(Debug, Clone)]
pub struct Client {
client: reqwest::Client,
base_url: String,
token: String,
}
impl Client {
#[must_use]
pub fn new(token: &str) -> Self {
Self {
client: reqwest::Client::new(),
base_url: "https://api.openai.com/v1/".to_string(),
token: token.to_string(),
}
}
async fn get<T>(&self, endpoint: &str) -> Result<T>
where
T: serde::de::DeserializeOwned,
{
let mut response =
self.client
.get(endpoint)
.header("Authorization", format!("Bearer {}", self.token))
.send()
.await?;
if let reqwest::StatusCode::OK = response.status() {
Ok(response.json::<T>().await?)
} else {
let err = response.json::<api::ErrorWrapper>().await?.error;
Err(Error::Api(err))
}
}
pub async fn engines(&self) -> Result<Vec<api::EngineInfo>> {
self.get(
&self.build_url_from_path(
&format!(
"engines",
),
),
).await.map(|r: api::Container<_>| r.data)
}
pub async fn engine(&self, engine: &str) -> Result<api::EngineInfo> {
self.get(
&self.build_url_from_path(
&format!(
"engines/{}",
engine,
),
),
).await
}
async fn post<B, R>(&self, endpoint: &str, body: B) -> Result<R>
where
B: serde::ser::Serialize,
R: serde::de::DeserializeOwned,
{
let mut response = self
.client
.post(endpoint)
.header("Authorization", format!("Bearer {}", self.token))
.json(&body)
.send()
.await?;
match response.status() {
reqwest::StatusCode::OK => Ok(response.json::<R>().await?),
_ => Err(Error::Api(
response
.json::<api::ErrorWrapper>()
.await
.expect("The API has returned something funky")
.error,
)),
}
}
pub fn build_url_from_path(&self, path: &str) -> String {
format!("{}{}", self.base_url, path)
}
pub async fn complete_prompt(
&self,
prompt: impl Into<api::CompletionArgs>,
) -> Result<api::Completion> {
let args = prompt.into();
Ok(self
.post(
&self.build_url_from_path(
&format!(
"engines/{}/completions",
args.engine,
)
),
args,
)
.await?)
}
}