use byte_strings::c_str;
use indymilter::{
Actions, Callbacks, Context, ContextActions, EomContext, MacroStage, NegotiateContext,
SocketInfo, Status,
};
use std::{
collections::HashMap,
env,
net::IpAddr,
process,
sync::{Arc, Mutex},
};
use tokio::{net::TcpListener, signal};
struct TlsData {
ip: Option<IpAddr>,
tls_version: Option<String>,
}
type Frequencies = HashMap<String, usize>;
#[tokio::main]
async fn main() {
let args = env::args().collect::<Vec<_>>();
if args.len() != 2 {
eprintln!("usage: {} <socket>", args[0]);
process::exit(1);
}
let listener = TcpListener::bind(&args[1])
.await
.expect("cannot open milter socket");
let tls_versions = Arc::new(Mutex::new(Frequencies::new()));
let tls_versions_eom = tls_versions.clone();
let callbacks = Callbacks::new()
.on_negotiate(|context, _, _| Box::pin(negotiate(context)))
.on_connect(|context, _, socket_info| Box::pin(connect(context, socket_info)))
.on_helo(|context, _| Box::pin(helo(context)))
.on_eom(move |context| Box::pin(eom(tls_versions_eom.clone(), context)))
.on_close(|context| Box::pin(close(context)));
let config = Default::default();
let shutdown = signal::ctrl_c();
indymilter::run(listener, callbacks, config, shutdown)
.await
.expect("milter execution failed");
let tls_versions = tls_versions.lock().unwrap();
println!("Frequencies of TLS versions seen:");
println!("{tls_versions:#?}");
}
async fn negotiate(context: &mut NegotiateContext<TlsData>) -> Status {
context.requested_actions |= Actions::ADD_HEADER;
let macros = c_str!("{tls_version}");
context.requested_macros.insert(MacroStage::Helo, macros.into());
Status::Continue
}
async fn connect(context: &mut Context<TlsData>, socket_info: SocketInfo) -> Status {
let ip = match socket_info {
SocketInfo::Inet(addr) => Some(addr.ip()),
_ => None,
};
let tls_data = TlsData {
ip,
tls_version: None,
};
context.data = Some(tls_data);
Status::Continue
}
async fn helo(context: &mut Context<TlsData>) -> Status {
if let Some(tls_data) = &mut context.data {
if let Some(tls_version) = context.macros.get(c_str!("{tls_version}")) {
let tls_version = tls_version.to_string_lossy();
tls_data.tls_version = Some(tls_version.into());
}
}
Status::Continue
}
async fn eom(tls_versions: Arc<Mutex<Frequencies>>, context: &mut EomContext<TlsData>) -> Status {
if let Some(TlsData { ip, tls_version }) = context.data.take() {
let ip = ip.map_or_else(|| "unknown".to_owned(), |ip| ip.to_string());
let tls_version = tls_version.unwrap_or_else(|| "none".to_owned());
let name = "TLS-Version-Info";
let value = format!("ip={ip} tls-version={tls_version}");
if let Err(e) = context.actions.add_header(name, value).await {
eprintln!("failed to add header: {e}");
return Status::Tempfail;
}
let mut tls_versions = tls_versions.lock().unwrap();
*tls_versions.entry(tls_version).or_insert(0) += 1;
}
Status::Continue
}
async fn close(context: &mut Context<TlsData>) -> Status {
context.data = None;
Status::Continue
}