indymilter 0.3.0

Asynchronous milter library
Documentation
// indymilter – asynchronous milter library
// Copyright © 2021–2024 David Bürgin <dbuergin@gluet.ch>
//
// This program is free software: you can redistribute it and/or modify it under
// the terms of the GNU General Public License as published by the Free Software
// Foundation, either version 3 of the License, or (at your option) any later
// version.
//
// This program is distributed in the hope that it will be useful, but WITHOUT
// ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
// FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
// details.
//
// You should have received a copy of the GNU General Public License along with
// this program. If not, see <https://www.gnu.org/licenses/>.

use std::{
    future::Future,
    io::{self, IoSlice},
    pin::Pin,
    task::{Context, Poll},
};
use tokio::{
    io::{AsyncRead, AsyncWrite, ReadBuf},
    net::{TcpListener, TcpStream},
};

// This API is similar to `Listener` in tokio-util.

/// A Tokio-based listener.
pub trait Listener {
    /// The listener’s connection stream.
    type Io: AsyncRead + AsyncWrite + Unpin + Send + 'static;

    /// Polls to accept a new incoming connection to this listener.
    fn poll_accept(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<Self::Io>>;

    /// Accepts a new incoming connection from this listener.
    fn accept(&mut self) -> ListenerAcceptFut<'_, Self>
    where
        Self: Sized,
    {
        ListenerAcceptFut { listener: self }
    }
}

impl Listener for TcpListener {
    type Io = TcpStream;

    fn poll_accept(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<Self::Io>> {
        Self::poll_accept(self, cx).map_ok(|(stream, _)| stream)
    }
}

#[cfg(unix)]
impl Listener for tokio::net::UnixListener {
    type Io = tokio::net::UnixStream;

    fn poll_accept(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<Self::Io>> {
        Self::poll_accept(self, cx).map_ok(|(stream, _)| stream)
    }
}

pub struct ListenerAcceptFut<'a, L> {
    listener: &'a mut L,
}

impl<'a, L> Future for ListenerAcceptFut<'a, L>
where
    L: Listener,
{
    type Output = io::Result<L::Io>;

    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        self.listener.poll_accept(cx)
    }
}

/// A listener being either the one or the other listener type.
///
/// The variant names `Tcp` and `Unix` suggest the intended usage with
/// `tokio::net::TcpListener` and `tokio::net::UnixListener`, but technically
/// may hold any type implementing `Listener`.
#[derive(Debug)]
pub enum EitherListener<T, U> {
    /// A TCP (or other) socket listener.
    Tcp(T),
    /// A UNIX domain (or other) socket listener.
    Unix(U),
}

impl<T, U> Listener for EitherListener<T, U>
where
    T: Listener,
    U: Listener,
{
    type Io = EitherStream<T::Io, U::Io>;

    fn poll_accept(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<Self::Io>> {
        match self {
            Self::Tcp(listener) => listener.poll_accept(cx).map_ok(Self::Io::Tcp),
            Self::Unix(listener) => listener.poll_accept(cx).map_ok(Self::Io::Unix),
        }
    }
}

/// Companion connection stream of `EitherListener`.
#[derive(Debug)]
pub enum EitherStream<T, U> {
    /// A TCP (or other) stream.
    Tcp(T),
    /// A UNIX socket (or other) stream.
    Unix(U),
}

impl<T, U> AsyncRead for EitherStream<T, U>
where
    T: AsyncRead + Unpin,
    U: AsyncRead + Unpin,
{
    fn poll_read(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        buf: &mut ReadBuf<'_>,
    ) -> Poll<io::Result<()>> {
        match self.get_mut() {
            Self::Tcp(stream) => Pin::new(stream).poll_read(cx, buf),
            Self::Unix(stream) => Pin::new(stream).poll_read(cx, buf),
        }
    }
}

impl<T, U> AsyncWrite for EitherStream<T, U>
where
    T: AsyncWrite + Unpin,
    U: AsyncWrite + Unpin,
{
    fn poll_write(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        buf: &[u8],
    ) -> Poll<io::Result<usize>> {
        match self.get_mut() {
            Self::Tcp(stream) => Pin::new(stream).poll_write(cx, buf),
            Self::Unix(stream) => Pin::new(stream).poll_write(cx, buf),
        }
    }

    fn poll_write_vectored(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        bufs: &[IoSlice<'_>],
    ) -> Poll<io::Result<usize>> {
        match self.get_mut() {
            Self::Tcp(stream) => Pin::new(stream).poll_write_vectored(cx, bufs),
            Self::Unix(stream) => Pin::new(stream).poll_write_vectored(cx, bufs),
        }
    }

    fn is_write_vectored(&self) -> bool {
        match self {
            Self::Tcp(stream) => stream.is_write_vectored(),
            Self::Unix(stream) => stream.is_write_vectored(),
        }
    }

    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
        match self.get_mut() {
            Self::Tcp(stream) => Pin::new(stream).poll_flush(cx),
            Self::Unix(stream) => Pin::new(stream).poll_flush(cx),
        }
    }

    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
        match self.get_mut() {
            Self::Tcp(stream) => Pin::new(stream).poll_shutdown(cx),
            Self::Unix(stream) => Pin::new(stream).poll_shutdown(cx),
        }
    }
}