1use std::sync::{
43 Arc,
44 atomic::{AtomicU32, AtomicUsize, Ordering},
45};
46use std::time::{Duration, SystemTime, UNIX_EPOCH};
47
48use serde::Serialize;
49use tokio::io::AsyncWriteExt;
50use tokio::sync::broadcast;
51
52pub const TRACE_CHANNEL_CAPACITY: usize = 1_024;
54pub const MAX_SUBSCRIBERS: usize = 4;
56
57#[derive(Clone, Debug, Serialize)]
61pub struct MatchTraceEvent {
62 pub event_id: u64,
64 pub schema_version: u8,
66 pub received_at_ms: u64,
68 pub duration_ms: u32,
70 pub request: RequestSummary,
72 pub outcome: Outcome,
74 pub dropped_count: u32,
76}
77
78#[derive(Clone, Debug, Serialize)]
80pub struct RequestSummary {
81 pub method: String,
82 pub url_path: String,
83 pub headers: Vec<(String, String)>,
85}
86
87#[derive(Clone, Debug, Serialize)]
89#[serde(tag = "type", rename_all = "snake_case")]
90pub enum Outcome {
91 Matched { rule_set_index: usize, rule_index: usize },
92 Fallback { file_path: String, status: u16 },
93 Miss { status: u16 },
94 Error { kind: String, message: String },
95}
96
97#[derive(Clone)]
103pub struct TraceEmitter {
104 sender: broadcast::Sender<MatchTraceEvent>,
105 event_counter: Arc<AtomicU32>,
106 dropped_counter: Arc<AtomicU32>,
107}
108
109impl TraceEmitter {
110 pub fn new() -> Self {
111 let (sender, _) = broadcast::channel(TRACE_CHANNEL_CAPACITY);
112 Self {
113 sender,
114 event_counter: Arc::new(AtomicU32::new(0)),
115 dropped_counter: Arc::new(AtomicU32::new(0)),
116 }
117 }
118
119 pub fn subscribe(&self) -> broadcast::Receiver<MatchTraceEvent> {
121 self.sender.subscribe()
122 }
123
124 pub fn emit(
127 &self,
128 received_at_ms: u64,
129 duration_ms: u32,
130 request: RequestSummary,
131 outcome: Outcome,
132 ) {
133 let event_id = self.event_counter.fetch_add(1, Ordering::Relaxed) as u64;
134 let dropped_count = self.dropped_counter.swap(0, Ordering::Relaxed);
135
136 let event = MatchTraceEvent {
137 event_id,
138 schema_version: 1,
139 received_at_ms,
140 duration_ms,
141 request,
142 outcome,
143 dropped_count,
144 };
145
146 if self.sender.send(event).is_err() {
147 self.dropped_counter.fetch_add(1, Ordering::Relaxed);
148 }
149 }
150
151 pub fn has_subscribers(&self) -> bool {
153 self.sender.receiver_count() > 0
154 }
155}
156
157impl Default for TraceEmitter { fn default() -> Self { Self::new() } }
158
159#[derive(Clone, Debug, Default)]
163pub enum TraceTransportConfig {
164 #[cfg(unix)]
166 Uds { path: String },
167 Tcp { addr: String },
169 #[default]
171 Disabled,
172}
173
174pub struct TraceTransport;
177
178impl TraceTransport {
179 pub async fn accept_loop(config: TraceTransportConfig, emitter: TraceEmitter) {
191 match config {
192 #[cfg(unix)]
193 TraceTransportConfig::Uds { path } => {
194 Self::uds_accept_loop(path, emitter).await
195 }
196 TraceTransportConfig::Tcp { addr } => {
197 Self::tcp_accept_loop(addr, emitter).await
198 }
199 TraceTransportConfig::Disabled => {
200 }
202 }
203 }
204
205 async fn tcp_accept_loop(addr: String, emitter: TraceEmitter) {
208 let listener = match tokio::net::TcpListener::bind(&addr).await {
209 Ok(l) => {
210 let bound = l.local_addr().map(|a| a.to_string())
211 .unwrap_or_else(|_| addr.clone());
212 log::info!("trace transport: TCP listening on {}", bound);
213 l
214 }
215 Err(e) => {
216 log::error!("trace transport: failed to bind TCP {}: {}", addr, e);
217 return;
218 }
219 };
220
221 let active = Arc::new(AtomicUsize::new(0));
222 loop {
223 match listener.accept().await {
224 Ok((stream, peer)) => {
225 let count = active.fetch_add(1, Ordering::Relaxed) + 1;
226 if count > MAX_SUBSCRIBERS {
227 active.fetch_sub(1, Ordering::Relaxed);
228 let emitter_clone = emitter.clone();
229 let active_clone = active.clone();
230 tokio::spawn(async move {
231 let (_, mut writer) = tokio::io::split(stream);
232 let _ = writer
233 .write_all(b"{\"error\":\"max_subscribers_reached\"}\n")
234 .await;
235 drop(active_clone);
236 });
237 continue;
238 }
239 log::debug!("trace: TCP subscriber connected from {}", peer);
240 let rx = emitter.subscribe();
241 let active_clone = active.clone();
242 tokio::spawn(async move {
243 let (_, writer) = tokio::io::split(stream);
244 Self::forward_events(writer, rx).await;
245 active_clone.fetch_sub(1, Ordering::Relaxed);
246 log::debug!("trace: TCP subscriber {} disconnected", peer);
247 });
248 }
249 Err(e) => {
250 log::error!("trace: TCP accept error: {}", e);
251 tokio::time::sleep(Duration::from_millis(100)).await;
252 }
253 }
254 }
255 }
256
257 #[cfg(unix)]
260 async fn uds_accept_loop(path: String, emitter: TraceEmitter) {
261 let _ = std::fs::remove_file(&path);
263
264 let listener = match tokio::net::UnixListener::bind(&path) {
265 Ok(l) => {
266 log::info!("trace transport: UDS listening at {}", path);
267 l
268 }
269 Err(e) => {
270 log::error!("trace transport: failed to bind UDS {}: {}", path, e);
271 return;
272 }
273 };
274
275 let active = Arc::new(AtomicUsize::new(0));
276 loop {
277 match listener.accept().await {
278 Ok((stream, _)) => {
279 let count = active.fetch_add(1, Ordering::Relaxed) + 1;
280 if count > MAX_SUBSCRIBERS {
281 active.fetch_sub(1, Ordering::Relaxed);
282 tokio::spawn(async move {
283 let (_, mut writer) = tokio::io::split(stream);
284 let _ = writer
285 .write_all(b"{\"error\":\"max_subscribers_reached\"}\n")
286 .await;
287 });
288 continue;
289 }
290 log::debug!("trace: UDS subscriber connected");
291 let rx = emitter.subscribe();
292 let active_clone = active.clone();
293 tokio::spawn(async move {
294 let (_, writer) = tokio::io::split(stream);
295 Self::forward_events(writer, rx).await;
296 active_clone.fetch_sub(1, Ordering::Relaxed);
297 log::debug!("trace: UDS subscriber disconnected");
298 });
299 }
300 Err(e) => {
301 log::error!("trace: UDS accept error: {}", e);
302 tokio::time::sleep(Duration::from_millis(100)).await;
303 }
304 }
305 }
306 }
307
308 async fn forward_events<W>(mut writer: W, mut rx: broadcast::Receiver<MatchTraceEvent>)
313 where
314 W: tokio::io::AsyncWrite + Unpin,
315 {
316 loop {
317 let event = match rx.recv().await {
318 Ok(e) => e,
319 Err(broadcast::error::RecvError::Lagged(n)) => {
320 log::debug!("trace: subscriber lagged, {} events dropped", n);
324 continue;
325 }
326 Err(broadcast::error::RecvError::Closed) => break,
327 };
328
329 let mut line = match serde_json::to_string(&event) {
330 Ok(s) => s,
331 Err(e) => {
332 log::error!("trace: serialise error: {}", e);
333 continue;
334 }
335 };
336 line.push('\n');
337
338 if writer.write_all(line.as_bytes()).await.is_err() {
339 break; }
341 }
342 }
343}
344
345pub fn now_ms() -> u64 {
349 SystemTime::now()
350 .duration_since(UNIX_EPOCH)
351 .unwrap_or(Duration::ZERO)
352 .as_millis() as u64
353}
354
355#[cfg(test)]
358mod tests {
359 use super::*;
360
361 #[tokio::test]
362 async fn emit_received_by_subscriber() {
363 let emitter = TraceEmitter::new();
364 let mut rx = emitter.subscribe();
365
366 emitter.emit(
367 1_000_000, 5,
368 RequestSummary { method: "GET".into(), url_path: "/api/test".into(), headers: vec![] },
369 Outcome::Miss { status: 404 },
370 );
371
372 let event = rx.try_recv().expect("event in channel");
373 assert_eq!(event.event_id, 0);
374 assert_eq!(event.schema_version, 1);
375 assert_eq!(event.request.method, "GET");
376 assert_eq!(event.duration_ms, 5);
377 assert_eq!(event.dropped_count, 0);
378 assert!(matches!(event.outcome, Outcome::Miss { status: 404 }));
379 }
380
381 #[tokio::test]
382 async fn emit_no_subscriber_increments_dropped() {
383 let emitter = TraceEmitter::new();
384 emitter.emit(0, 0,
385 RequestSummary { method: "GET".into(), url_path: "/".into(), headers: vec![] },
386 Outcome::Miss { status: 404 },
387 );
388 let mut rx = emitter.subscribe();
389 emitter.emit(0, 0,
390 RequestSummary { method: "GET".into(), url_path: "/".into(), headers: vec![] },
391 Outcome::Miss { status: 200 },
392 );
393 let event = rx.try_recv().expect("second event visible");
394 assert_eq!(event.dropped_count, 1, "first event should be counted dropped");
395 }
396
397 #[test]
398 fn has_subscribers_reflects_state() {
399 let emitter = TraceEmitter::new();
400 assert!(!emitter.has_subscribers());
401 let _rx = emitter.subscribe();
402 assert!(emitter.has_subscribers());
403 }
404
405 #[tokio::test]
406 async fn outcome_serialises_correctly() {
407 let event = MatchTraceEvent {
408 event_id: 7, schema_version: 1, received_at_ms: 0, duration_ms: 0,
409 request: RequestSummary { method: "POST".into(), url_path: "/x".into(), headers: vec![] },
410 outcome: Outcome::Matched { rule_set_index: 0, rule_index: 2 },
411 dropped_count: 0,
412 };
413 let json = serde_json::to_string(&event).unwrap();
414 assert!(json.contains("\"type\":\"matched\""));
415 assert!(json.contains("\"rule_index\":2"));
416 assert!(json.contains("\"schema_version\":1"));
417 }
418
419 #[tokio::test]
420 async fn tcp_transport_delivers_events() {
421 let emitter = TraceEmitter::new();
422 let emitter_clone = emitter.clone();
423
424 let config = TraceTransportConfig::Tcp { addr: "127.0.0.1:0".to_owned() };
426
427 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
431 let bound_addr = listener.local_addr().unwrap();
432
433 tokio::spawn(async move {
435 let (stream, _) = listener.accept().await.unwrap();
436 let rx = emitter_clone.subscribe();
437 let (_, writer) = tokio::io::split(stream);
438 TraceTransport::forward_events(writer, rx).await;
439 });
440
441 let mut client = tokio::net::TcpStream::connect(bound_addr).await.unwrap();
443
444 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
446
447 emitter.emit(
448 42, 3,
449 RequestSummary { method: "GET".into(), url_path: "/ping".into(), headers: vec![] },
450 Outcome::Miss { status: 404 },
451 );
452
453 use tokio::io::AsyncBufReadExt;
455 let mut reader = tokio::io::BufReader::new(&mut client);
456 let mut line = String::new();
457 tokio::time::timeout(
458 std::time::Duration::from_secs(2),
459 reader.read_line(&mut line),
460 )
461 .await
462 .expect("timeout")
463 .expect("read ok");
464
465 let value: serde_json::Value = serde_json::from_str(line.trim()).expect("valid JSON");
466 assert_eq!(value["request"]["url_path"], "/ping");
467 assert_eq!(value["outcome"]["type"], "miss");
468 assert_eq!(value["schema_version"], 1);
469 }
470}