mod common;
pub use crate::common::*;
use byte_strings::c_str;
use indymilter::{
message::{
command::{Command, ConnInfoPayload, OptNegPayload},
reply::Reply,
PROTOCOL_VERSION,
},
Actions, Callbacks, Config, ProtoOpts, SocketInfo, Status,
};
use rand::Rng;
use std::{
collections::HashMap,
io,
net::SocketAddr,
ops::RangeInclusive,
sync::{
atomic::{AtomicUsize, Ordering},
Arc, Mutex,
},
time::Duration,
};
use tokio::time;
#[derive(Debug)]
enum Stage {
Connect,
Close,
}
type Id = usize;
#[tokio::test]
async fn max_connections() {
init_tracing_subscriber();
let client_count = 50;
let max_connections = 10;
let config = Config {
max_connections,
..default_config()
};
let events = Arc::new(Mutex::new(Vec::new()));
let callbacks = make_callbacks(&events);
let milter = Milter::spawn(LOCALHOST, callbacks, config).await.unwrap();
let addr = milter.addr();
time::pause();
let mut clients = Vec::new();
for _ in 0..client_count {
clients.push(tokio::spawn(run_client(addr)));
}
for c in clients {
c.await.unwrap().unwrap();
}
time::resume();
milter.shutdown().await.unwrap();
let events = Arc::try_unwrap(events).unwrap().into_inner().unwrap();
let event_count = events.len();
let mut sessions = HashMap::new();
for (i, (id, stage)) in events.into_iter().enumerate() {
sessions.entry(id).or_insert_with(Vec::new).push((i, stage));
}
let sessions = sessions
.into_iter()
.map(|(id, events)| match events[..] {
[(start, Stage::Connect), (end, Stage::Close)] => (id, RangeInclusive::new(start, end)),
_ => panic!("not a pair of connect/close events: {events:?}"),
})
.collect::<HashMap<_, _>>();
assert_eq!(sessions.len(), client_count);
eprintln!("{sessions:#?}");
let mut ranges = sessions.into_values().collect::<Vec<_>>();
for i in 0..event_count {
let active = ranges.iter().filter(|r| r.contains(&i)).count();
assert!(active <= max_connections);
}
ranges.sort_unstable_by_key(|r| *r.start());
for r in ranges {
eprint!("{}", " ".repeat(*r.start()));
eprint!("{}", "-".repeat(r.end() - r.start()));
eprintln!();
}
}
fn make_callbacks(events: &Arc<Mutex<Vec<(Id, Stage)>>>) -> Callbacks<Id> {
let session_id = Arc::new(AtomicUsize::new(0));
let events_connect = events.clone();
let events_close = events.clone();
Callbacks::new()
.on_connect(move |cx, _, _| {
let id = session_id.clone();
let events = events_connect.clone();
Box::pin(async move {
let id = id.fetch_add(1, Ordering::SeqCst);
cx.data = Some(id);
sleep_a_while().await;
events.lock().unwrap().push((id, Stage::Connect));
Status::Continue
})
})
.on_close(move |cx| {
let events = events_close.clone();
Box::pin(async move {
sleep_a_while().await;
if let Some(id) = cx.data.take() {
events.lock().unwrap().push((id, Stage::Close));
}
Status::Continue
})
})
}
async fn sleep_a_while() {
let ms = rand::thread_rng().gen_range(500..=1500);
time::sleep(Duration::from_millis(ms)).await;
}
async fn run_client(addr: SocketAddr) -> io::Result<()> {
let mut client = Client::connect(addr).await?;
client
.write_command(Command::OptNeg(OptNegPayload {
version: PROTOCOL_VERSION,
actions: Actions::all(),
opts: ProtoOpts::all(),
}))
.await?;
let reply = client.read_reply().await?;
assert!(matches!(reply, Reply::OptNeg { .. }));
client
.write_command(Command::ConnInfo(ConnInfoPayload {
hostname: c_str!("example.com").into(),
socket_info: SocketInfo::Unknown,
}))
.await?;
let reply = client.read_reply().await?;
assert_eq!(reply, Reply::Continue);
client.write_command(Command::Quit).await?;
client.disconnect().await?;
Ok(())
}