use crate::Context;
use crate::error::BoxError;
use std::{convert::Infallible, fmt, sync::Arc};
mod concurrent;
#[doc(inline)]
pub use concurrent::{ConcurrentCounter, ConcurrentPolicy, ConcurrentTracker, LimitReached};
mod matcher;
pub struct PolicyResult<State, Request, Guard, Error> {
pub ctx: Context<State>,
pub request: Request,
pub output: PolicyOutput<Guard, Error>,
}
impl<State: fmt::Debug, Request: fmt::Debug, Guard: fmt::Debug, Error: fmt::Debug> std::fmt::Debug
for PolicyResult<State, Request, Guard, Error>
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PolicyResult")
.field("ctx", &self.ctx)
.field("request", &self.request)
.field("output", &self.output)
.finish()
}
}
pub enum PolicyOutput<Guard, Error> {
Ready(Guard),
Abort(Error),
Retry,
}
impl<Guard: fmt::Debug, Error: fmt::Debug> std::fmt::Debug for PolicyOutput<Guard, Error> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Ready(guard) => write!(f, "PolicyOutput::Ready({guard:?})"),
Self::Abort(error) => write!(f, "PolicyOutput::Abort({error:?})"),
Self::Retry => write!(f, "PolicyOutput::Retry"),
}
}
}
pub trait Policy<State, Request>: Send + Sync + 'static {
type Guard: Send + 'static;
type Error: Send + Sync + 'static;
fn check(
&self,
ctx: Context<State>,
request: Request,
) -> impl Future<Output = PolicyResult<State, Request, Self::Guard, Self::Error>> + Send + '_;
}
impl<State, Request, P> Policy<State, Request> for Option<P>
where
P: Policy<State, Request>,
State: Clone + Send + Sync + 'static,
Request: Send + 'static,
{
type Guard = Option<P::Guard>;
type Error = P::Error;
async fn check(
&self,
ctx: Context<State>,
request: Request,
) -> PolicyResult<State, Request, Self::Guard, Self::Error> {
match self {
Some(policy) => {
let result = policy.check(ctx, request).await;
match result.output {
PolicyOutput::Ready(guard) => PolicyResult {
ctx: result.ctx,
request: result.request,
output: PolicyOutput::Ready(Some(guard)),
},
PolicyOutput::Abort(err) => PolicyResult {
ctx: result.ctx,
request: result.request,
output: PolicyOutput::Abort(err),
},
PolicyOutput::Retry => PolicyResult {
ctx: result.ctx,
request: result.request,
output: PolicyOutput::Retry,
},
}
}
None => PolicyResult {
ctx,
request,
output: PolicyOutput::Ready(None),
},
}
}
}
impl<State, Request, P> Policy<State, Request> for &'static P
where
P: Policy<State, Request>,
State: Clone + Send + Sync + 'static,
Request: Send + 'static,
{
type Guard = P::Guard;
type Error = P::Error;
async fn check(
&self,
ctx: Context<State>,
request: Request,
) -> PolicyResult<State, Request, Self::Guard, Self::Error> {
(**self).check(ctx, request).await
}
}
impl<State, Request, P> Policy<State, Request> for Arc<P>
where
P: Policy<State, Request>,
State: Clone + Send + Sync + 'static,
Request: Send + 'static,
{
type Guard = P::Guard;
type Error = P::Error;
async fn check(
&self,
ctx: Context<State>,
request: Request,
) -> PolicyResult<State, Request, Self::Guard, Self::Error> {
self.as_ref().check(ctx, request).await
}
}
impl<State, Request, P> Policy<State, Request> for Box<P>
where
P: Policy<State, Request>,
State: Clone + Send + Sync + 'static,
Request: Send + 'static,
{
type Guard = P::Guard;
type Error = P::Error;
async fn check(
&self,
ctx: Context<State>,
request: Request,
) -> PolicyResult<State, Request, Self::Guard, Self::Error> {
self.as_ref().check(ctx, request).await
}
}
#[derive(Debug, Clone, Default)]
#[non_exhaustive]
pub struct UnlimitedPolicy;
impl UnlimitedPolicy {
pub const fn new() -> Self {
UnlimitedPolicy
}
}
impl<State, Request> Policy<State, Request> for UnlimitedPolicy
where
State: Clone + Send + Sync + 'static,
Request: Send + 'static,
{
type Guard = ();
type Error = Infallible;
async fn check(
&self,
ctx: Context<State>,
request: Request,
) -> PolicyResult<State, Request, Self::Guard, Self::Error> {
PolicyResult {
ctx,
request,
output: PolicyOutput::Ready(()),
}
}
}
macro_rules! impl_limit_policy_either {
($id:ident, $($param:ident),+ $(,)?) => {
impl<$($param),+, State, Request> Policy<State, Request> for crate::combinators::$id<$($param),+>
where
$(
$param: Policy<State, Request>,
$param::Error: Into<BoxError>,
)+
Request: Send + 'static,
State: Clone + Send + Sync + 'static,
{
type Guard = crate::combinators::$id<$($param::Guard),+>;
type Error = BoxError;
async fn check(
&self,
ctx: Context<State>,
req: Request,
) -> PolicyResult<State, Request, Self::Guard, Self::Error> {
match self {
$(
crate::combinators::$id::$param(policy) => {
let result = policy.check(ctx, req).await;
match result.output {
PolicyOutput::Ready(guard) => PolicyResult {
ctx: result.ctx,
request: result.request,
output: PolicyOutput::Ready(crate::combinators::$id::$param(guard)),
},
PolicyOutput::Abort(err) => PolicyResult {
ctx: result.ctx,
request: result.request,
output: PolicyOutput::Abort(err.into()),
},
PolicyOutput::Retry => PolicyResult {
ctx: result.ctx,
request: result.request,
output: PolicyOutput::Retry,
},
}
}
)+
}
}
}
};
}
crate::combinators::impl_either!(impl_limit_policy_either);