use std::cmp::PartialEq;
use crate::memdx::auth_mechanism::AuthMechanism;
use crate::memdx::dispatcher::Dispatcher;
use crate::memdx::error::Error;
use crate::memdx::error::Result;
use crate::memdx::op_auth_saslbyname::{
OpSASLAuthByNameEncoder, OpsSASLAuthByName, SASLAuthByNameOptions,
};
use crate::memdx::pendingop::StandardPendingOp;
use crate::memdx::request::SASLListMechsRequest;
use crate::memdx::response::SASLListMechsResponse;
use tokio::time::Instant;
#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
pub enum Credentials {
UserPass { username: String, password: String },
JwtToken(String),
}
impl Credentials {
pub fn user_pass(&self) -> Result<(&str, &str)> {
match self {
Credentials::UserPass { username, password } => {
Ok((username.as_str(), password.as_str()))
}
_ => Err(Error::new_invalid_argument_error(
"credentials do not contain username/password",
None,
)),
}
}
pub fn jwt(&self) -> Result<&str> {
match self {
Credentials::JwtToken(token) => Ok(token.as_str()),
_ => Err(Error::new_invalid_argument_error(
"credentials do not contain jwt",
None,
)),
}
}
}
#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
pub struct SASLAuthAutoOptions {
pub credentials: Credentials,
pub enabled_mechs: Vec<AuthMechanism>,
}
#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
pub struct SASLListMechsOptions {}
pub trait OpSASLAutoEncoder: OpSASLAuthByNameEncoder {
fn sasl_list_mechs<D>(
&self,
dispatcher: &D,
request: SASLListMechsRequest,
) -> impl std::future::Future<Output = Result<StandardPendingOp<SASLListMechsResponse>>>
where
D: Dispatcher;
}
#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
pub struct OpsSASLAuthAuto {}
impl OpsSASLAuthAuto {
pub async fn sasl_auth_auto<E, D>(
&self,
encoder: &E,
dispatcher: &D,
deadline: Instant,
opts: SASLAuthAutoOptions,
) -> Result<()>
where
E: OpSASLAutoEncoder,
D: Dispatcher,
{
if opts.enabled_mechs.is_empty() {
return Err(Error::new_invalid_argument_error(
"no enabled mechanisms",
"enabled_mechanisms".to_string(),
));
}
let mut op = encoder
.sasl_list_mechs(dispatcher, SASLListMechsRequest {})
.await?;
let server_mechs = op.recv().await?.available_mechs;
let default_mech = opts.enabled_mechs.first().unwrap();
let by_name = OpsSASLAuthByName {};
match by_name
.sasl_auth_by_name(
encoder,
dispatcher,
SASLAuthByNameOptions {
credentials: opts.credentials.clone(),
auth_mechanism: default_mech.clone(),
deadline,
},
)
.await
{
Ok(()) => Ok(()),
Err(e) => {
if e.is_cancellation_error() {
return Err(e);
}
let supports_default_mech = server_mechs.contains(default_mech);
if supports_default_mech {
return Err(e);
}
let selected_mech = opts
.enabled_mechs
.iter()
.find(|item| server_mechs.contains(item));
let selected_mech = match selected_mech {
Some(mech) => mech,
None => {
return Err(Error::new_message_error("no supported mechanisms found"));
}
};
OpsSASLAuthByName {}
.sasl_auth_by_name(
encoder,
dispatcher,
SASLAuthByNameOptions {
credentials: opts.credentials.clone(),
auth_mechanism: selected_mech.clone(),
deadline,
},
)
.await
}
}
}
}