#[macro_use]
extern crate derive_builder;
use thiserror::Error;
type Result<T> = std::result::Result<T, OpenAIError>;
#[allow(clippy::default_trait_access)]
pub mod api {
use std::collections::HashMap;
use super::OpenAIClient;
use serde::{Deserialize, Serialize};
#[derive(Deserialize, Debug)]
pub(super) struct Container<T> {
pub data: Vec<T>,
}
#[derive(Deserialize, Debug, Eq, PartialEq)]
pub struct EngineInfo {
pub id: Engine,
pub owner: String,
pub ready: bool,
}
#[derive(Deserialize, Serialize, Debug, Ord, PartialOrd, Eq, PartialEq, Copy, Clone)]
#[serde(rename_all = "kebab-case")]
#[non_exhaustive]
pub enum Engine {
Ada,
Babbage,
Curie,
Davinci,
#[serde(rename = "content-filter-alpha-c4")]
ContentFilter,
#[serde(other)]
Other,
}
impl std::fmt::Display for Engine {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Engine::Ada => f.write_str("ada"),
Engine::Babbage => f.write_str("babbage"),
Engine::Curie => f.write_str("curie"),
Engine::Davinci => f.write_str("davinci"),
Engine::ContentFilter => f.write_str("content-filter-alpha-c4"),
_ => panic!("Can't write out Other engine id"),
}
}
}
#[derive(Serialize, Debug, Builder, Clone)]
pub struct CompletionArgs {
#[builder(setter(into), default = "\"<|endoftext|>\".into()")]
prompt: String,
#[builder(default = "Engine::Davinci")]
#[serde(skip_serializing)]
pub(super) engine: Engine,
#[builder(default = "16")]
max_tokens: u64,
#[builder(default = "1.0")]
temperature: f64,
#[builder(default = "1.0")]
top_p: f64,
#[builder(default = "1")]
n: u64,
#[builder(setter(strip_option), default)]
logprobs: Option<u64>,
#[builder(default = "false")]
echo: bool,
#[builder(setter(strip_option), default)]
stop: Option<Vec<String>>,
#[builder(default = "0.0")]
presence_penalty: f64,
#[builder(default = "0.0")]
frequency_penalty: f64,
#[builder(default)]
logit_bias: HashMap<String, f64>,
}
impl From<&str> for CompletionArgs {
fn from(prompt_string: &str) -> Self {
Self {
prompt: prompt_string.into(),
..CompletionArgsBuilder::default()
.build()
.expect("default should build")
}
}
}
impl CompletionArgs {
#[must_use]
pub fn builder() -> CompletionArgsBuilder {
CompletionArgsBuilder::default()
}
pub async fn complete(&self, client: &OpenAIClient) -> super::Result<Completion> {
client.complete(self.clone()).await
}
}
impl CompletionArgsBuilder {
pub async fn complete(&self, client: &OpenAIClient) -> super::Result<Completion> {
client.complete(self.build()?).await
}
}
#[derive(Deserialize, Debug)]
pub struct Completion {
pub id: String,
pub created: u64,
pub model: String,
pub choices: Vec<Choice>,
}
impl std::fmt::Display for Completion {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.choices[0])
}
}
#[derive(Deserialize, Debug)]
pub struct Choice {
pub text: String,
pub index: u64,
pub logprobs: Option<LogProbs>,
pub finish_reason: FinishReason,
}
impl std::fmt::Display for Choice {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.text.fmt(f)
}
}
#[derive(Deserialize, Debug)]
pub struct LogProbs {
pub tokens: Vec<String>,
pub token_logprobs: Vec<Option<f64>>,
pub top_logprobs: Vec<Option<HashMap<String, f64>>>,
pub text_offset: Vec<u64>,
}
#[derive(Deserialize, Debug, Eq, PartialEq)]
#[non_exhaustive]
pub enum FinishReason {
#[serde(rename = "max_tokens")]
MaxTokensReached,
#[serde(rename = "stop")]
StopSequenceReached,
}
#[derive(Deserialize, Debug, Eq, PartialEq)]
pub struct ErrorMessage {
pub message: String,
#[serde(rename = "type")]
pub error_type: String,
}
#[derive(Deserialize, Debug)]
pub(super) struct ErrorWrapper {
pub error: ErrorMessage,
}
}
#[derive(Error, Debug)]
pub enum OpenAIError {
#[error("API Returned an Error document")]
APIError(api::ErrorMessage),
#[error("Bad arguments")]
BadArguments(String),
#[error("Error at the protocol level")]
ProtocolError(surf::Error),
}
impl From<api::ErrorMessage> for OpenAIError {
fn from(e: api::ErrorMessage) -> Self {
OpenAIError::APIError(e)
}
}
impl From<String> for OpenAIError {
fn from(e: String) -> Self {
OpenAIError::BadArguments(e)
}
}
impl From<surf::Error> for OpenAIError {
fn from(e: surf::Error) -> Self {
OpenAIError::ProtocolError(e)
}
}
pub struct OpenAIClient {
client: surf::Client,
}
struct BearerToken {
token: String,
}
impl std::fmt::Debug for BearerToken {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
r#"Bearer {{ token: "{}" }}"#,
self.token.get(0..8).ok_or(std::fmt::Error)?
)
}
}
impl BearerToken {
fn new(token: &str) -> Self {
Self {
token: String::from(token),
}
}
}
#[surf::utils::async_trait]
impl surf::middleware::Middleware for BearerToken {
async fn handle(
&self,
mut req: surf::Request,
client: surf::Client,
next: surf::middleware::Next<'_>,
) -> surf::Result<surf::Response> {
log::debug!("Request: {:?}", req);
req.insert_header("Authorization", format!("Bearer {}", self.token));
let response = next.run(req, client).await?;
log::debug!("Response: {:?}", response);
Ok(response)
}
}
impl OpenAIClient {
#[must_use]
pub fn new(token: &str) -> Self {
let mut client = surf::client();
client.set_base_url(
surf::Url::parse("https://api.openai.com/v1/").expect("Static string should parse"),
);
client = client.with(BearerToken::new(token));
Self { client }
}
#[cfg(test)]
fn set_api_root(&mut self, url: surf::Url) {
self.client.set_base_url(url);
}
async fn get<T: serde::de::DeserializeOwned>(&self, endpoint: &str) -> Result<T> {
let mut response = self.client.get(endpoint).await?;
if let surf::StatusCode::Ok = response.status() {
Ok(response.body_json::<T>().await?)
} else {
{
let err = response.body_json::<api::ErrorWrapper>().await?.error;
Err(OpenAIError::APIError(err))
}
}
}
pub async fn engines(&self) -> Result<Vec<api::EngineInfo>> {
self.get("engines").await.map(|r: api::Container<_>| r.data)
}
pub async fn engine(&self, engine: api::Engine) -> Result<api::EngineInfo> {
self.get(&format!("engines/{}", engine)).await
}
async fn post<B, R>(&self, endpoint: &str, body: B) -> Result<R>
where
B: serde::ser::Serialize,
R: serde::de::DeserializeOwned,
{
let mut response = self
.client
.post(endpoint)
.body(surf::Body::from_json(&body)?)
.await?;
match response.status() {
surf::StatusCode::Ok => Ok(response.body_json::<R>().await?),
_ => Err(OpenAIError::APIError(
response
.body_json::<api::ErrorWrapper>()
.await
.expect("The API has returned something funky")
.error,
)),
}
}
pub async fn complete(
&self,
prompt: impl Into<api::CompletionArgs>,
) -> Result<api::Completion> {
let args = prompt.into();
Ok(self
.post(&format!("engines/{}/completions", args.engine), args)
.await?)
}
}
#[cfg(test)]
mod unit {
use crate::{api, OpenAIClient, OpenAIError};
fn mocked_client() -> OpenAIClient {
let _ = env_logger::builder().is_test(true).try_init();
let mut client = OpenAIClient::new("bogus");
client.set_api_root(
surf::Url::parse(&mockito::server_url()).expect("mockito url didn't parse"),
);
client
}
#[test]
fn can_create_client() {
let _c = mocked_client();
}
#[test]
fn parse_engine_info() -> Result<(), Box<dyn std::error::Error>> {
let example = r#"{
"id": "ada",
"object": "engine",
"owner": "openai",
"ready": true
}"#;
let ei: api::EngineInfo = serde_json::from_str(example)?;
assert_eq!(
ei,
api::EngineInfo {
id: api::Engine::Ada,
owner: "openai".into(),
ready: true,
}
);
Ok(())
}
#[tokio::test]
async fn parse_engines() -> crate::Result<()> {
use api::{Engine, EngineInfo};
let _m = mockito::mock("GET", "/engines")
.with_status(200)
.with_header("content-type", "application/json")
.with_body(
r#"{
"object": "list",
"data": [
{
"id": "ada",
"object": "engine",
"owner": "openai",
"ready": true
},
{
"id": "babbage",
"object": "engine",
"owner": "openai",
"ready": true
},
{
"id": "experimental-engine-v7",
"object": "engine",
"owner": "openai",
"ready": false
},
{
"id": "curie",
"object": "engine",
"owner": "openai",
"ready": true
},
{
"id": "davinci",
"object": "engine",
"owner": "openai",
"ready": true
},
{
"id": "content-filter-alpha-c4",
"object": "engine",
"owner": "openai",
"ready": true
}
]
}"#,
)
.create();
let expected = vec![
EngineInfo {
id: Engine::Ada,
owner: "openai".into(),
ready: true,
},
EngineInfo {
id: Engine::Babbage,
owner: "openai".into(),
ready: true,
},
EngineInfo {
id: Engine::Other,
owner: "openai".into(),
ready: false,
},
EngineInfo {
id: Engine::Curie,
owner: "openai".into(),
ready: true,
},
EngineInfo {
id: Engine::Davinci,
owner: "openai".into(),
ready: true,
},
EngineInfo {
id: Engine::ContentFilter,
owner: "openai".into(),
ready: true,
},
];
let response = mocked_client().engines().await?;
assert_eq!(response, expected);
Ok(())
}
#[tokio::test]
async fn engine_error_response() -> crate::Result<()> {
let _m = mockito::mock("GET", "/engines/davinci")
.with_status(404)
.with_header("content-type", "application/json")
.with_body(
r#"{
"error": {
"code": null,
"message": "Some kind of error happened",
"type": "some_error_type"
}
}"#,
)
.create();
let expected = api::ErrorMessage {
message: "Some kind of error happened".into(),
error_type: "some_error_type".into(),
};
let response = mocked_client().engine(api::Engine::Davinci).await;
if let Result::Err(OpenAIError::APIError(msg)) = response {
assert_eq!(expected, msg);
}
Ok(())
}
#[tokio::test]
async fn completion_args() -> crate::Result<()> {
let _m = mockito::mock("POST", "/engines/davinci/completions")
.with_status(200)
.with_header("content-type", "application/json")
.with_body(
r#"{
"id": "cmpl-uqkvlQyYK7bGYrRHQ0eXlWi7",
"object": "text_completion",
"created": 1589478378,
"model": "davinci:2020-05-03",
"choices": [
{
"text": " there was a girl who",
"index": 0,
"logprobs": null,
"finish_reason": "max_tokens"
}
]
}"#,
)
.create();
let expected = api::Completion {
id: "cmpl-uqkvlQyYK7bGYrRHQ0eXlWi7".into(),
created: 1589478378,
model: "davinci:2020-05-03".into(),
choices: vec![api::Choice {
text: " there was a girl who".into(),
index: 0,
logprobs: None,
finish_reason: api::FinishReason::MaxTokensReached,
}],
};
let args = api::CompletionArgs::builder()
.engine(api::Engine::Davinci)
.prompt("Once upon a time")
.max_tokens(5)
.temperature(1.0)
.top_p(1.0)
.n(1)
.stop(vec!["\n".into()])
.build()?;
let response = mocked_client().complete(args).await?;
assert_eq!(response.model, expected.model);
assert_eq!(response.id, expected.id);
assert_eq!(response.created, expected.created);
let (resp_choice, expected_choice) = (&response.choices[0], &expected.choices[0]);
assert_eq!(resp_choice.text, expected_choice.text);
assert_eq!(resp_choice.index, expected_choice.index);
assert!(resp_choice.logprobs.is_none());
assert_eq!(resp_choice.finish_reason, expected_choice.finish_reason);
Ok(())
}
}
#[cfg(test)]
mod integration {
use api::ErrorMessage;
use crate::{api, OpenAIClient, OpenAIError};
fn get_client() -> OpenAIClient {
let _ = env_logger::builder().is_test(true).try_init();
let sk = std::env::var("OPENAI_SK").expect(
"To run integration tests, you must put set the OPENAI_SK env var to your api token",
);
OpenAIClient::new(&sk)
}
#[tokio::test]
async fn can_get_engines() {
let client = get_client();
client.engines().await.unwrap();
}
#[tokio::test]
async fn can_get_engine() {
let client = get_client();
let result = client.engine(api::Engine::Ada).await;
match result {
Err(OpenAIError::APIError(ErrorMessage {
message,
error_type,
})) => {
assert_eq!(message, "No engine with that ID: ada");
assert_eq!(error_type, "invalid_request_error");
}
_ => {
panic!("Expected an error message, got {:?}", result)
}
}
}
#[tokio::test]
async fn complete_string() -> crate::Result<()> {
let client = get_client();
client.complete("Hey there").await?;
Ok(())
}
#[tokio::test]
async fn complete_explicit_params() -> crate::Result<()> {
let client = get_client();
let args = api::CompletionArgsBuilder::default()
.prompt("Once upon a time,")
.max_tokens(10)
.temperature(0.5)
.top_p(0.5)
.n(1)
.logprobs(3)
.echo(false)
.stop(vec!["\n".into()])
.presence_penalty(0.5)
.frequency_penalty(0.5)
.logit_bias(maplit::hashmap! {
"1".into() => 1.0,
"23".into() => 0.0,
})
.build()
.expect("Build should have succeeded");
client.complete(args).await?;
Ok(())
}
#[tokio::test]
async fn complete_stop_condition() -> crate::Result<()> {
let client = get_client();
let mut args = api::CompletionArgs::builder();
let completion = args
.prompt(
r#"
Q: Please type `#` now
A:"#,
)
.temperature(0.0)
.top_p(0.0)
.max_tokens(100)
.stop(vec!["#".into(), "\n".into()])
.complete(&client).await?;
assert_eq!(completion.choices[0].finish_reason, api::FinishReason::StopSequenceReached);
Ok(())
}
}