indymilter 0.3.0

Asynchronous milter library
Documentation
// indymilter – asynchronous milter library
// Copyright © 2021–2024 David Bürgin <dbuergin@gluet.ch>
//
// This program is free software: you can redistribute it and/or modify it under
// the terms of the GNU General Public License as published by the Free Software
// Foundation, either version 3 of the License, or (at your option) any later
// version.
//
// This program is distributed in the hope that it will be useful, but WITHOUT
// ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
// FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
// details.
//
// You should have received a copy of the GNU General Public License along with
// this program. If not, see <https://www.gnu.org/licenses/>.

//! Milter protocol messages.
//!
//! This module contains low-level protocol helpers.

pub mod command;
pub mod reply;

use bytes::Bytes;
use std::{
    ascii,
    error::Error,
    ffi::{CStr, CString},
    fmt::{self, Debug, Display, Formatter, Write},
    io::{self, ErrorKind},
};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};

/// The type of the milter protocol version.
pub type Version = u32;

/// Milter protocol version.
pub const PROTOCOL_VERSION: Version = 6;

/// A milter protocol message.
#[derive(Clone, Eq, Hash, PartialEq)]
pub struct Message {
    /// The message kind.
    pub kind: u8,
    /// The message payload buffer.
    pub buffer: Bytes,
}

impl Message {
    // Limit length of message buffer somewhat arbitrarily to 1 MB. Else if
    // someone erroneously sends a jumbo message we would spend a lot of time
    // and memory reading/writing it.
    //
    // Support for the experimental, negotiable custom `Maxdatasize` feature of
    // libmilter is not implemented.
    const MAX_BUFFER_LEN: usize = 1024 * 1024 - 1;

    /// Creates a new message with the given kind and payload buffer.
    pub fn new(kind: impl Into<u8>, buffer: impl Into<Bytes>) -> Self {
        Self {
            kind: kind.into(),
            buffer: buffer.into(),
        }
    }
}

impl Debug for Message {
    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
        f.debug_struct("Message")
            .field("kind", &Byte(self.kind))
            .field("buffer", &self.buffer)
            .finish()
    }
}

pub(crate) struct Byte(pub u8);

impl Debug for Byte {
    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
        f.write_str("b'")?;

        match self.0 {
            b'\0' => f.write_str("\\0")?,
            b'"' => f.write_str("\"")?,
            byte => {
                for c in ascii::escape_default(byte) {
                    f.write_char(c.into())?;
                }
            }
        }

        f.write_str("'")?;

        Ok(())
    }
}

/// An error that occurs when conversion from a wire format byte fails.
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
pub struct TryFromByteError(u8);

impl TryFromByteError {
    pub(crate) fn new(byte: u8) -> Self {
        Self(byte)
    }

    /// Returns the byte that caused the conversion error.
    pub fn byte(&self) -> u8 {
        self.0
    }
}

impl Display for TryFromByteError {
    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
        write!(f, "failed to convert byte {:?}", Byte(self.0))
    }
}

impl Error for TryFromByteError {}

/// An error that occurs when conversion from a wire format index fails.
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
pub struct TryFromIndexError(i32);

impl TryFromIndexError {
    pub(crate) fn new(index: i32) -> Self {
        Self(index)
    }

    /// Returns the index that caused the conversion error.
    pub fn index(&self) -> i32 {
        self.0
    }
}

impl Display for TryFromIndexError {
    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
        write!(f, "failed to convert index {}", self.0)
    }
}

impl Error for TryFromIndexError {}

/// Reads a milter protocol message from a stream.
pub async fn read<S>(stream: &mut S) -> io::Result<Message>
where
    S: AsyncRead + Unpin,
{
    let len = stream.read_u32().await?;
    let kind = stream.read_u8().await?;

    let len = usize::try_from(len)
        .expect("unsupported pointer size")
        .saturating_sub(1);

    if len > Message::MAX_BUFFER_LEN {
        return Err(ErrorKind::InvalidData.into());
    }

    let mut buffer = vec![0; len];
    stream.read_exact(&mut buffer).await?;

    Ok(Message::new(kind, buffer))
}

/// Writes a milter protocol message to a stream.
pub async fn write<S>(stream: &mut S, msg: Message) -> io::Result<()>
where
    S: AsyncWrite + Unpin,
{
    let len = msg.buffer.len();

    if len > Message::MAX_BUFFER_LEN {
        return Err(ErrorKind::InvalidData.into());
    }

    let len = u32::try_from(len).unwrap().checked_add(1).unwrap();

    stream.write_u32(len).await?;
    stream.write_u8(msg.kind).await?;
    stream.write_all(msg.buffer.as_ref()).await?;
    stream.flush().await?;

    Ok(())
}

struct NoCStringFoundError;

fn get_c_string(buf: &mut Bytes) -> Result<CString, NoCStringFoundError> {
    if let Some(i) = buf.iter().position(|&x| x == 0) {
        let b = buf.split_to(i + 1);
        return Ok(CStr::from_bytes_with_nul(b.as_ref()).unwrap().into());
    }
    Err(NoCStringFoundError)
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn byte_debug_ok() {
        assert_eq!(format!("{:?}", Byte(b'\0')), r"b'\0'");
        assert_eq!(format!("{:?}", Byte(b'\n')), r"b'\n'");
        assert_eq!(format!("{:?}", Byte(b'\'')), r"b'\''");
        assert_eq!(format!("{:?}", Byte(b'"')), r#"b'"'"#);
        assert_eq!(format!("{:?}", Byte(b'a')), r"b'a'");
        assert_eq!(format!("{:?}", Byte(b' ')), r"b' '");
        assert_eq!(format!("{:?}", Byte(b'\x07')), r"b'\x07'");
        assert_eq!(format!("{:?}", Byte(b'\xef')), r"b'\xef'");
    }

    #[test]
    fn message_debug_ok() {
        assert_eq!(
            format!("{:?}", Message::new(b'A', "abc\0")),
            r#"Message { kind: b'A', buffer: b"abc\0" }"#
        );
        assert_eq!(
            format!("{:?}", Message::new(b'\0', "")),
            r#"Message { kind: b'\0', buffer: b"" }"#
        );
    }

    #[test]
    fn try_from_byte_error_debug() {
        assert_eq!(
            TryFromByteError(b'x').to_string(),
            "failed to convert byte b'x'"
        );
    }

    #[tokio::test]
    async fn read_message() {
        let mut stream = {
            let (mut client, stream) = tokio::io::duplex(100);

            client.write_all(b"\0\0\0\x02xyz").await.unwrap();

            stream
        };

        let msg = read(&mut stream).await.unwrap();
        assert_eq!(msg, Message::new(b'x', "y"));

        let error = read(&mut stream).await.unwrap_err();
        assert_eq!(error.kind(), ErrorKind::UnexpectedEof);
    }

    #[tokio::test]
    async fn write_message() {
        let mut client = {
            let (client, mut stream) = tokio::io::duplex(100);

            let msg = Message::new(b'x', "abc");
            write(&mut stream, msg).await.unwrap();
            stream.write_u8(1).await.unwrap();

            client
        };

        let mut buffer = vec![0; 8];
        client.read_exact(&mut buffer).await.unwrap();
        assert_eq!(buffer, b"\0\0\0\x04xabc");

        let byte = client.read_u8().await.unwrap();
        assert_eq!(byte, 1);

        let error = client.read_u8().await.unwrap_err();
        assert_eq!(error.kind(), ErrorKind::UnexpectedEof);
    }
}