indymilter 0.3.0

Asynchronous milter library
Documentation
mod common;

pub use crate::common::*;
use byte_strings::c_str;
use indymilter::{
    message::{
        command::{Command, ConnInfoPayload, OptNegPayload},
        reply::Reply,
        PROTOCOL_VERSION,
    },
    Actions, Callbacks, Config, ProtoOpts, SocketInfo, Status,
};
use rand::Rng;
use std::{
    collections::HashMap,
    io,
    net::SocketAddr,
    ops::RangeInclusive,
    sync::{
        atomic::{AtomicUsize, Ordering},
        Arc, Mutex,
    },
    time::Duration,
};
use tokio::time;

#[derive(Debug)]
enum Stage {
    Connect,
    Close,
}

type Id = usize;

#[tokio::test]
async fn max_connections() {
    init_tracing_subscriber();

    // Adjust these settings to observe different throttling behaviour:
    let client_count = 50;
    let max_connections = 10;

    let config = Config {
        max_connections,
        ..default_config()
    };

    let events = Arc::new(Mutex::new(Vec::new()));

    let callbacks = make_callbacks(&events);

    // Now execute the test scenario:
    // Start the milter with `max_connections`.
    // Spawn `client_count` clients that all try to communicate with the milter
    // at the same time.

    let milter = Milter::spawn(LOCALHOST, callbacks, config).await.unwrap();

    let addr = milter.addr();

    time::pause();

    let mut clients = Vec::new();
    for _ in 0..client_count {
        clients.push(tokio::spawn(run_client(addr)));
    }
    for c in clients {
        c.await.unwrap().unwrap();
    }

    time::resume();

    milter.shutdown().await.unwrap();

    // Now examine the recorded test events and make sure everything is as
    // expected. Especially no more than `max_connections` simultaneously active
    // clients.

    let events = Arc::try_unwrap(events).unwrap().into_inner().unwrap();
    let event_count = events.len();

    // Collect all the events that were recorded by the callbacks, mapping them
    // by the session ID key.

    let mut sessions = HashMap::new();
    for (i, (id, stage)) in events.into_iter().enumerate() {
        sessions.entry(id).or_insert_with(Vec::new).push((i, stage));
    }

    let sessions = sessions
        .into_iter()
        .map(|(id, events)| match events[..] {
            [(start, Stage::Connect), (end, Stage::Close)] => (id, RangeInclusive::new(start, end)),
            _ => panic!("not a pair of connect/close events: {events:?}"),
        })
        .collect::<HashMap<_, _>>();

    assert_eq!(sessions.len(), client_count);

    eprintln!("{sessions:#?}");

    // Prepare a vector of the ranges `start..=end` for all client ids. Then go
    // through all callback events, and check that at any point there were no
    // more than `max_connections` clients active. (Note: quadratic complexity.)

    let mut ranges = sessions.into_values().collect::<Vec<_>>();

    for i in 0..event_count {
        let active = ranges.iter().filter(|r| r.contains(&i)).count();
        assert!(active <= max_connections);
    }

    // Print a simple graphic of the different session lengths:
    ranges.sort_unstable_by_key(|r| *r.start());
    for r in ranges {
        eprint!("{}", " ".repeat(*r.start()));
        eprint!("{}", "-".repeat(r.end() - r.start()));
        eprintln!();
    }
}

fn make_callbacks(events: &Arc<Mutex<Vec<(Id, Stage)>>>) -> Callbacks<Id> {
    let session_id = Arc::new(AtomicUsize::new(0));

    let events_connect = events.clone();
    let events_close = events.clone();

    Callbacks::new()
        .on_connect(move |cx, _, _| {
            let id = session_id.clone();
            let events = events_connect.clone();

            Box::pin(async move {
                let id = id.fetch_add(1, Ordering::SeqCst);

                cx.data = Some(id);

                sleep_a_while().await;

                events.lock().unwrap().push((id, Stage::Connect));

                Status::Continue
            })
        })
        .on_close(move |cx| {
            let events = events_close.clone();

            Box::pin(async move {
                sleep_a_while().await;

                if let Some(id) = cx.data.take() {
                    events.lock().unwrap().push((id, Stage::Close));
                }

                Status::Continue
            })
        })
}

async fn sleep_a_while() {
    // Sleep between 0.5 and 1.5 seconds.
    let ms = rand::thread_rng().gen_range(500..=1500);
    time::sleep(Duration::from_millis(ms)).await;
}

async fn run_client(addr: SocketAddr) -> io::Result<()> {
    let mut client = Client::connect(addr).await?;

    client
        .write_command(Command::OptNeg(OptNegPayload {
            version: PROTOCOL_VERSION,
            actions: Actions::all(),
            opts: ProtoOpts::all(),
        }))
        .await?;

    let reply = client.read_reply().await?;
    assert!(matches!(reply, Reply::OptNeg { .. }));

    client
        .write_command(Command::ConnInfo(ConnInfoPayload {
            hostname: c_str!("example.com").into(),
            socket_info: SocketInfo::Unknown,
        }))
        .await?;

    let reply = client.read_reply().await?;
    assert_eq!(reply, Reply::Continue);

    client.write_command(Command::Quit).await?;

    client.disconnect().await?;

    Ok(())
}