use crate::{
deserialize_as, path::Format, Connector, DataSource, Error, GetRequest, ListRequest, Page,
PathBuilder,
};
use hyper::{Body, Response};
use serde::de::DeserializeOwned;
use std::{num::NonZeroU32, sync::Arc};
pub type RateLimiter = governor::RateLimiter<
governor::state::NotKeyed,
governor::state::InMemoryState,
governor::clock::QuantaClock,
>;
pub use governor::Quota;
#[derive(Debug, Clone)]
pub struct RateLimitedDataSource<C>
where
C: Connector,
{
datasource: DataSource<C>,
rate_limiter: Arc<RateLimiter>,
}
impl<C> RateLimitedDataSource<C>
where
C: Connector,
{
pub fn new(datasource: DataSource<C>, quota: Quota) -> Self {
Self {
datasource,
rate_limiter: Arc::new(RateLimiter::direct(quota)),
}
}
pub fn per_second(datasource: DataSource<C>, per_second: NonZeroU32) -> Self {
Self::new(datasource, Quota::per_second(per_second))
}
async fn execute<R>(&self, request: R) -> Result<Response<Body>, Error>
where
R: Into<PathBuilder>,
{
self.rate_limiter.until_ready().await;
self.datasource.execute(request).await
}
pub async fn fetch<T>(&self, request: GetRequest) -> Result<T, Error>
where
T: DeserializeOwned,
{
let response = self
.execute(Into::<PathBuilder>::into(request).format(Format::Json))
.await?;
deserialize_as::<T>(response).await
}
pub async fn fetch_paged<T>(&self, request: ListRequest) -> Result<Page<T>, Error>
where
T: DeserializeOwned,
{
let response = self
.execute(Into::<PathBuilder>::into(request).format(Format::Json))
.await?;
deserialize_as::<Page<T>>(response).await
}
}