use crate::ynab::errors::{Error, ErrorResponse};
use governor::{DefaultDirectRateLimiter, Quota, RateLimiter};
use std::fmt;
use std::num::NonZeroU32;
use std::sync::Arc;
use std::time::Duration;
pub struct Client {
base_url: reqwest::Url,
http_client: reqwest::Client,
limiter: Option<Arc<DefaultDirectRateLimiter>>,
api_key: String,
timeout: Option<Duration>,
}
impl fmt::Debug for Client {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Client")
.field("base_url", &self.base_url)
.field("api_key", &"[redacted]")
.finish()
}
}
impl Client {
pub fn new(api_key: impl Into<String>) -> Result<Self, Error> {
let api_key = api_key.into();
let http_client = Self::build_http_client(&api_key, None)?;
Ok(Self {
base_url: reqwest::Url::parse("https://api.ynab.com/v1").unwrap(),
http_client,
limiter: None,
api_key,
timeout: None,
})
}
fn build_http_client(
api_key: &str,
timeout: Option<Duration>,
) -> Result<reqwest::Client, Error> {
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
reqwest::header::AUTHORIZATION,
format!("Bearer {}", api_key)
.parse()
.expect("api key must be valid ASCII"),
);
let mut builder = reqwest::Client::builder().default_headers(headers);
if let Some(t) = timeout {
builder = builder.timeout(t);
}
builder.build().map_err(Into::into)
}
pub fn with_timeout(mut self, timeout: Duration) -> Result<Self, Error> {
self.http_client = Self::build_http_client(&self.api_key, Some(timeout))?;
self.timeout = Some(timeout);
Ok(self)
}
pub fn with_base_url(mut self, base_url: impl AsRef<str>) -> Result<Self, Error> {
self.base_url = reqwest::Url::parse(base_url.as_ref())?;
Ok(self)
}
pub fn with_rate_limiter(
mut self,
requests_per_hour: usize,
burst_volume: Option<usize>,
) -> Result<Self, Error> {
let requests = NonZeroU32::new(requests_per_hour as u32)
.ok_or_else(|| Error::InvalidRateLimit("requests_per_hour must be non-zero".into()))?;
let quota = match burst_volume {
None => Quota::per_hour(requests),
Some(burst) => {
let effective = (requests_per_hour as u32)
.checked_sub(burst as u32)
.ok_or_else(|| {
Error::InvalidRateLimit(
"requests_per_hour must be greater than burst_volume".into(),
)
})?;
let effective_rate = NonZeroU32::new(effective).ok_or_else(|| {
Error::InvalidRateLimit(
"requests_per_hour - burst_volume must be non-zero".into(),
)
})?;
let burst = NonZeroU32::new(burst as u32).ok_or_else(|| {
Error::InvalidRateLimit("burst_volume must be non-zero".into())
})?;
Quota::per_hour(effective_rate).allow_burst(burst)
}
};
self.limiter = Some(Arc::new(RateLimiter::direct(quota)));
Ok(self)
}
pub(crate) async fn get<T: serde::de::DeserializeOwned, Q: serde::ser::Serialize + ?Sized>(
&self,
endpoint: &str,
params: Option<&Q>,
) -> Result<T, Error> {
if let Some(limiter) = &self.limiter {
limiter.until_ready().await;
}
let mut url = self.base_url.clone();
url.path_segments_mut()
.expect("base URL must be a valid base")
.extend(endpoint.split('/'));
let mut builder = self.http_client.get(url);
if let Some(p) = params {
builder = builder.query(p);
}
let res = builder.send().await?;
let status = res.status();
if !status.is_success() {
let err_body: ErrorResponse = res.json().await?;
return Err(Error::new_api_error(status, err_body.error));
}
res.json().await.map_err(Into::into)
}
pub(crate) async fn post<T: serde::de::DeserializeOwned, B: serde::ser::Serialize>(
&self,
endpoint: &str,
body: B,
) -> Result<T, Error> {
if let Some(limiter) = &self.limiter {
limiter.until_ready().await;
}
let mut url = self.base_url.clone();
url.path_segments_mut()
.expect("base URL must be a valid base")
.extend(endpoint.split('/'));
let res = self.http_client.post(url).json(&body).send().await?;
let status = res.status();
if !status.is_success() {
let err_body: ErrorResponse = res.json().await?;
return Err(Error::new_api_error(status, err_body.error));
}
res.json().await.map_err(Into::into)
}
pub(crate) async fn patch<T: serde::de::DeserializeOwned, B: serde::ser::Serialize>(
&self,
endpoint: &str,
body: B,
) -> Result<T, Error> {
if let Some(limiter) = &self.limiter {
limiter.until_ready().await;
}
let mut url = self.base_url.clone();
url.path_segments_mut()
.expect("base URL must be a valid base")
.extend(endpoint.split('/'));
let res = self.http_client.patch(url).json(&body).send().await?;
let status = res.status();
if !status.is_success() {
let err_body: ErrorResponse = res.json().await?;
return Err(Error::new_api_error(status, err_body.error));
}
res.json().await.map_err(Into::into)
}
pub(crate) async fn put<T: serde::de::DeserializeOwned, B: serde::ser::Serialize>(
&self,
endpoint: &str,
body: B,
) -> Result<T, Error> {
if let Some(limiter) = &self.limiter {
limiter.until_ready().await;
}
let mut url = self.base_url.clone();
url.path_segments_mut()
.expect("base URL must be a valid base")
.extend(endpoint.split('/'));
let res = self.http_client.put(url).json(&body).send().await?;
let status = res.status();
if !status.is_success() {
let err_body: ErrorResponse = res.json().await?;
return Err(Error::new_api_error(status, err_body.error));
}
res.json().await.map_err(Into::into)
}
pub(crate) async fn delete<T: serde::de::DeserializeOwned>(
&self,
endpoint: &str,
) -> Result<T, Error> {
if let Some(limiter) = &self.limiter {
limiter.until_ready().await;
}
let mut url = self.base_url.clone();
url.path_segments_mut()
.expect("base URL must be a valid base")
.extend(endpoint.split('/'));
let res = self.http_client.delete(url).send().await?;
let status = res.status();
if !status.is_success() {
let err_body: ErrorResponse = res.json().await?;
return Err(Error::new_api_error(status, err_body.error));
}
res.json().await.map_err(Into::into)
}
}