use crate::{receiver::ErrorCounter, ReceiverContext, ReceiverHandler};
use tokio::io::AsyncWriteExt;
use vsmtp_common::Reply;
pub struct Writer<W: tokio::io::AsyncWrite + Unpin + Send> {
inner: W,
}
impl<W: tokio::io::AsyncWrite + Unpin + Send> AsMut<W> for Writer<W> {
#[inline]
fn as_mut(&mut self) -> &mut W {
&mut self.inner
}
}
impl<W: tokio::io::AsyncWrite + Unpin + Send> Writer<W> {
#[inline]
#[must_use]
pub const fn new(inner: W) -> Self {
Self { inner }
}
#[inline]
#[must_use]
#[allow(clippy::missing_const_for_fn)]
pub fn into_inner(self) -> W {
self.inner
}
#[inline]
pub async fn write_all(&mut self, buffer: &str) -> std::io::Result<()> {
tracing::trace!(">> {:?}", buffer);
self.write_all_bytes(buffer.as_bytes()).await
}
#[inline]
pub async fn write_all_bytes(&mut self, buffer: &[u8]) -> std::io::Result<()> {
self.inner.write_all(buffer).await
}
#[inline]
pub async fn send_reply<T: ReceiverHandler + Send>(
&mut self,
ctx: &mut ReceiverContext,
error_counter: &mut ErrorCounter,
handler: &mut T,
reply: Reply,
) -> std::io::Result<()> {
if !reply.code().is_error() {
return self.write_all(reply.as_ref()).await;
}
error_counter.error_count += 1;
let hard_error = error_counter.threshold_hard_error;
let soft_error = error_counter.threshold_soft_error;
if hard_error != -1 && error_counter.error_count >= hard_error {
let reply = handler.on_hard_error(ctx, reply).await;
return self.write_all(reply.as_ref()).await;
}
if soft_error != -1 && error_counter.error_count >= soft_error {
let reply = handler.on_soft_error(ctx, reply).await;
return self.write_all(reply.as_ref()).await;
}
self.write_all(reply.as_ref()).await
}
}