spamassassin-milter 0.0.3

Milter for spam filtering with SpamAssassin
Documentation
use crate::{
    collections::{StrVecMap, StrVecSet},
    config::Config,
    error::{Error, Result},
};
use milter::ActionContext;
use once_cell::sync::Lazy;
use std::{
    cmp,
    fmt::{self, Display, Formatter},
    str,
};

#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
pub struct Header<'a> {
    pub name: &'a str,
    pub value: &'a str,
}

#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
pub struct Email<'a> {
    pub header: Vec<Header<'a>>,
    pub body: &'a [u8],
}

impl<'a> Email<'a> {
    pub fn parse(bytes: &'a [u8]) -> Result<Self> {
        let (header, body) = split_at_eoh(bytes)?;

        let header = header_lines(header)
            .into_iter()
            .map(parse_header_line)
            .collect::<Result<Vec<_>>>()?;

        Ok(Self { header, body })
    }
}

fn split_at_eoh(bytes: &[u8]) -> Result<(&[u8], &[u8])> {
    bytes
        .windows(4)
        .position(|w| w == b"\r\n\r\n")
        .map(|i| (&bytes[..(i + 2)], &bytes[(i + 4)..]))
        .ok_or(Error::ParseEmail)
}

fn header_lines(header: &[u8]) -> Vec<&[u8]> {
    let mut lines = Vec::new();

    let mut i = 0;
    let mut start = i;

    while i < header.len() {
        // Assume line breaks are always encoded as b"\r\n".
        if header[i] == b'\r' && i + 1 < header.len() && header[i + 1] == b'\n' {
            if i + 2 < header.len() && (header[i + 2] == b' ' || header[i + 2] == b'\t') {
                i += 3;
            } else {
                lines.push(&header[start..i]);
                i += 2;
                start = i;
            }
        } else {
            i += 1;
        }
    }

    if start != i {
        lines.push(&header[start..i]);
    }

    lines
}

fn parse_header_line(bytes: &[u8]) -> Result<Header<'_>> {
    let line = str::from_utf8(bytes).map_err(|_| Error::ParseEmail)?;
    let (name, value) = line.split_at(line.find(':').ok_or(Error::ParseEmail)?);

    if name.trim().is_empty() {
        return Err(Error::ParseEmail);
    }

    let value = &value[1..];

    Ok(Header { name, value })
}

pub fn ensure_crlf(s: &str) -> String {
    // For symmetry, ensure existing occurrences of "\r\n" remain unchanged.
    s.split('\n')
        .map(|line| match line.as_bytes().last() {
            Some(&last) if last == b'\r' => &line[..(line.len() - 1)],
            _ => line,
        })
        .collect::<Vec<_>>()
        .join("\r\n")
}

pub fn ensure_lf(s: &str) -> String {
    s.replace("\r\n", "\n")
}

pub fn is_spam_assassin_header(name: &str) -> bool {
    let prefix = b"X-Spam-";
    let name = name.as_bytes();

    name[..cmp::min(prefix.len(), name.len())].eq_ignore_ascii_case(prefix)
}

// Values use CRLF line breaks and include leading whitespace.
pub type HeaderMap = StrVecMap<String, String>;
pub type HeaderSet<'e> = StrVecSet<&'e str>;

// Selected subset of ‘X-Spam-’ headers for which we assume responsibility.
pub static SPAM_ASSASSIN_HEADERS: Lazy<HeaderSet<'static>> = Lazy::new(|| {
    let mut h = HeaderSet::new();
    h.insert("X-Spam-Checker-Version");
    h.insert("X-Spam-Flag");
    h.insert("X-Spam-Level");
    h.insert("X-Spam-Status");
    h.insert("X-Spam-Report");
    h
});

pub static REWRITE_HEADERS: Lazy<HeaderSet<'static>> = Lazy::new(|| {
    let mut h = HeaderSet::new();
    h.insert("Subject");
    h.insert("From");
    h.insert("To");
    h
});

pub static REPORT_HEADERS: Lazy<HeaderSet<'static>> = Lazy::new(|| {
    let mut h = HeaderSet::new();
    h.insert("MIME-Version");
    h.insert("Content-Type");
    h
});

/// A header rewriter that processes headers returned by SpamAssassin, and
/// computes and applies modifications by referring back to the original set of
/// headers. The rewriter operates only on the first occurrence of headers with
/// the same name.
#[derive(Clone, Debug)]
pub struct HeaderRewriter<'a> {
    original: HeaderMap,
    processed: HeaderSet<'a>,
    spam_assassin_mods: Vec<HeaderMod<'a>>,
    rewrite_mods: Vec<HeaderMod<'a>>,
    report_mods: Vec<HeaderMod<'a>>,
    config: &'a Config,
}

impl<'a> HeaderRewriter<'a> {
    pub fn new(original: HeaderMap, config: &'a Config) -> Self {
        Self {
            original,
            processed: HeaderSet::new(),
            spam_assassin_mods: vec![],
            rewrite_mods: vec![],
            report_mods: vec![],
            config,
        }
    }

    pub fn process_header(&mut self, name: &'a str, value: &'a str) {
        // Assumes that the value is normalised to using CRLF line breaks, and
        // includes leading whitespace.
        if is_spam_assassin_header(name) {
            if let Some(m) = self.convert_to_header_mod(name, value) {
                self.spam_assassin_mods.push(m);
            }
        } else if REWRITE_HEADERS.contains(name) {
            if let Some(m) = self.convert_to_header_mod(name, value) {
                self.rewrite_mods.push(m);
            }
        } else if REPORT_HEADERS.contains(name) {
            if let Some(m) = self.convert_to_header_mod(name, value) {
                self.report_mods.push(m);
            }
        }
    }

    fn convert_to_header_mod(&mut self, name: &'a str, value: &'a str) -> Option<HeaderMod<'a>> {
        if !self.processed.insert(name) {
            return None;
        }

        match self.original.get(name) {
            Some(original_value) => {
                if original_value != value {
                    Some(HeaderMod::Replace { name, value })
                } else {
                    None
                }
            }
            None => Some(HeaderMod::Add { name, value }),
        }
    }

    pub fn is_flagged_spam(&self) -> bool {
        use HeaderMod::*;

        self.spam_assassin_mods.iter().any(|m| match m {
            Add { name, value } | Replace { name, value } => {
                name.eq_ignore_ascii_case("X-Spam-Flag") && value.trim().eq_ignore_ascii_case("YES")
            }
            _ => false,
        })
    }

    pub fn rewrite_spam_assassin_headers(
        &self,
        id: &str,
        actions: &impl ActionContext,
    ) -> milter::Result<()> {
        execute_mods(id, self.spam_assassin_mods.iter(), actions, self.config)?;

        // Delete certain incoming SpamAssassin headers not returned by
        // SpamAssassin, to get rid of foreign `X-Spam-Flag` etc. headers.
        let deletions = SPAM_ASSASSIN_HEADERS.iter()
            .filter(|n| self.original.contains_key(n) && !self.processed.contains(n))
            .map(|name| HeaderMod::Delete { name })
            .collect::<Vec<_>>();

        execute_mods(id, deletions.iter(), actions, self.config)
    }

    pub fn rewrite_rewrite_headers(
        &self,
        id: &str,
        actions: &impl ActionContext,
    ) -> milter::Result<()> {
        execute_mods(id, self.rewrite_mods.iter(), actions, self.config)
    }

    pub fn rewrite_report_headers(
        &self,
        id: &str,
        actions: &impl ActionContext,
    ) -> milter::Result<()> {
        execute_mods(id, self.report_mods.iter(), actions, self.config)
    }
}

fn execute_mods<'a, I>(
    id: &str,
    mods: I,
    actions: &impl ActionContext,
    config: &Config,
) -> milter::Result<()>
where
    I: IntoIterator<Item = &'a HeaderMod<'a>>,
{
    Ok(for m in mods.into_iter() {
        if config.dry_run() {
            verbose!(config, "{}: rewriting header: {} [dry-run, not done]", id, m);
        } else {
            verbose!(config, "{}: rewriting header: {}", id, m);
            m.execute(actions)?;
        }
    })
}

pub fn replace_body(
    id: &str,
    body: &[u8],
    actions: &impl ActionContext,
    config: &Config,
) -> milter::Result<()> {
    Ok(if config.dry_run() {
        verbose!(config, "{}: replacing message body [dry-run, not done]", id);
    } else {
        verbose!(config, "{}: replacing message body", id);
        actions.append_body_chunk(body)?;
    })
}

/// A header rewriting modification operation. These are intended to operate
/// only on the first instance of headers occurring multiple times.
#[derive(Clone, Copy, Debug)]
enum HeaderMod<'a> {
    Add { name: &'a str, value: &'a str },
    Replace { name: &'a str, value: &'a str },
    Delete { name: &'a str },
}

impl HeaderMod<'_> {
    fn execute(&self, actions: &impl ActionContext) -> milter::Result<()> {
        use HeaderMod::*;

        // The milter library is smart enough to treat the name in a
        // case-insensitive manner, eg ‘Subject’ may replace ‘sUbject’.
        match self {
            Add { name, value } => actions.add_header(name, &ensure_lf(value)),
            Replace { name, value } => actions.replace_header(name, 1, Some(&ensure_lf(value))),
            Delete { name } => actions.replace_header(name, 1, None),
        }
    }
}

impl Display for HeaderMod<'_> {
    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
        use HeaderMod::*;

        match self {
            Add { name, .. } => write!(f, "add header \"{}\"", name),
            Replace { name, .. } => write!(f, "replace header \"{}\"", name),
            Delete { name } => write!(f, "delete header \"{}\"", name),
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn email_split_at_eoh() {
        assert_eq!(split_at_eoh(b"x\r\n\r\ny"), Ok((b"x\r\n" as &[_], b"y" as &[_])));
        assert_eq!(split_at_eoh(b"x\r\n\r\n"), Ok((b"x\r\n" as &[_], b"" as &[_])));
        assert_eq!(split_at_eoh(b"\r\n\r\ny"), Ok((b"\r\n" as &[_], b"y" as &[_])));
        assert_eq!(split_at_eoh(b"\r\ny"), Err(Error::ParseEmail));
        assert_eq!(split_at_eoh(b"y"), Err(Error::ParseEmail));
    }

    #[test]
    fn email_header_lines_empty() {
        assert_eq!(header_lines(b""), Vec::<&[_]>::new());
        assert_eq!(header_lines(b"\r\n"), vec![b"" as &[_]]);
        assert_eq!(header_lines(b"\r\n\r\n"), vec![b"" as &[_], b"" as &[_]]);
    }

    #[test]
    fn email_header_lines_simple() {
        assert_eq!(header_lines(b"x\r\n"), vec![b"x" as &[_]]);
        assert_eq!(header_lines(b"x\r\ny"), vec![b"x" as &[_], b"y" as &[_]]);
        assert_eq!(header_lines(b"x\r\ny\r\n"), vec![b"x" as &[_], b"y" as &[_]]);
    }

    #[test]
    fn email_header_lines_multi() {
        assert_eq!(header_lines(b"x\r\n\t"), vec![b"x\r\n\t" as &[_]]);
        assert_eq!(header_lines(b"x\r\n\ty"), vec![b"x\r\n\ty" as &[_]]);
        assert_eq!(header_lines(b"x\r\n\ty\r\n"), vec![b"x\r\n\ty" as &[_]]);
        assert_eq!(
            header_lines(b"x\r\n\ty\r\n\tz\r\nq"),
            vec![b"x\r\n\ty\r\n\tz" as &[_], b"q" as &[_]]
        );
    }

    #[test]
    fn email_parse_header_line() {
        assert_eq!(parse_header_line(b"no colon"), Err(Error::ParseEmail));
        assert_eq!(parse_header_line(b":empty name"), Err(Error::ParseEmail));
        assert_eq!(parse_header_line(b"\t : whitespace name"), Err(Error::ParseEmail));
        assert_eq!(parse_header_line(b"name:value"), Ok(Header { name: "name", value: "value" }));
        assert_eq!(parse_header_line(b"name: value"), Ok(Header { name: "name", value: " value" }));
        assert_eq!(
            parse_header_line(b"name:\r\n\tvalue"),
            Ok(Header { name: "name", value: "\r\n\tvalue" })
        );
    }

    #[test]
    fn ensure_crlf_ok() {
        assert_eq!(&ensure_crlf(""), "");
        assert_eq!(&ensure_crlf("\n"), "\r\n");
        assert_eq!(&ensure_crlf("\r\n"), "\r\n");
        assert_eq!(&ensure_crlf("a\nb"), "a\r\nb");
        assert_eq!(&ensure_crlf("a\n\nb"), "a\r\n\r\nb");
        assert_eq!(&ensure_crlf("a\r\n\nb"), "a\r\n\r\nb");
        assert_eq!(&ensure_crlf("a\n\r\nb"), "a\r\n\r\nb");
        assert_eq!(&ensure_crlf("a\r\nb\n"), "a\r\nb\r\n");
    }

    #[test]
    fn ensure_lf_ok() {
        assert_eq!(&ensure_lf(""), "");
        assert_eq!(&ensure_lf("\n"), "\n");
        assert_eq!(&ensure_lf("\r\n"), "\n");
        assert_eq!(&ensure_lf("a\nb"), "a\nb");
        assert_eq!(&ensure_lf("a\n\nb"), "a\n\nb");
        assert_eq!(&ensure_lf("a\r\n\nb"), "a\n\nb");
        assert_eq!(&ensure_lf("a\n\r\nb"), "a\n\nb");
        assert_eq!(&ensure_lf("a\r\nb\n"), "a\nb\n");
    }

    #[test]
    fn spam_assassin_header_predicate() {
        assert!(is_spam_assassin_header("x-spam-status"));
        assert!(is_spam_assassin_header("x-spam-bogus"));
        assert!(is_spam_assassin_header("x-spam-"));
        assert!(!is_spam_assassin_header("x-spam"));
        assert!(!is_spam_assassin_header("bogus"));
    }

    #[test]
    fn header_rewriter_flags_spam() {
        let mut headers = HeaderMap::new();
        headers.insert(String::from("x-spam-flag"), String::from(" no"));
        let config = Default::default();

        let mut rewriter = HeaderRewriter::new(headers, &config);
        rewriter.process_header("X-Spam-Flag", " YES");

        assert!(rewriter.is_flagged_spam());
    }

    #[test]
    fn header_rewriter_processes_first_occurrence_only() {
        let headers = HeaderMap::new();
        let config = Default::default();

        let mut rewriter = HeaderRewriter::new(headers, &config);
        rewriter.process_header("X-Spam-Flag", " NO");
        rewriter.process_header("X-Spam-Flag", " YES");

        let mut mods = rewriter.spam_assassin_mods.into_iter();
        match mods.next().unwrap() {
            HeaderMod::Add { name, value } => {
                assert_eq!(name, "X-Spam-Flag");
                assert_eq!(value, " NO");
            }
            _ => panic!(),
        }
        assert!(mods.next().is_none());
    }

    #[test]
    fn header_rewriter_replaces_different_values() {
        let mut headers = HeaderMap::new();
        headers.insert(String::from("x-spam-level"), String::from(" ***"));
        headers.insert(String::from("x-spam-report"), String::from(" original"));
        let config = Default::default();

        let mut rewriter = HeaderRewriter::new(headers, &config);
        rewriter.process_header("X-Spam-Level", " ***");
        rewriter.process_header("X-Spam-Report", " new");

        let mut mods = rewriter.spam_assassin_mods.into_iter();
        match mods.next().unwrap() {
            HeaderMod::Replace { name, value } => {
                assert_eq!(name, "X-Spam-Report");
                assert_eq!(value, " new");
            }
            _ => panic!(),
        }
        assert!(mods.next().is_none());
    }
}