use crate::alloc::boxed::Box;
use crate::mechanism::Authentication;
use crate::mechname::Mechname;
use core::cmp::Ordering;
use core::fmt;
use crate::config::SASLConfig;
use crate::error::SASLError;
pub use crate::session::Side;
#[cfg(feature = "registry_static")]
pub use registry_static::*;
pub type StartFn = fn() -> Result<Box<dyn Authentication>, SASLError>;
pub type ServerStartFn = fn(sasl: &SASLConfig) -> Result<Box<dyn Authentication>, SASLError>;
#[non_exhaustive]
#[derive(Copy, Clone)]
pub struct Mechanism {
pub mechanism: &'static Mechname,
pub(crate) priority: usize,
pub(crate) client: Option<StartFn>,
pub(crate) server: Option<ServerStartFn>,
#[cfg_attr(not(feature = "provider"), allow(unused))]
pub(crate) first: Side,
pub(crate) select: fn(bool) -> Option<Selection>,
#[allow(dead_code)]
pub(crate) offer: fn(bool) -> bool,
}
#[cfg(feature = "unstable_custom_mechanism")]
impl Mechanism {
#[must_use]
pub const fn build(
mechanism: &'static Mechname,
priority: usize,
client: Option<StartFn>,
server: Option<ServerStartFn>,
first: Side,
select: fn(bool) -> Option<Selection>,
offer: fn(bool) -> bool,
) -> Self {
Self {
mechanism,
priority,
client,
server,
first,
select,
offer,
}
}
}
impl Mechanism {
#[must_use]
pub fn client(&self) -> Option<Result<Box<dyn Authentication>, SASLError>> {
self.client.map(|f| f())
}
#[must_use]
pub fn server(&self, sasl: &SASLConfig) -> Option<Result<Box<dyn Authentication>, SASLError>> {
self.server.map(|f| f(sasl))
}
#[must_use]
fn select(&self, cb: bool) -> Option<Selection> {
(self.select)(cb)
}
}
impl fmt::Debug for Mechanism {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Mechanism")
.field("name", &self.mechanism)
.field("has client", &self.client.is_some())
.field("has server", &self.server.is_some())
.finish()
}
}
impl fmt::Display for Mechanism {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(self.mechanism.as_str())
}
}
#[derive(Debug, Clone)]
pub struct Registry {
static_mechanisms: &'static [Mechanism],
}
#[cfg(any(test, feature = "config_builder", feature = "testutils"))]
mod config {
use super::Registry;
use crate::registry::Mechanism;
#[cfg(feature = "config_builder")]
impl Registry {
#[inline(always)]
#[must_use]
pub const fn with_mechanisms(mechanisms: &'static [Mechanism]) -> Self {
Self {
static_mechanisms: mechanisms,
}
}
pub(crate) fn credentials(authzid: bool) -> Self {
static CRED_AUTHZID: &[Mechanism] = &[
#[cfg(feature = "scram-sha-2")]
crate::mechanisms::scram::SCRAM_SHA256,
#[cfg(feature = "scram-sha-1")]
crate::mechanisms::scram::SCRAM_SHA1,
#[cfg(feature = "plain")]
crate::mechanisms::plain::PLAIN,
];
static CRED: &[Mechanism] = &[
#[cfg(feature = "scram-sha-2")]
crate::mechanisms::scram::SCRAM_SHA256,
#[cfg(feature = "scram-sha-1")]
crate::mechanisms::scram::SCRAM_SHA1,
#[cfg(feature = "plain")]
crate::mechanisms::plain::PLAIN,
#[cfg(feature = "login")]
crate::mechanisms::login::LOGIN,
];
let mechanisms = if authzid { CRED_AUTHZID } else { CRED };
Self::with_mechanisms(mechanisms)
}
}
#[cfg(feature = "registry_static")]
impl Default for Registry {
fn default() -> Self {
Self::with_mechanisms(&super::registry_static::MECHANISMS)
}
}
#[cfg(not(feature = "registry_static"))]
impl Default for Registry {
fn default() -> Self {
static BUILTIN: &[Mechanism] = &[
#[cfg(feature = "scram-sha-2")]
crate::mechanisms::scram::SCRAM_SHA256,
#[cfg(feature = "scram-sha-1")]
crate::mechanisms::scram::SCRAM_SHA1,
#[cfg(feature = "plain")]
crate::mechanisms::plain::PLAIN,
#[cfg(feature = "login")]
crate::mechanisms::login::LOGIN,
#[cfg(feature = "anonymous")]
crate::mechanisms::anonymous::ANONYMOUS,
#[cfg(feature = "external")]
crate::mechanisms::external::EXTERNAL,
#[cfg(feature = "xoauth2")]
crate::mechanisms::xoauth2::XOAUTH2,
#[cfg(feature = "oauthbearer")]
crate::mechanisms::oauthbearer::OAUTHBEARER,
];
Self::with_mechanisms(BUILTIN)
}
}
}
pub type MechanismIter<'a> = core::slice::Iter<'a, Mechanism>;
impl Registry {
#[inline(always)]
pub(crate) fn get_mechanisms<'a>(&self) -> MechanismIter<'a> {
self.static_mechanisms.iter()
}
pub(crate) fn select<'a>(
&self,
cb: bool,
offered: impl Iterator<Item = &'a Mechname>,
mut fold: impl FnMut(Option<&'static Mechanism>, &'static Mechanism) -> Ordering,
) -> Result<(Box<dyn Authentication>, &'static Mechanism), SASLError> {
let mut selectors: Vec<Selection> =
self.get_mechanisms().filter_map(|m| m.select(cb)).collect();
for o in offered {
for s in &mut selectors {
s.select(o);
}
}
let (mut s, m) = selectors
.into_iter()
.filter_map(|mut s| s.done().map(|m| (s, m)))
.fold(None, |acc, (s, m)| {
let accmech = acc.as_ref().map(|(_, m)| *m);
match fold(accmech, m) {
Ordering::Greater => acc,
Ordering::Equal |
Ordering::Less => Some((s,m))
}
})
.ok_or(SASLError::NoSharedMechanism)?;
s.finalize().map(|a| (a, m))
}
}
#[cfg(feature = "registry_static")]
mod registry_static {
use super::Mechanism;
pub use linkme::distributed_slice;
#[distributed_slice]
pub static MECHANISMS: [Mechanism] = [..];
}
#[cfg(not(feature = "registry_static"))]
mod registry_static {
use super::Mechanism;
pub static MECHANISMS: [Mechanism; 0] = [];
}
mod selector {
use super::{Authentication, Box, Mechanism, Mechname, SASLError};
use alloc::marker::PhantomData;
pub trait Selector {
fn select(&mut self, mechname: &Mechname) -> Option<&'static Mechanism>;
fn done(&mut self) -> Option<&'static Mechanism>;
fn finalize(&mut self) -> Result<Box<dyn Authentication>, SASLError>;
}
pub enum Selection {
Nothing(Box<dyn Selector>),
Done(&'static Mechanism),
}
impl Selection {
pub(super) fn select(&mut self, mechname: &Mechname) {
if let Self::Nothing(ref mut selector) = self {
if let Some(m) = selector.select(mechname) {
*self = Self::Done(m);
}
}
}
pub(super) fn done(&mut self) -> Option<&'static Mechanism> {
match self {
Self::Nothing(selector) => selector.done(),
Self::Done(m) => Some(m),
}
}
pub(super) fn finalize(&mut self) -> Result<Box<dyn Authentication>, SASLError> {
match self {
Self::Nothing(selector) => selector.finalize(),
Self::Done(m) => m.client().unwrap(),
}
}
}
pub trait Named {
fn mech() -> &'static Mechanism;
}
#[repr(transparent)]
pub struct Matches<T>(PhantomData<T>);
impl<T: Named + 'static> Matches<T> {
#[must_use]
pub fn name() -> Selection {
Selection::Nothing(Box::new(Self(PhantomData)))
}
}
impl<T: Named> Selector for Matches<T> {
fn select(&mut self, mechname: &Mechname) -> Option<&'static Mechanism> {
let m = T::mech();
if *mechname == *m.mechanism {
Some(m)
} else {
None
}
}
fn done(&mut self) -> Option<&'static Mechanism> {
None
}
fn finalize(&mut self) -> Result<Box<dyn Authentication>, SASLError> {
let m = T::mech();
(m.client.unwrap())()
}
}
}
pub use selector::*;