use std::time::Duration;
use reqwest::{RequestBuilder, Response, StatusCode};
use serde::de::DeserializeOwned;
use serde_json::Value;
use crate::error::{ErrorBody, VynFiError};
use crate::resources::{
ApiKeys, Billing, Catalog, Configs, Credits, Jobs, Notifications, Quality, Scenarios, Sessions,
Usage, Webhooks,
};
const DEFAULT_BASE_URL: &str = "https://api.vynfi.com";
const DEFAULT_TIMEOUT_SECS: u64 = 30;
const DEFAULT_MAX_RETRIES: u32 = 2;
const USER_AGENT: &str = concat!("vynfi-rust/", env!("CARGO_PKG_VERSION"));
#[derive(Clone)]
pub struct Client {
http: reqwest::Client,
base_url: String,
max_retries: u32,
}
impl Client {
pub fn builder(api_key: impl Into<String>) -> ClientBuilder {
ClientBuilder::new(api_key)
}
pub fn jobs(&self) -> Jobs<'_> {
Jobs::new(self)
}
pub fn catalog(&self) -> Catalog<'_> {
Catalog::new(self)
}
pub fn usage(&self) -> Usage<'_> {
Usage::new(self)
}
pub fn api_keys(&self) -> ApiKeys<'_> {
ApiKeys::new(self)
}
pub fn quality(&self) -> Quality<'_> {
Quality::new(self)
}
pub fn webhooks(&self) -> Webhooks<'_> {
Webhooks::new(self)
}
pub fn billing(&self) -> Billing<'_> {
Billing::new(self)
}
pub fn configs(&self) -> Configs<'_> {
Configs::new(self)
}
pub fn credits(&self) -> Credits<'_> {
Credits::new(self)
}
pub fn sessions(&self) -> Sessions<'_> {
Sessions::new(self)
}
pub fn scenarios(&self) -> Scenarios<'_> {
Scenarios::new(self)
}
pub fn notifications(&self) -> Notifications<'_> {
Notifications::new(self)
}
pub(crate) async fn request<T: DeserializeOwned>(
&self,
method: reqwest::Method,
path: &str,
) -> Result<T, VynFiError> {
self.send_with_retry(method, path, |req| req).await
}
pub(crate) async fn request_with_body<T, B>(
&self,
method: reqwest::Method,
path: &str,
body: Option<&B>,
) -> Result<T, VynFiError>
where
T: DeserializeOwned,
B: serde::Serialize,
{
let body_value = body.map(|b| serde_json::to_value(b).expect("serializable body"));
self.send_with_retry(method, path, move |req| match &body_value {
Some(v) => req.json(v),
None => req,
})
.await
}
pub(crate) async fn request_with_params<T: DeserializeOwned>(
&self,
method: reqwest::Method,
path: &str,
params: &[(&str, String)],
) -> Result<T, VynFiError> {
let params = params.to_vec();
self.send_with_retry(method, path, move |req| req.query(¶ms))
.await
}
pub(crate) async fn request_raw(
&self,
method: reqwest::Method,
path: &str,
params: &[(&str, String)],
) -> Result<Response, VynFiError> {
let url = format!("{}{}", self.base_url, path);
let mut req = self.http.request(method, &url);
if !params.is_empty() {
req = req.query(params);
}
let resp = req.send().await?;
if resp.status().is_client_error() || resp.status().is_server_error() {
return Err(Self::error_from_response(resp).await);
}
Ok(resp)
}
pub(crate) fn url(&self, path: &str) -> String {
format!("{}{}", self.base_url, path)
}
pub(crate) fn http(&self) -> &reqwest::Client {
&self.http
}
async fn send_with_retry<T, F>(
&self,
method: reqwest::Method,
path: &str,
configure: F,
) -> Result<T, VynFiError>
where
T: DeserializeOwned,
F: Fn(RequestBuilder) -> RequestBuilder,
{
let url = format!("{}{}", self.base_url, path);
let mut last_err: Option<VynFiError> = None;
for attempt in 0..=self.max_retries {
let req = configure(self.http.request(method.clone(), &url));
let resp = match req.send().await {
Ok(r) => r,
Err(e) => {
last_err = Some(VynFiError::Http(e));
if attempt < self.max_retries {
tokio::time::sleep(Self::backoff(attempt)).await;
continue;
}
return Err(last_err.unwrap());
}
};
let status = resp.status();
if Self::should_retry(status) && attempt < self.max_retries {
let retry_after = resp
.headers()
.get("retry-after")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.parse::<u64>().ok())
.map(Duration::from_secs);
let wait = retry_after.unwrap_or_else(|| Self::backoff(attempt));
let _ = resp.bytes().await;
tokio::time::sleep(wait).await;
continue;
}
if status == StatusCode::NO_CONTENT {
return serde_json::from_value(serde_json::Value::Null).map_err(VynFiError::from);
}
if status.is_client_error() || status.is_server_error() {
return Err(Self::error_from_response(resp).await);
}
let bytes = resp.bytes().await?;
return serde_json::from_slice(&bytes).map_err(VynFiError::from);
}
Err(last_err.unwrap_or_else(|| VynFiError::Config("max retries exceeded".into())))
}
fn should_retry(status: StatusCode) -> bool {
status == StatusCode::TOO_MANY_REQUESTS || status.is_server_error()
}
fn backoff(attempt: u32) -> Duration {
Duration::from_millis(500 * 2u64.pow(attempt))
}
async fn error_from_response(resp: Response) -> VynFiError {
let status = resp.status();
let body: ErrorBody = resp.json().await.unwrap_or_else(|_| ErrorBody {
error_type: String::new(),
title: String::new(),
detail: format!("HTTP {}", status.as_u16()),
status: status.as_u16(),
instance: None,
});
let body = Box::new(body);
match status {
StatusCode::UNAUTHORIZED => VynFiError::Authentication(body),
StatusCode::PAYMENT_REQUIRED => VynFiError::InsufficientCredits(body),
StatusCode::FORBIDDEN => VynFiError::Forbidden(body),
StatusCode::NOT_FOUND => VynFiError::NotFound(body),
StatusCode::CONFLICT => VynFiError::Conflict(body),
StatusCode::UNPROCESSABLE_ENTITY => VynFiError::Validation(body),
StatusCode::TOO_MANY_REQUESTS => VynFiError::RateLimit(body),
_ => VynFiError::Server(body),
}
}
}
pub struct ClientBuilder {
api_key: String,
base_url: String,
timeout: Duration,
max_retries: u32,
}
impl ClientBuilder {
fn new(api_key: impl Into<String>) -> Self {
Self {
api_key: api_key.into(),
base_url: DEFAULT_BASE_URL.to_string(),
timeout: Duration::from_secs(DEFAULT_TIMEOUT_SECS),
max_retries: DEFAULT_MAX_RETRIES,
}
}
pub fn base_url(mut self, url: impl Into<String>) -> Self {
self.base_url = url.into().trim_end_matches('/').to_string();
self
}
pub fn timeout(mut self, timeout: Duration) -> Self {
self.timeout = timeout;
self
}
pub fn max_retries(mut self, retries: u32) -> Self {
self.max_retries = retries;
self
}
pub fn build(self) -> Result<Client, VynFiError> {
if self.api_key.is_empty() {
return Err(VynFiError::Config("api_key is required".into()));
}
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
reqwest::header::AUTHORIZATION,
format!("Bearer {}", self.api_key)
.parse()
.expect("valid authorization header value"),
);
headers.insert(
reqwest::header::USER_AGENT,
USER_AGENT.parse().expect("valid user-agent header value"),
);
headers.insert(
reqwest::header::CONTENT_TYPE,
"application/json"
.parse()
.expect("valid content-type header value"),
);
let http = reqwest::Client::builder()
.default_headers(headers)
.timeout(self.timeout)
.build()?;
Ok(Client {
http,
base_url: self.base_url,
max_retries: self.max_retries,
})
}
}
pub(crate) fn extract_list<T: DeserializeOwned>(value: Value) -> Result<Vec<T>, VynFiError> {
if value.is_array() {
return Ok(serde_json::from_value(value)?);
}
if let Value::Object(mut map) = value {
if let Some(arr) = map.remove("data").filter(|v| v.is_array()) {
return Ok(serde_json::from_value(arr)?);
}
for (_, v) in map {
if v.is_array() {
return Ok(serde_json::from_value(v)?);
}
}
}
Ok(vec![])
}