use std::{
ops::{BitAnd, Deref},
sync::Arc,
};
use tokio::sync::SetOnce;
use crate::prelude::{LocalAgent, RemoteAgent};
#[derive(Default, Clone, Debug, PartialEq, Eq)]
pub enum ClientNameVerifyResult {
#[default]
Accept,
Refuse(String),
SilentRefuse(String),
}
impl BitAnd for ClientNameVerifyResult {
type Output = Self;
fn bitand(self, rhs: Self) -> Self::Output {
use ClientNameVerifyResult::*;
match (self, rhs) {
(Accept, Accept) => Accept,
(SilentRefuse(reason), ..) | (.., SilentRefuse(reason)) => SilentRefuse(reason),
(Refuse(reason), ..) | (.., Refuse(reason)) => Refuse(reason),
}
}
}
#[derive(Default, Clone, Debug, PartialEq, Eq)]
pub enum ClientAgentVerifyResult {
#[default]
Accept,
Refuse(String),
}
impl BitAnd for ClientAgentVerifyResult {
type Output = Self;
fn bitand(self, rhs: Self) -> Self::Output {
use ClientAgentVerifyResult::*;
match (self, rhs) {
(Accept, Accept) => Accept,
(Refuse(reason), ..) | (.., Refuse(reason)) => Refuse(reason),
}
}
}
pub trait AuthClient: Send + Sync {
fn verify_client_name(
&self,
server_agent: &LocalAgent,
client_name: Option<&str>,
) -> ClientNameVerifyResult;
fn verify_client_agent(
&self,
server_agent: &LocalAgent,
client_agent: &RemoteAgent,
) -> ClientAgentVerifyResult;
}
#[derive(Default, Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct AcceptAllClientAuther;
impl AuthClient for AcceptAllClientAuther {
fn verify_client_name(&self, _: &LocalAgent, _: Option<&str>) -> ClientNameVerifyResult {
ClientNameVerifyResult::Accept
}
fn verify_client_agent(&self, _: &LocalAgent, _: &RemoteAgent) -> ClientAgentVerifyResult {
ClientAgentVerifyResult::Accept
}
}
#[derive(Default, Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct ClientNameAuther;
impl AuthClient for ClientNameAuther {
fn verify_client_name(&self, _: &LocalAgent, _: Option<&str>) -> ClientNameVerifyResult {
ClientNameVerifyResult::Accept
}
fn verify_client_agent(
&self,
_: &LocalAgent,
client_agent: &RemoteAgent,
) -> ClientAgentVerifyResult {
use x509_parser::prelude::*;
macro_rules! refuse {
($($tt:tt)*) => {
return ClientAgentVerifyResult::Refuse(format!($($tt)*))
};
}
let cert = match x509_parser::parse_x509_certificate(&client_agent.cert_chain()[0]) {
Ok((_remain, cert)) => cert,
Err(error) => refuse!("Invalid certificate: {error}"),
};
let san = match cert.subject_alternative_name() {
Ok(Some(san)) => san,
Ok(None) => refuse!("Missing SAN in certificate"),
Err(error) => refuse!("Invalid SAN in certificate: {error}"),
};
if san.value.general_names.iter().any(|name| match name {
GeneralName::DNSName(name) => *name == client_agent.name(),
_ => false,
}) {
return ClientAgentVerifyResult::Accept;
}
refuse!("Client name not verified by client certificate")
}
}
impl<A: AuthClient + ?Sized> AuthClient for &A {
fn verify_client_name(
&self,
server_agent: &LocalAgent,
client_name: Option<&str>,
) -> ClientNameVerifyResult {
A::verify_client_name(self, server_agent, client_name)
}
fn verify_client_agent(
&self,
server_agent: &LocalAgent,
client_agent: &RemoteAgent,
) -> ClientAgentVerifyResult {
A::verify_client_agent(self, server_agent, client_agent)
}
}
impl<A: AuthClient + ?Sized> AuthClient for Box<A> {
fn verify_client_name(
&self,
server_agent: &LocalAgent,
client_name: Option<&str>,
) -> ClientNameVerifyResult {
self.deref().verify_client_name(server_agent, client_name)
}
fn verify_client_agent(
&self,
server_agent: &LocalAgent,
client_agent: &RemoteAgent,
) -> ClientAgentVerifyResult {
self.deref().verify_client_agent(server_agent, client_agent)
}
}
impl<A: AuthClient + ?Sized> AuthClient for Arc<A> {
fn verify_client_name(
&self,
server_agent: &LocalAgent,
client_name: Option<&str>,
) -> ClientNameVerifyResult {
self.deref().verify_client_name(server_agent, client_name)
}
fn verify_client_agent(
&self,
server_agent: &LocalAgent,
client_agent: &RemoteAgent,
) -> ClientAgentVerifyResult {
self.deref().verify_client_agent(server_agent, client_agent)
}
}
macro_rules! impl_auth_client_for_tuple {
($head:ident $($tail:ident)*) => {
impl_auth_client_for_tuple!(@impl $head $($tail)*);
impl_auth_client_for_tuple!($($tail)*);
};
(@impl $($t:ident)*) => {
impl<$($t,)*> AuthClient for ($($t,)*)
where
$($t: AuthClient,)*
{
fn verify_client_name(
&self,
server_agent: &LocalAgent,
client_name: Option<&str>
) -> ClientNameVerifyResult {
#[allow(non_snake_case)]
let ($($t,)*) = self;
$($t.verify_client_name(server_agent, client_name) &)* Default::default()
}
fn verify_client_agent(
&self,
server_agent: &LocalAgent,
client_agent: &RemoteAgent
) -> ClientAgentVerifyResult {
#[allow(non_snake_case)]
let ($($t,)*) = self;
$($t.verify_client_agent(server_agent, client_agent) &)* Default::default()
}
}
};
() => {}
}
impl_auth_client_for_tuple! {
Z Y X W V U T S R Q P O N M L K J I H G F E D C B A
}
#[derive(Default, Debug, Clone)]
pub struct ArcSendLock(Arc<SetOnce<()>>);
impl ArcSendLock {
pub fn new() -> Self {
Self::default()
}
pub fn unrestricted() -> Self {
Self(Arc::new(SetOnce::new_with(Some(()))))
}
pub async fn request_permit(&self) {
_ = self.0.wait().await
}
pub fn is_permitted(&self) -> bool {
self.0.get().is_some()
}
pub fn grant_permit(&self) {
_ = self.0.set(());
}
}