spf-milter 0.6.0

Milter for SPF verification
Documentation
// SPF Milter – milter for SPF verification
// Copyright © 2020–2023 David Bürgin <dbuergin@gluet.ch>
//
// This program is free software: you can redistribute it and/or modify it under
// the terms of the GNU General Public License as published by the Free Software
// Foundation, either version 3 of the License, or (at your option) any later
// version.
//
// This program is distributed in the hope that it will be useful, but WITHOUT
// ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
// FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
// details.
//
// You should have received a copy of the GNU General Public License along with
// this program. If not, see <https://www.gnu.org/licenses/>.

use futures::stream::StreamExt;
use indymilter::Listener;
use signal_hook::consts::{SIGHUP, SIGINT, SIGTERM};
use signal_hook_tokio::{Handle, Signals};
use spf_milter::{CliOptions, Config, Socket, MILTER_NAME, VERSION};
use std::{
    env,
    error::Error,
    io::{stderr, stdout, Write},
    os::unix::fs::FileTypeExt,
    path::Path,
    process,
};
use tokio::{
    fs,
    net::{TcpListener, UnixListener},
    sync::{mpsc, oneshot},
    task::JoinHandle,
};

const PROGRAM_NAME: &str = env!("CARGO_BIN_NAME");

#[tokio::main]
async fn main() {
    let opts = match parse_args() {
        Ok(opts) => opts,
        Err(e) => {
            let _ = writeln!(stderr(), "{PROGRAM_NAME}: {e}");
            process::exit(1);
        }
    };

    let config = match Config::read(opts).await {
        Ok(config) => config,
        Err(e) => {
            let _ = writeln!(stderr(), "{PROGRAM_NAME}: {e}");
            process::exit(1);
        }
    };

    let (reload_tx, reload) = mpsc::channel(1);
    let (shutdown_tx, shutdown) = oneshot::channel();

    let signals =
        Signals::new([SIGHUP, SIGINT, SIGTERM]).expect("failed to install signal handler");
    let signals_handle = signals.handle();
    let signals_task = spawn_signals_task(signals, reload_tx, shutdown_tx);

    let addr;
    let mut socket_path = None;
    let listener = match config.socket() {
        Socket::Inet(addr) => {
            let listener = match TcpListener::bind(addr).await {
                Ok(listener) => listener,
                Err(e) => {
                    let _ = writeln!(stderr(), "{PROGRAM_NAME}: could not bind TCP socket: {e}");
                    process::exit(1);
                }
            };

            Listener::Tcp(listener)
        }
        Socket::Unix(path) => {
            // Before creating the socket file, try removing any existing socket
            // at the target path. This is to clear out a leftover file from a
            // previous, aborted execution.
            try_remove_socket(&path).await;

            let listener = match UnixListener::bind(path) {
                Ok(listener) => listener,
                Err(e) => {
                    let _ = writeln!(stderr(), "{PROGRAM_NAME}: could not create UNIX domain socket: {e}");
                    process::exit(1);
                }
            };

            // Remember the socket file path, and delete it on shutdown.
            addr = listener.local_addr().unwrap();
            socket_path = addr.as_pathname();

            Listener::Unix(listener)
        }
    };

    let result = spf_milter::run(listener, config, reload, shutdown).await;

    cleanup(signals_handle, signals_task, socket_path).await;

    if let Err(e) = result {
        let _ = writeln!(stderr(), "{PROGRAM_NAME}: {e}");
        process::exit(1);
    }
}

const USAGE_TEXT: &str = "\
[OPTIONS]

Options:
  -c, --config-file <PATH>          Path to configuration file
  -n, --dry-run                     Process messages without taking action
  -h, --help                        Print usage information
  -l, --log-destination <TARGET>    Destination for log messages
  -L, --log-level <LEVEL>           Minimum severity of messages to log
  -p, --socket <SOCKET>             Listening socket of the milter
  -s, --syslog-facility <NAME>      Facility to use for syslog messages
  -V, --version                     Print version information
";

fn parse_args() -> Result<CliOptions, Box<dyn Error>> {
    let mut args = env::args_os()
        .skip(1)
        .map(|s| s.into_string().map_err(|_| "invalid UTF-8 bytes in argument"));

    let mut opts = CliOptions::builder();

    while let Some(arg) = args.next() {
        let arg = arg?;

        let missing_value = || format!("missing value for option {arg}");

        match arg.as_str() {
            "-h" | "--help" => {
                write!(stdout(), "Usage: {PROGRAM_NAME} {USAGE_TEXT}")?;
                process::exit(0);
            }
            "-V" | "--version" => {
                writeln!(stdout(), "{MILTER_NAME} {VERSION}")?;
                process::exit(0);
            }
            "-n" | "--dry-run" => {
                opts = opts.dry_run(true);
            }
            "-c" | "--config-file" => {
                let path = args.next().ok_or_else(missing_value)??;

                opts = opts.config_file(path);
            }
            "-l" | "--log-destination" => {
                let arg = args.next().ok_or_else(missing_value)??;
                let target = arg.parse()
                    .map_err(|_| format!("invalid value for log destination: \"{arg}\""))?;

                opts = opts.log_destination(target);
            }
            "-L" | "--log-level" => {
                let arg = args.next().ok_or_else(missing_value)??;
                let level = arg.parse()
                    .map_err(|_| format!("invalid value for log level: \"{arg}\""))?;

                opts = opts.log_level(level);
            }
            "-p" | "--socket" => {
                let arg = args.next().ok_or_else(missing_value)??;
                let socket = arg.parse()
                    .map_err(|_| format!("invalid value for socket: \"{arg}\""))?;

                opts = opts.socket(socket);
            }
            "-s" | "--syslog-facility" => {
                let arg = args.next().ok_or_else(missing_value)??;
                let name = arg.parse()
                    .map_err(|_| format!("invalid value for syslog facility: \"{arg}\""))?;

                opts = opts.syslog_facility(name);
            }
            arg => return Err(format!("unrecognized option: \"{arg}\"").into()),
        }
    }

    Ok(opts.build())
}

fn spawn_signals_task(
    mut signals: Signals,
    reload_config: mpsc::Sender<()>,
    shutdown_milter: oneshot::Sender<()>,
) -> JoinHandle<()> {
    tokio::spawn(async move {
        while let Some(signal) = signals.next().await {
            match signal {
                SIGHUP => {
                    let _ = reload_config.send(()).await;
                }
                SIGINT | SIGTERM => {
                    let _ = shutdown_milter.send(());
                    break;
                }
                _ => panic!("unexpected signal"),
            }
        }
    })
}

async fn cleanup(signals_handle: Handle, signals_task: JoinHandle<()>, socket_path: Option<&Path>) {
    signals_handle.close();
    signals_task.await.expect("signal handler task failed");

    if let Some(p) = socket_path {
        try_remove_socket(p).await;
    }
}

async fn try_remove_socket(path: impl AsRef<Path>) {
    if let Ok(metadata) = fs::metadata(&path).await {
        if metadata.file_type().is_socket() {
            let _ = fs::remove_file(path).await;
        }
    }
}