use reqwest::header::HeaderMap;
use serde::{de::DeserializeOwned, Serialize};
use crate::{
error::{map_deserialization_error, ApiError, StabilityAIError},
generate::Generate,
user::User,
Engines,
};
#[derive(Debug, Clone)]
pub struct Client {
http_client: reqwest::Client,
api_key: String,
api_base: String,
organization: String,
client_id: Option<String>,
client_version: Option<String>,
backoff: backoff::ExponentialBackoff,
}
pub const API_BASE: &str = "https://api.stability.ai/v1";
pub const ORGANIZATION_HEADER: &str = "Organization";
pub const CLIENT_ID_HEADER: &str = "Stability-Client-ID";
pub const CLIENT_VERSION_HEADER: &str = "Stability-Client-Version";
impl Default for Client {
fn default() -> Self {
Self {
http_client: reqwest::Client::new(),
api_base: API_BASE.to_string(),
api_key: std::env::var("STABILITY_API_KEY").unwrap_or_else(|_| "".to_string()),
organization: Default::default(),
backoff: Default::default(),
client_id: None,
client_version: None,
}
}
}
impl Client {
pub fn new() -> Self {
Default::default()
}
pub fn with_http_client(mut self, http_client: reqwest::Client) -> Self {
self.http_client = http_client;
self
}
pub fn with_api_key<S: Into<String>>(mut self, api_key: S) -> Self {
self.api_key = api_key.into();
self
}
pub fn with_organization<S: Into<String>>(mut self, organization: S) -> Self {
self.organization = organization.into();
self
}
pub fn with_api_base<S: Into<String>>(mut self, api_base: S) -> Self {
self.api_base = api_base.into();
self
}
pub fn with_backoff(mut self, backoff: backoff::ExponentialBackoff) -> Self {
self.backoff = backoff;
self
}
pub fn api_base(&self) -> &str {
&self.api_base
}
pub fn api_key(&self) -> &str {
&self.api_key
}
fn headers(&self) -> HeaderMap {
let mut headers = HeaderMap::new();
if !self.organization.is_empty() {
headers.insert(
ORGANIZATION_HEADER,
self.organization.as_str().parse().unwrap(),
);
}
if let Some(ref client_id) = self.client_id {
if !client_id.is_empty() {
headers.insert(CLIENT_ID_HEADER, client_id.as_str().parse().unwrap());
}
}
if let Some(ref client_version) = self.client_version {
if !client_version.is_empty() {
headers.insert(
CLIENT_VERSION_HEADER,
client_version.as_str().parse().unwrap(),
);
}
}
headers
}
pub fn user(&self) -> User {
User::new(self)
}
pub fn engines(&self) -> Engines {
Engines::new(self)
}
pub fn generate<S: Into<String>>(&self, engine_id: S) -> Generate<S>
where
S: std::fmt::Display,
{
Generate::new(self, engine_id)
}
pub(crate) async fn get<O>(&self, path: &str) -> Result<O, StabilityAIError>
where
O: DeserializeOwned,
{
let request_maker = || async {
Ok(self
.http_client
.get(format!("{}{path}", self.api_base()))
.bearer_auth(self.api_key())
.headers(self.headers())
.build()?)
};
self.execute(request_maker).await
}
pub(crate) async fn post<I, O>(&self, path: &str, request: I) -> Result<O, StabilityAIError>
where
I: Serialize,
O: DeserializeOwned,
{
let request_maker = || async {
Ok(self
.http_client
.post(format!("{}{path}", self.api_base()))
.bearer_auth(self.api_key())
.headers(self.headers())
.json(&request)
.build()?)
};
self.execute(request_maker).await
}
pub(crate) async fn post_form<O, F>(&self, path: &str, form: F) -> Result<O, StabilityAIError>
where
O: DeserializeOwned,
reqwest::multipart::Form: async_convert::TryFrom<F, Error = StabilityAIError>,
F: Clone,
{
let request_maker = || async {
Ok(self
.http_client
.post(format!("{}{path}", self.api_base()))
.bearer_auth(self.api_key())
.headers(self.headers())
.multipart(async_convert::TryInto::try_into(form.clone()).await?)
.build()?)
};
self.execute(request_maker).await
}
async fn execute<O, M, Fut>(&self, request_maker: M) -> Result<O, StabilityAIError>
where
O: DeserializeOwned,
M: Fn() -> Fut,
Fut: core::future::Future<Output = Result<reqwest::Request, StabilityAIError>>,
{
let client = self.http_client.clone();
backoff::future::retry(self.backoff.clone(), || async {
let request = request_maker().await.map_err(backoff::Error::Permanent)?;
let response = client
.execute(request)
.await
.map_err(StabilityAIError::Reqwest)
.map_err(backoff::Error::Permanent)?;
let status = response.status();
let bytes = response
.bytes()
.await
.map_err(StabilityAIError::Reqwest)
.map_err(backoff::Error::Permanent)?;
if !status.is_success() {
let api_error: ApiError = serde_json::from_slice(bytes.as_ref())
.map_err(|e| map_deserialization_error(e, bytes.as_ref()))
.map_err(backoff::Error::Permanent)?;
if status.as_u16() == 429 {
tracing::warn!("Rate limited: {}", api_error.message);
return Err(backoff::Error::Transient {
err: StabilityAIError::ApiError(api_error),
retry_after: None,
});
} else {
return Err(backoff::Error::Permanent(StabilityAIError::ApiError(
api_error,
)));
}
}
let response: O = serde_json::from_slice(bytes.as_ref())
.map_err(|e| map_deserialization_error(e, bytes.as_ref()))
.map_err(backoff::Error::Permanent)?;
Ok(response)
})
.await
}
}