use std::time::Duration;
use reqwest::{Response, StatusCode};
use serde::de::DeserializeOwned;
use serde_json::Value;
use crate::error::{ErrorBody, VynFiError};
use crate::resources::{ApiKeys, Catalog, Credits, Jobs, Usage};
const DEFAULT_BASE_URL: &str = "https://api.vynfi.com";
const DEFAULT_TIMEOUT_SECS: u64 = 30;
const DEFAULT_MAX_RETRIES: u32 = 2;
const USER_AGENT: &str = concat!("vynfi-rust/", env!("CARGO_PKG_VERSION"));
#[derive(Clone)]
pub struct Client {
http: reqwest::Client,
base_url: String,
max_retries: u32,
}
impl Client {
pub fn builder(api_key: impl Into<String>) -> ClientBuilder {
ClientBuilder::new(api_key)
}
pub fn jobs(&self) -> Jobs<'_> {
Jobs::new(self)
}
pub fn catalog(&self) -> Catalog<'_> {
Catalog::new(self)
}
pub fn usage(&self) -> Usage<'_> {
Usage::new(self)
}
pub fn api_keys(&self) -> ApiKeys<'_> {
ApiKeys::new(self)
}
pub fn credits(&self) -> Credits<'_> {
Credits::new(self)
}
pub(crate) async fn request<T: DeserializeOwned>(
&self,
method: reqwest::Method,
path: &str,
) -> Result<T, VynFiError> {
self.request_with_body::<T, ()>(method, path, None).await
}
pub(crate) async fn request_with_body<T, B>(
&self,
method: reqwest::Method,
path: &str,
body: Option<&B>,
) -> Result<T, VynFiError>
where
T: DeserializeOwned,
B: serde::Serialize,
{
let url = format!("{}{}", self.base_url, path);
let mut last_err: Option<VynFiError> = None;
for attempt in 0..=self.max_retries {
let mut req = self.http.request(method.clone(), &url);
if let Some(b) = body {
req = req.json(b);
}
let resp = match req.send().await {
Ok(r) => r,
Err(e) => {
last_err = Some(VynFiError::Http(e));
if attempt < self.max_retries {
tokio::time::sleep(Self::backoff(attempt)).await;
continue;
}
return Err(last_err.unwrap());
}
};
let status = resp.status();
if Self::should_retry(status) && attempt < self.max_retries {
let retry_after = resp
.headers()
.get("retry-after")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.parse::<u64>().ok())
.map(Duration::from_secs);
let wait = retry_after.unwrap_or_else(|| Self::backoff(attempt));
let _ = resp.bytes().await;
tokio::time::sleep(wait).await;
continue;
}
if status == StatusCode::NO_CONTENT {
return serde_json::from_value(serde_json::Value::Null).map_err(VynFiError::from);
}
if status.is_client_error() || status.is_server_error() {
return Err(Self::error_from_response(resp).await);
}
let bytes = resp.bytes().await?;
return serde_json::from_slice(&bytes).map_err(VynFiError::from);
}
Err(last_err.unwrap_or_else(|| VynFiError::Config("max retries exceeded".into())))
}
pub(crate) async fn request_with_params<T: DeserializeOwned>(
&self,
method: reqwest::Method,
path: &str,
params: &[(&str, String)],
) -> Result<T, VynFiError> {
let url = format!("{}{}", self.base_url, path);
let mut last_err: Option<VynFiError> = None;
for attempt in 0..=self.max_retries {
let resp = match self
.http
.request(method.clone(), &url)
.query(params)
.send()
.await
{
Ok(r) => r,
Err(e) => {
last_err = Some(VynFiError::Http(e));
if attempt < self.max_retries {
tokio::time::sleep(Self::backoff(attempt)).await;
continue;
}
return Err(last_err.unwrap());
}
};
let status = resp.status();
if Self::should_retry(status) && attempt < self.max_retries {
let _ = resp.bytes().await;
tokio::time::sleep(Self::backoff(attempt)).await;
continue;
}
if status.is_client_error() || status.is_server_error() {
return Err(Self::error_from_response(resp).await);
}
let bytes = resp.bytes().await?;
return serde_json::from_slice(&bytes).map_err(VynFiError::from);
}
Err(last_err.unwrap_or_else(|| VynFiError::Config("max retries exceeded".into())))
}
fn should_retry(status: StatusCode) -> bool {
status == StatusCode::TOO_MANY_REQUESTS || status.is_server_error()
}
fn backoff(attempt: u32) -> Duration {
Duration::from_millis(500 * 2u64.pow(attempt))
}
async fn error_from_response(resp: Response) -> VynFiError {
let status = resp.status();
let body: ErrorBody = resp.json().await.unwrap_or_else(|_| ErrorBody {
error_type: String::new(),
title: String::new(),
detail: String::new(),
status: status.as_u16(),
request_id: String::new(),
fields: vec![],
});
let body = Box::new(body);
match status {
StatusCode::UNAUTHORIZED => VynFiError::Authentication(body),
StatusCode::PAYMENT_REQUIRED => VynFiError::InsufficientCredits(body),
StatusCode::FORBIDDEN => VynFiError::Forbidden(body),
StatusCode::NOT_FOUND => VynFiError::NotFound(body),
StatusCode::CONFLICT => VynFiError::Conflict(body),
StatusCode::UNPROCESSABLE_ENTITY => VynFiError::Validation(body),
StatusCode::TOO_MANY_REQUESTS => VynFiError::RateLimit(body),
_ => VynFiError::Server(body),
}
}
}
pub struct ClientBuilder {
api_key: String,
base_url: String,
timeout: Duration,
max_retries: u32,
}
impl ClientBuilder {
fn new(api_key: impl Into<String>) -> Self {
Self {
api_key: api_key.into(),
base_url: DEFAULT_BASE_URL.to_string(),
timeout: Duration::from_secs(DEFAULT_TIMEOUT_SECS),
max_retries: DEFAULT_MAX_RETRIES,
}
}
pub fn base_url(mut self, url: impl Into<String>) -> Self {
self.base_url = url.into().trim_end_matches('/').to_string();
self
}
pub fn timeout(mut self, timeout: Duration) -> Self {
self.timeout = timeout;
self
}
pub fn max_retries(mut self, retries: u32) -> Self {
self.max_retries = retries;
self
}
pub fn build(self) -> Result<Client, VynFiError> {
if self.api_key.is_empty() {
return Err(VynFiError::Config("api_key is required".into()));
}
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
reqwest::header::AUTHORIZATION,
format!("Bearer {}", self.api_key)
.parse()
.expect("valid authorization header value"),
);
headers.insert(
reqwest::header::USER_AGENT,
USER_AGENT.parse().expect("valid user-agent header value"),
);
headers.insert(
reqwest::header::CONTENT_TYPE,
"application/json"
.parse()
.expect("valid content-type header value"),
);
let http = reqwest::Client::builder()
.default_headers(headers)
.timeout(self.timeout)
.build()?;
Ok(Client {
http,
base_url: self.base_url,
max_retries: self.max_retries,
})
}
}
pub(crate) fn extract_list<T: DeserializeOwned>(value: Value) -> Result<Vec<T>, VynFiError> {
if value.is_array() {
return Ok(serde_json::from_value(value)?);
}
if let Some(obj) = value.as_object() {
if let Some(arr) = obj.get("data").filter(|v| v.is_array()) {
return Ok(serde_json::from_value(arr.clone())?);
}
for v in obj.values() {
if v.is_array() {
return Ok(serde_json::from_value(v.clone())?);
}
}
}
Ok(vec![])
}