indymilter 0.3.0

Asynchronous milter library
Documentation
// indymilter – asynchronous milter library
// Copyright © 2021–2022 David Bürgin <dbuergin@gluet.ch>
//
// This program is free software: you can redistribute it and/or modify it under
// the terms of the GNU General Public License as published by the Free Software
// Foundation, either version 3 of the License, or (at your option) any later
// version.
//
// This program is distributed in the hope that it will be useful, but WITHOUT
// ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
// FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
// details.
//
// You should have received a copy of the GNU General Public License along with
// this program. If not, see <https://www.gnu.org/licenses/>.

use crate::message::{self, reply::Reply, Message};
use std::{io, time::Duration};
use tokio::{
    io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufStream},
    sync::{mpsc, oneshot},
    time,
};
use tracing::trace;

enum Request {
    WriteMsg {
        msg: Message,
        response: oneshot::Sender<io::Result<()>>,
    },
    ReadMsg {
        response: oneshot::Sender<io::Result<Message>>,
    },
}

struct StreamHandler<S>
where
    S: AsyncRead + AsyncWrite + Unpin + Send,
{
    stream: BufStream<S>,
}

impl<S> StreamHandler<S>
where
    S: AsyncRead + AsyncWrite + Unpin + Send,
{
    fn new(stream: S) -> Self {
        Self {
            stream: BufStream::new(stream),
        }
    }

    // Consumes this stream handler and handles requests for as long as requests
    // can be received on the given `Receiver`. Use with `tokio::spawn` to turn
    // this stream handler into an *actor*.
    async fn handle_requests(mut self, mut conn: mpsc::Receiver<Request>, timeout: Duration) {
        // The actor’s task is fail safe. It does not exit – and therefore does
        // not unexpectedly drop the `mpsc::Receiver` (or similarly the
        // `oneshot::Sender`) – until the `Connection` is dropped.

        while let Some(req) = conn.recv().await {
            match req {
                Request::ReadMsg { response } => {
                    let f = message::read(&mut self.stream);

                    let result = match time::timeout(timeout, f).await {
                        Ok(r) => r,
                        Err(e) => Err(e.into()),
                    };

                    let _ = response.send(result);
                }
                Request::WriteMsg { msg, response } => {
                    let f = message::write(&mut self.stream, msg);

                    let result = match time::timeout(timeout, f).await {
                        Ok(r) => r,
                        Err(e) => Err(e.into()),
                    };

                    let _ = response.send(result);
                }
            }
        }

        // When this actor exits it also shuts down and drops the wrapped
        // stream. An error result is no longer of interest.

        let _ = self.stream.shutdown().await;
    }
}

// A connection can be cloned. Cloned connections all carry a handle to the
// message-processing task (actor) holding the actual connection stream. Once
// the last `Connection` is dropped, so is the `mpsc::Sender` handle and the
// task then exits.
//
// Note that in a milter session, a connection is only ever read from or written
// to serially. However, during the eom stage, a clone of the connection is used
// to write replies from the eom callback.
#[derive(Clone)]
pub struct Connection {
    conn: mpsc::Sender<Request>,
}

impl Connection {
    pub fn new<S>(stream: S, timeout: Duration) -> Self
    where
        S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
    {
        let (messages_tx, messages_rx) = mpsc::channel(1);

        // Creating a `Connection` spawns a task that processes requests for as
        // long as the connection (or a clone) exists. When the last
        // `mpsc::Sender` held by a connection is dropped, the task exits.

        // The stream handler takes exclusive ownership of the stream, and is
        // handed off directly to `tokio::spawn`, becoming an *actor* solely
        // responsible for reading and writing messages to the stream.

        let handler = StreamHandler::new(stream);

        tokio::spawn(handler.handle_requests(messages_rx, timeout));

        Self { conn: messages_tx }
    }

    pub async fn read_message(&self) -> io::Result<Message> {
        let (response_tx, response) = oneshot::channel();

        let request = Request::ReadMsg {
            response: response_tx,
        };

        let result = self.do_request(request, response).await;

        if let Ok(msg) = &result {
            trace!(?msg, "message read");
        }

        result
    }

    pub async fn write_reply(&self, reply: Reply) -> io::Result<()> {
        let msg = reply.into_message();

        self.write_message(msg).await
    }

    pub async fn write_message(&self, msg: Message) -> io::Result<()> {
        let (response_tx, response) = oneshot::channel();

        trace!(?msg, "writing message");

        let request = Request::WriteMsg {
            msg,
            response: response_tx,
        };

        self.do_request(request, response).await
    }

    async fn do_request<T>(&self, request: Request, response: oneshot::Receiver<T>) -> T {
        self.conn
            .send(request)
            .await
            .unwrap_or_else(|_| panic!("connection stream closed"));

        response.await.expect("connection stream exited")
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::io::ErrorKind;
    use tokio::{io::AsyncReadExt, join};

    #[tokio::test]
    async fn multiple_connections() {
        let (mut client, stream) = tokio::io::duplex(100);

        let conn1 = Connection::new(stream, Duration::from_secs(30));
        let conn2 = conn1.clone();

        // First, read from and write to the first connection.

        client.write_all(b"\0\0\0\x03xyz").await.unwrap();

        let msg = conn1.read_message().await.unwrap();
        assert_eq!(msg, Message::new(b'x', "yz"));
        conn1.write_message(msg).await.unwrap();

        let mut buffer = vec![0; 7];
        client.read_exact(&mut buffer).await.unwrap();
        assert_eq!(buffer, b"\0\0\0\x03xyz");

        // Drop the first connection and continue using the second connection.

        drop(conn1);

        let msg = Message::new(b'x', "abc");
        conn2.write_message(msg).await.unwrap();

        let mut buffer = vec![0; 8];
        client.read_exact(&mut buffer).await.unwrap();
        assert_eq!(buffer, b"\0\0\0\x04xabc");

        // Drop the second and last remaining connection, closing the stream.

        drop(conn2);

        let e = client.read_u8().await.unwrap_err();
        assert_eq!(e.kind(), ErrorKind::UnexpectedEof);
    }

    #[tokio::test]
    async fn connection_timeout() {
        let timeout = Duration::from_secs(30);

        let (mut client, stream) = tokio::io::duplex(100);
        let conn = Connection::new(stream, timeout);

        // Both the `Connection` and the client end of the duplex stream are
        // moved into the futures. The connection is therefore dropped and
        // closes after 30 seconds. The client write then fails as it attempts
        // to write a few seconds later.

        time::pause();

        let (stream_result, client_result) = join!(
            async move { conn.read_message().await },
            async move {
                client.write_all(b"\0\0\0\x05").await.unwrap();
                time::sleep(timeout + Duration::from_secs(5)).await;
                client.write_all(b"Xyzabc").await
            },
        );

        time::resume();

        let e = stream_result.unwrap_err();
        assert_eq!(e.kind(), ErrorKind::TimedOut);

        let e = client_result.unwrap_err();
        assert_eq!(e.kind(), ErrorKind::BrokenPipe);
    }
}