use crate::{Receiver, ReceiverHandler};
use base64::{engine::general_purpose::STANDARD, Engine};
use tokio::io::AsyncWriteExt;
use tokio_stream::StreamExt;
use vsmtp_common::auth::Mechanism;
#[repr(transparent)]
#[allow(clippy::exhaustive_structs)]
pub struct CallbackWrap(pub Box<dyn rsasl::callback::SessionCallback + Send + Sync>);
#[allow(clippy::missing_trait_methods)]
impl rsasl::callback::SessionCallback for CallbackWrap {
#[inline]
fn callback(
&self,
session_data: &rsasl::callback::SessionData,
context: &rsasl::callback::Context<'_>,
request: &mut rsasl::callback::Request<'_>,
) -> Result<(), rsasl::prelude::SessionError> {
self.0.callback(session_data, context, request)
}
#[inline]
fn validate(
&self,
session_data: &rsasl::callback::SessionData,
context: &rsasl::callback::Context<'_>,
validate: &mut rsasl::validate::Validate<'_>,
) -> Result<(), rsasl::validate::ValidationError> {
self.0.validate(session_data, context, validate)
}
}
#[derive(Debug, thiserror::Error)]
#[allow(clippy::exhaustive_enums)]
pub enum AuthError {
#[error("client must not start with this mechanism")]
ClientMustNotStart,
#[error("validation failed: {0}")]
ValidationError(Box<dyn std::error::Error + Send + Sync>),
#[error("sasl challenge cancelled by the client")]
Canceled,
#[error("base64 decoding fail: {source}")]
Base64 {
#[from]
#[source]
source: base64::DecodeError,
},
#[error("io error: {0}")]
IO(#[from] std::io::Error),
#[error("error produced by the backend: {0}")]
SessionError(rsasl::prelude::SessionError),
#[error("error while initializing the SASL backend: {0}")]
ConfigError(#[from] rsasl::prelude::SASLError),
}
impl<
T: ReceiverHandler + Send,
V: rsasl::validate::Validation + Send,
W: tokio::io::AsyncWrite + Unpin + Send,
R: tokio::io::AsyncRead + Unpin + Send,
> Receiver<T, V, W, R>
where
V::Value: Send + Sync,
{
pub(crate) async fn authenticate(
&mut self,
mechanism: Mechanism,
initial_response: Option<Vec<u8>>,
) -> Result<(), AuthError> {
macro_rules! block_on {
($future:expr) => {
tokio::task::block_in_place(move || {
tokio::runtime::Handle::current().block_on($future)
})
};
}
struct AdapterSMTPandSASL<'writer, W: tokio::io::AsyncWrite + Unpin + Send>(&'writer mut W);
#[allow(clippy::missing_trait_methods)]
impl<'writer, W: tokio::io::AsyncWrite + Unpin + Send> std::io::Write
for AdapterSMTPandSASL<'writer, W>
{
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
block_on! { async move {
self.0.write_all(b"334 ").await?;
self.0.write_all(STANDARD.encode(buf).as_bytes()).await?;
self.0.write_all(b"\r\n").await?;
std::io::Result::Ok(())
}}
.map(|_| buf.len())
}
fn flush(&mut self) -> std::io::Result<()> {
block_on! { tokio::io::AsyncWriteExt::flush(&mut self.0) }
}
}
let callback = self.handler.generate_sasl_callback();
let rsasl_config = rsasl::config::SASLConfig::builder()
.with_default_mechanisms()
.with_callback(callback)?;
let sasl_server = rsasl::prelude::SASLServer::<V>::new(rsasl_config);
let temp = mechanism.to_string();
#[allow(clippy::expect_used)]
let selected =
rsasl::prelude::Mechname::parse(temp.as_bytes()).expect("mechanism is valid");
let mut session = sasl_server.start_suggested(selected)?;
let mut adapter = AdapterSMTPandSASL(&mut self.sink.inner);
let challenge_stream = self.stream.as_line_stream().map(|line| {
let l = line.map(|buffer| {
buffer
.strip_suffix(b"\r\n")
.map(<[u8]>::to_vec)
.ok_or_else(|| {
std::io::Error::new(
std::io::ErrorKind::InvalidData,
"Line does not end with \r\n",
)
})
});
match l {
Ok(Ok(l)) => Ok(l),
Ok(Err(e)) | Err(e) => Err(e),
}
});
tokio::pin!(challenge_stream);
macro_rules! next_challenge_line {
($challenge_stream:expr) => {
match challenge_stream.next().await {
Some(Ok(buffer)) if buffer == b"*" => return Err(AuthError::Canceled),
Some(Ok(buffer)) => Some(
STANDARD
.decode(buffer)
.map_err(|source| AuthError::Base64 { source })?,
),
Some(Err(e)) => todo!("{}", e),
None => todo!("what happen when the client close the connection?"),
}
};
}
let mut data = match (initial_response, session.are_we_first()) {
(None, true) => None,
(None, false) => {
std::io::Write::write(&mut adapter, &[])?;
next_challenge_line!(challenge_stream)
}
(Some(_), true) => return Err(AuthError::ClientMustNotStart),
(Some(data), false) => Some(STANDARD.decode(data)?),
};
#[allow(clippy::wildcard_enum_match_arm)]
while session
.step(data.as_deref(), &mut adapter)
.map_err(|e| match e {
rsasl::prelude::SessionError::ValidationError(
rsasl::validate::ValidationError::Boxed(e),
) => AuthError::ValidationError(e),
otherwise => AuthError::SessionError(otherwise),
})?
.is_running()
{
data = next_challenge_line!(challenge_stream);
}
#[allow(clippy::todo)]
session.validation().map_or_else(
|| todo!("what happen when the validator return nothing ?"),
|_v| Ok(()),
)
}
}