use std::sync::{
Arc,
atomic::{AtomicU32, AtomicUsize, Ordering},
};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use serde::Serialize;
use tokio::io::AsyncWriteExt;
use tokio::sync::broadcast;
pub const TRACE_CHANNEL_CAPACITY: usize = 1_024;
pub const MAX_SUBSCRIBERS: usize = 4;
#[derive(Clone, Debug, Serialize)]
pub struct MatchTraceEvent {
pub event_id: u64,
pub schema_version: u8,
pub received_at_ms: u64,
pub duration_ms: u32,
pub request: RequestSummary,
pub outcome: Outcome,
pub dropped_count: u32,
}
#[derive(Clone, Debug, Serialize)]
pub struct RequestSummary {
pub method: String,
pub url_path: String,
pub headers: Vec<(String, String)>,
}
#[derive(Clone, Debug, Serialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum Outcome {
Matched { rule_set_index: usize, rule_index: usize },
Fallback { file_path: String, status: u16 },
Miss { status: u16 },
Error { kind: String, message: String },
}
#[derive(Clone)]
pub struct TraceEmitter {
sender: broadcast::Sender<MatchTraceEvent>,
event_counter: Arc<AtomicU32>,
dropped_counter: Arc<AtomicU32>,
}
impl TraceEmitter {
pub fn new() -> Self {
let (sender, _) = broadcast::channel(TRACE_CHANNEL_CAPACITY);
Self {
sender,
event_counter: Arc::new(AtomicU32::new(0)),
dropped_counter: Arc::new(AtomicU32::new(0)),
}
}
pub fn subscribe(&self) -> broadcast::Receiver<MatchTraceEvent> {
self.sender.subscribe()
}
pub fn emit(
&self,
received_at_ms: u64,
duration_ms: u32,
request: RequestSummary,
outcome: Outcome,
) {
let event_id = self.event_counter.fetch_add(1, Ordering::Relaxed) as u64;
let dropped_count = self.dropped_counter.swap(0, Ordering::Relaxed);
let event = MatchTraceEvent {
event_id,
schema_version: 1,
received_at_ms,
duration_ms,
request,
outcome,
dropped_count,
};
if self.sender.send(event).is_err() {
self.dropped_counter.fetch_add(1, Ordering::Relaxed);
}
}
pub fn has_subscribers(&self) -> bool {
self.sender.receiver_count() > 0
}
}
impl Default for TraceEmitter { fn default() -> Self { Self::new() } }
#[derive(Clone, Debug, Default)]
pub enum TraceTransportConfig {
#[cfg(unix)]
Uds { path: String },
Tcp { addr: String },
#[default]
Disabled,
}
pub struct TraceTransport;
impl TraceTransport {
pub async fn accept_loop(config: TraceTransportConfig, emitter: TraceEmitter) {
match config {
#[cfg(unix)]
TraceTransportConfig::Uds { path } => {
Self::uds_accept_loop(path, emitter).await
}
TraceTransportConfig::Tcp { addr } => {
Self::tcp_accept_loop(addr, emitter).await
}
TraceTransportConfig::Disabled => {
}
}
}
async fn tcp_accept_loop(addr: String, emitter: TraceEmitter) {
let listener = match tokio::net::TcpListener::bind(&addr).await {
Ok(l) => {
let bound = l.local_addr().map(|a| a.to_string())
.unwrap_or_else(|_| addr.clone());
log::info!("trace transport: TCP listening on {}", bound);
l
}
Err(e) => {
log::error!("trace transport: failed to bind TCP {}: {}", addr, e);
return;
}
};
let active = Arc::new(AtomicUsize::new(0));
loop {
match listener.accept().await {
Ok((stream, peer)) => {
let count = active.fetch_add(1, Ordering::Relaxed) + 1;
if count > MAX_SUBSCRIBERS {
active.fetch_sub(1, Ordering::Relaxed);
let active_clone = active.clone();
tokio::spawn(async move {
let (_, mut writer) = tokio::io::split(stream);
let _ = writer
.write_all(b"{\"error\":\"max_subscribers_reached\"}\n")
.await;
drop(active_clone);
});
continue;
}
log::debug!("trace: TCP subscriber connected from {}", peer);
let rx = emitter.subscribe();
let active_clone = active.clone();
tokio::spawn(async move {
let (_, writer) = tokio::io::split(stream);
Self::forward_events(writer, rx).await;
active_clone.fetch_sub(1, Ordering::Relaxed);
log::debug!("trace: TCP subscriber {} disconnected", peer);
});
}
Err(e) => {
log::error!("trace: TCP accept error: {}", e);
tokio::time::sleep(Duration::from_millis(100)).await;
}
}
}
}
#[cfg(unix)]
async fn uds_accept_loop(path: String, emitter: TraceEmitter) {
let _ = std::fs::remove_file(&path);
let listener = match tokio::net::UnixListener::bind(&path) {
Ok(l) => {
log::info!("trace transport: UDS listening at {}", path);
l
}
Err(e) => {
log::error!("trace transport: failed to bind UDS {}: {}", path, e);
return;
}
};
let active = Arc::new(AtomicUsize::new(0));
loop {
match listener.accept().await {
Ok((stream, _)) => {
let count = active.fetch_add(1, Ordering::Relaxed) + 1;
if count > MAX_SUBSCRIBERS {
active.fetch_sub(1, Ordering::Relaxed);
tokio::spawn(async move {
let (_, mut writer) = tokio::io::split(stream);
let _ = writer
.write_all(b"{\"error\":\"max_subscribers_reached\"}\n")
.await;
});
continue;
}
log::debug!("trace: UDS subscriber connected");
let rx = emitter.subscribe();
let active_clone = active.clone();
tokio::spawn(async move {
let (_, writer) = tokio::io::split(stream);
Self::forward_events(writer, rx).await;
active_clone.fetch_sub(1, Ordering::Relaxed);
log::debug!("trace: UDS subscriber disconnected");
});
}
Err(e) => {
log::error!("trace: UDS accept error: {}", e);
tokio::time::sleep(Duration::from_millis(100)).await;
}
}
}
}
async fn forward_events<W>(mut writer: W, mut rx: broadcast::Receiver<MatchTraceEvent>)
where
W: tokio::io::AsyncWrite + Unpin,
{
loop {
let event = match rx.recv().await {
Ok(e) => e,
Err(broadcast::error::RecvError::Lagged(n)) => {
log::debug!("trace: subscriber lagged, {} events dropped", n);
continue;
}
Err(broadcast::error::RecvError::Closed) => break,
};
let mut line = match serde_json::to_string(&event) {
Ok(s) => s,
Err(e) => {
log::error!("trace: serialise error: {}", e);
continue;
}
};
line.push('\n');
if writer.write_all(line.as_bytes()).await.is_err() {
break; }
}
}
}
pub fn now_ms() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or(Duration::ZERO)
.as_millis() as u64
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn emit_received_by_subscriber() {
let emitter = TraceEmitter::new();
let mut rx = emitter.subscribe();
emitter.emit(
1_000_000, 5,
RequestSummary { method: "GET".into(), url_path: "/api/test".into(), headers: vec![] },
Outcome::Miss { status: 404 },
);
let event = rx.try_recv().expect("event in channel");
assert_eq!(event.event_id, 0);
assert_eq!(event.schema_version, 1);
assert_eq!(event.request.method, "GET");
assert_eq!(event.duration_ms, 5);
assert_eq!(event.dropped_count, 0);
assert!(matches!(event.outcome, Outcome::Miss { status: 404 }));
}
#[tokio::test]
async fn emit_no_subscriber_increments_dropped() {
let emitter = TraceEmitter::new();
emitter.emit(0, 0,
RequestSummary { method: "GET".into(), url_path: "/".into(), headers: vec![] },
Outcome::Miss { status: 404 },
);
let mut rx = emitter.subscribe();
emitter.emit(0, 0,
RequestSummary { method: "GET".into(), url_path: "/".into(), headers: vec![] },
Outcome::Miss { status: 200 },
);
let event = rx.try_recv().expect("second event visible");
assert_eq!(event.dropped_count, 1, "first event should be counted dropped");
}
#[test]
fn has_subscribers_reflects_state() {
let emitter = TraceEmitter::new();
assert!(!emitter.has_subscribers());
let _rx = emitter.subscribe();
assert!(emitter.has_subscribers());
}
#[tokio::test]
async fn outcome_serialises_correctly() {
let event = MatchTraceEvent {
event_id: 7, schema_version: 1, received_at_ms: 0, duration_ms: 0,
request: RequestSummary { method: "POST".into(), url_path: "/x".into(), headers: vec![] },
outcome: Outcome::Matched { rule_set_index: 0, rule_index: 2 },
dropped_count: 0,
};
let json = serde_json::to_string(&event).unwrap();
assert!(json.contains("\"type\":\"matched\""));
assert!(json.contains("\"rule_index\":2"));
assert!(json.contains("\"schema_version\":1"));
}
#[tokio::test]
async fn tcp_transport_delivers_events() {
let emitter = TraceEmitter::new();
let emitter_clone = emitter.clone();
let config = TraceTransportConfig::Tcp { addr: "127.0.0.1:0".to_owned() };
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let bound_addr = listener.local_addr().unwrap();
tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();
let rx = emitter_clone.subscribe();
let (_, writer) = tokio::io::split(stream);
TraceTransport::forward_events(writer, rx).await;
});
let mut client = tokio::net::TcpStream::connect(bound_addr).await.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
emitter.emit(
42, 3,
RequestSummary { method: "GET".into(), url_path: "/ping".into(), headers: vec![] },
Outcome::Miss { status: 404 },
);
use tokio::io::AsyncBufReadExt;
let mut reader = tokio::io::BufReader::new(&mut client);
let mut line = String::new();
tokio::time::timeout(
std::time::Duration::from_secs(2),
reader.read_line(&mut line),
)
.await
.expect("timeout")
.expect("read ok");
let value: serde_json::Value = serde_json::from_str(line.trim()).expect("valid JSON");
assert_eq!(value["request"]["url_path"], "/ping");
assert_eq!(value["outcome"]["type"], "miss");
assert_eq!(value["schema_version"], 1);
}
}