use std::{fmt::Debug, sync::Arc};
use futures::future::BoxFuture;
use rsasl::{
callback::SessionCallback,
config::SASLConfig,
property::{AuthzId, OAuthBearerKV, OAuthBearerToken},
};
use crate::messenger::SaslError;
#[derive(Debug, Clone)]
pub enum SaslConfig {
Plain(Credentials),
ScramSha256(Credentials),
ScramSha512(Credentials),
Oauthbearer(OauthBearerCredentials),
}
#[derive(Debug, Clone)]
pub struct Credentials {
pub username: String,
pub password: String,
}
impl Credentials {
pub fn new(username: String, password: String) -> Self {
Self { username, password }
}
}
impl SaslConfig {
pub(crate) async fn get_sasl_config(&self) -> Result<Arc<SASLConfig>, SaslError> {
match self {
Self::Plain(credentials)
| Self::ScramSha256(credentials)
| Self::ScramSha512(credentials) => Ok(SASLConfig::with_credentials(
None,
credentials.username.clone(),
credentials.password.clone(),
)?),
Self::Oauthbearer(credentials) => {
let token = (*credentials.callback)()
.await
.map_err(SaslError::Callback)?;
struct OauthProvider {
authz_id: Option<String>,
bearer_kvs: Vec<(String, String)>,
token: String,
}
impl SessionCallback for OauthProvider {
fn callback(
&self,
_session_data: &rsasl::callback::SessionData,
_context: &rsasl::callback::Context<'_>,
request: &mut rsasl::callback::Request<'_>,
) -> Result<(), rsasl::prelude::SessionError> {
request
.satisfy::<OAuthBearerKV>(
&self
.bearer_kvs
.iter()
.map(|(k, v)| (k.as_str(), v.as_str()))
.collect::<Vec<_>>(),
)?
.satisfy::<OAuthBearerToken>(&self.token)?;
if let Some(authz_id) = &self.authz_id {
request.satisfy::<AuthzId>(authz_id)?;
}
Ok(())
}
}
Ok(SASLConfig::builder()
.with_default_mechanisms()
.with_callback(OauthProvider {
authz_id: credentials.authz_id.clone(),
bearer_kvs: credentials.bearer_kvs.clone(),
token,
})?)
}
}
}
pub(crate) fn mechanism(&self) -> &str {
use rsasl::mechanisms::*;
match self {
Self::Plain { .. } => plain::PLAIN.mechanism.as_str(),
Self::ScramSha256 { .. } => scram::SCRAM_SHA256.mechanism.as_str(),
Self::ScramSha512 { .. } => scram::SCRAM_SHA512.mechanism.as_str(),
Self::Oauthbearer { .. } => oauthbearer::OAUTHBEARER.mechanism.as_str(),
}
}
}
type DynError = Box<dyn std::error::Error + Send + Sync>;
pub type OauthCallback =
Arc<dyn Fn() -> BoxFuture<'static, Result<String, DynError>> + Send + Sync>;
#[derive(Clone)]
pub struct OauthBearerCredentials {
pub callback: OauthCallback,
pub authz_id: Option<String>,
pub bearer_kvs: Vec<(String, String)>,
}
impl Debug for OauthBearerCredentials {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("OauthBearerCredentials")
.field("authz_id", &self.authz_id)
.field("bearer_kvs", &self.bearer_kvs)
.finish_non_exhaustive()
}
}