1use std::net::SocketAddr;
4use std::sync::atomic::{AtomicUsize, Ordering};
5use std::sync::{Arc, RwLock};
6
7use protosocket::TcpSocketListener;
8use protosocket_messagepack::{MessagePackDecoder, MessagePackSerializer};
9use protosocket_rpc::Message;
10use protosocket_rpc::server::{ConnectionService, RpcResponder, SocketRpcServer, SocketService};
11use tokio::sync::watch;
12use tracing::metadata::LevelFilter;
13use tracing_cache::{ChanceHandle, EnabledPredicate, LevelHandle, SpanCache, SpanRecord};
14
15use crate::protocol::{Request, RequestBody, Response, WireLevel, WireLevelFilter};
16use crate::wire::{TimeBase, span_to_wire};
17
18type ServerCodec = (MessagePackSerializer<Response>, MessagePackDecoder<Request>);
21
22const STREAM_SUBSCRIBER_CAPACITY: u64 = 65_536;
27
28#[derive(Debug, Default)]
31struct StreamState {
32 streaming: bool,
33 min_level: Option<WireLevel>,
34 sampling_rate: f64,
35}
36
37impl StreamState {
38 fn new() -> Self {
39 Self {
40 streaming: false,
41 min_level: None,
42 sampling_rate: 1.0,
43 }
44 }
45}
46
47#[derive(Clone)]
60pub(crate) struct CacheLevelBroadcast {
61 level_handle: LevelHandle,
62 level_tx: watch::Sender<WireLevelFilter>,
63 chance_handle: ChanceHandle,
64 chance_tx: watch::Sender<f64>,
65 active_streams: Arc<AtomicUsize>,
66}
67
68impl CacheLevelBroadcast {
69 pub fn new(level_handle: LevelHandle, chance_handle: ChanceHandle) -> Self {
70 let initial_level = WireLevelFilter::from_tracing(level_handle.get());
71 let initial_chance = chance_handle.get();
72 let (level_tx, _) = watch::channel(initial_level);
73 let (chance_tx, _) = watch::channel(initial_chance);
74 Self {
75 level_handle,
76 level_tx,
77 chance_handle,
78 chance_tx,
79 active_streams: Arc::new(AtomicUsize::new(0)),
80 }
81 }
82
83 fn set_level(&self, filter: WireLevelFilter) {
84 self.level_handle.set(filter.to_tracing());
85 let _ = self.level_tx.send(filter);
86 }
87
88 fn set_chance(&self, pct: f64) {
89 let pct = if pct.is_nan() {
93 0.0
94 } else {
95 pct.clamp(0.0, 100.0)
96 };
97 self.chance_handle.set(pct);
98 let _ = self.chance_tx.send(pct);
99 }
100
101 fn subscribe_level(&self) -> watch::Receiver<WireLevelFilter> {
102 self.level_tx.subscribe()
103 }
104
105 fn subscribe_chance(&self) -> watch::Receiver<f64> {
106 self.chance_tx.subscribe()
107 }
108
109 fn enter_stream(&self) -> StreamGuard {
118 self.active_streams.fetch_add(1, Ordering::SeqCst);
119 StreamGuard {
120 broadcast: self.clone(),
121 }
122 }
123}
124
125pub(crate) struct StreamGuard {
131 broadcast: CacheLevelBroadcast,
132}
133
134impl Drop for StreamGuard {
135 fn drop(&mut self) {
136 let prev = self.broadcast.active_streams.fetch_sub(1, Ordering::SeqCst);
137 if prev == 1 {
138 self.broadcast.level_handle.set(LevelFilter::OFF);
143 let _ = self.broadcast.level_tx.send(WireLevelFilter::Off);
144 self.broadcast.chance_handle.set(100.0);
145 let _ = self.broadcast.chance_tx.send(100.0);
146 }
147 }
148}
149
150pub(crate) struct ConnectionState<P: EnabledPredicate> {
156 cache: Arc<SpanCache<P>>,
157 base: TimeBase,
158 state: Arc<RwLock<StreamState>>,
159 level_bus: CacheLevelBroadcast,
160 stream_guard: Option<StreamGuard>,
167}
168
169impl<P: EnabledPredicate> ConnectionState<P> {
170 fn new(cache: Arc<SpanCache<P>>, base: TimeBase, level_bus: CacheLevelBroadcast) -> Self {
171 Self {
172 cache,
173 base,
174 state: Arc::new(RwLock::new(StreamState::new())),
175 level_bus,
176 stream_guard: None,
177 }
178 }
179}
180
181impl<P: EnabledPredicate> ConnectionService for ConnectionState<P> {
182 type Request = Request;
183 type Response = Response;
184
185 #[allow(clippy::expect_used, reason = "poisoned lock")]
186 fn new_rpc(&mut self, msg: Request, responder: RpcResponder<'_, Response>) {
187 let request_id = msg.message_id();
191 match msg.body {
192 RequestBody::StartStream => {
193 self.state
194 .write()
195 .expect("lock must not be poisoned")
196 .streaming = true;
197 if self.stream_guard.is_none() {
202 self.stream_guard = Some(self.level_bus.enter_stream());
203 }
204 let cache = Arc::clone(&self.cache);
205 let state = Arc::clone(&self.state);
206 let base = self.base;
207 let level_rx = self.level_bus.subscribe_level();
208 let chance_rx = self.level_bus.subscribe_chance();
209 tokio::spawn(responder.stream(span_stream(
210 cache, state, base, level_rx, chance_rx, request_id,
211 )));
212 }
213 RequestBody::StopStream => {
214 self.state
215 .write()
216 .expect("lock must not be poisoned")
217 .streaming = false;
218 responder.immediate(Response::ack().with_id(request_id));
219 }
220 RequestBody::SetLevel(level) => {
221 self.state
222 .write()
223 .expect("lock must not be poisoned")
224 .min_level = Some(level);
225 responder.immediate(Response::ack().with_id(request_id));
226 }
227 RequestBody::SetCacheLevel(filter) => {
228 self.level_bus.set_level(filter);
229 responder.immediate(Response::ack().with_id(request_id));
230 }
231 RequestBody::SetCacheChance(pct) => {
232 self.level_bus.set_chance(pct);
233 responder.immediate(Response::ack().with_id(request_id));
234 }
235 RequestBody::SetSamplingRate(rate) => {
236 if !(0.0..=1.0).contains(&rate) || rate.is_nan() {
237 responder.immediate(
238 Response::error(format!("sampling rate {rate} out of range [0.0, 1.0]"))
239 .with_id(request_id),
240 );
241 return;
242 }
243 self.state
244 .write()
245 .expect("lock must not be poisoned")
246 .sampling_rate = rate;
247 responder.immediate(Response::ack().with_id(request_id));
248 }
249 RequestBody::Noop => {}
250 }
251 }
252}
253
254fn span_stream<P: EnabledPredicate>(
264 cache: Arc<SpanCache<P>>,
265 state: Arc<RwLock<StreamState>>,
266 base: TimeBase,
267 mut level_rx: watch::Receiver<WireLevelFilter>,
268 mut chance_rx: watch::Receiver<f64>,
269 request_id: u64,
270) -> impl futures_core::Stream<Item = Response> {
271 async_stream::stream! {
272 yield Response::server_info(env!("CARGO_PKG_VERSION")).with_id(request_id);
276 let initial_level = *level_rx.borrow_and_update();
279 yield Response::cache_level(initial_level).with_id(request_id);
280 let initial_chance = *chance_rx.borrow_and_update();
281 yield Response::cache_chance(initial_chance).with_id(request_id);
282
283 let mut span_rx = cache.subscribe(STREAM_SUBSCRIBER_CAPACITY);
291
292 loop {
293 tokio::select! {
294 changed = level_rx.changed() => {
295 if changed.is_err() { break; }
296 let lvl = *level_rx.borrow_and_update();
297 yield Response::cache_level(lvl).with_id(request_id);
298 }
299 changed = chance_rx.changed() => {
300 if changed.is_err() { break; }
301 let pct = *chance_rx.borrow_and_update();
302 yield Response::cache_chance(pct).with_id(request_id);
303 }
304 batch = span_rx.next_batch() => {
305 let Some(batch) = batch else { break };
306 let (streaming, min_level, sampling_rate) = {
307 #[allow(clippy::expect_used, reason = "poisoned lock")]
308 let s = state.read().expect("lock must not be poisoned");
309 (s.streaming, s.min_level, s.sampling_rate)
310 };
311 if !streaming {
312 drop(batch);
316 continue;
317 }
318 for record in batch {
319 if let Some(min) = min_level
320 && !level_at_least(record.metadata.level(), min)
321 {
322 continue;
323 }
324 if !sampling_passes(&record, sampling_rate) {
325 continue;
326 }
327 yield Response::span(span_to_wire(&record, base)).with_id(request_id);
328 }
329 }
330 }
331 }
332 }
333}
334
335fn level_at_least(record_level: &tracing::Level, floor: WireLevel) -> bool {
340 record_level <= &floor.to_tracing()
341}
342
343fn sampling_passes(record: &SpanRecord, rate: f64) -> bool {
348 if rate >= 1.0 {
349 return true;
350 }
351 if rate <= 0.0 {
352 return false;
353 }
354 let bucket_id = record.parent_id.unwrap_or(record.id);
356 let mut x = bucket_id.wrapping_mul(0x9E37_79B9_7F4A_7C15);
358 x ^= x >> 33;
359 x = x.wrapping_mul(0xC2B2_AE3D_27D4_EB4F);
360 x ^= x >> 29;
361 let frac = (x as f64) / (u64::MAX as f64);
362 frac < rate
363}
364
365struct Service<P: EnabledPredicate> {
368 cache: Arc<SpanCache<P>>,
369 base: TimeBase,
370 level_bus: CacheLevelBroadcast,
371}
372
373impl<P: EnabledPredicate> SocketService for Service<P> {
374 type Codec = ServerCodec;
375 type ConnectionService = ConnectionState<P>;
376 type SocketListener = TcpSocketListener;
377
378 fn codec(&self) -> Self::Codec {
379 (
380 MessagePackSerializer::default(),
381 MessagePackDecoder::default(),
382 )
383 }
384
385 fn new_stream_service(
386 &self,
387 _stream: &<Self::SocketListener as protosocket::SocketListener>::Stream,
388 ) -> Self::ConnectionService {
389 ConnectionState::new(Arc::clone(&self.cache), self.base, self.level_bus.clone())
390 }
391}
392
393#[derive(Debug)]
395pub enum ServeError {
396 Io(std::io::Error),
397 Rpc(protosocket_rpc::Error),
398}
399impl std::fmt::Display for ServeError {
400 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
401 match self {
402 ServeError::Io(e) => write!(f, "io: {e}"),
403 ServeError::Rpc(e) => write!(f, "rpc: {e}"),
404 }
405 }
406}
407impl std::error::Error for ServeError {}
408impl From<std::io::Error> for ServeError {
409 fn from(e: std::io::Error) -> Self {
410 ServeError::Io(e)
411 }
412}
413impl From<protosocket_rpc::Error> for ServeError {
414 fn from(e: protosocket_rpc::Error) -> Self {
415 ServeError::Rpc(e)
416 }
417}
418
419pub async fn serve<P: EnabledPredicate>(
428 cache: Arc<SpanCache<P>>,
429 level_handle: LevelHandle,
430 chance_handle: ChanceHandle,
431 addr: SocketAddr,
432) -> Result<(), ServeError> {
433 let listener = TcpSocketListener::listen(addr, 1024, None)?;
435
436 let service = Service {
437 cache,
438 base: TimeBase::now(),
439 level_bus: CacheLevelBroadcast::new(level_handle, chance_handle),
440 };
441 let server: SocketRpcServer<Service<P>, _> = SocketRpcServer::new(
442 listener,
443 service,
444 16 * 1024 * 1024,
445 64 * 1024,
446 4096,
447 )?;
448 server.await?;
449 Ok(())
450}
451
452#[cfg(test)]
455mod tests {
456 use super::*;
457 use std::net::TcpListener as StdTcpListener;
458 use std::time::Duration;
459
460 use futures::StreamExt;
461 use protosocket_messagepack::{MessagePackDecoder, MessagePackSerializer};
462 use protosocket_rpc::client::{self, Configuration, RpcClient, TcpStreamConnector};
463 use tracing_cache::{ChancePredicate, SpanCache};
464
465 use crate::protocol::{ResponseBody, WireLevel};
466
467 type ClientCodec = (MessagePackSerializer<Request>, MessagePackDecoder<Response>);
468
469 fn pick_addr() -> SocketAddr {
473 let listener = StdTcpListener::bind("127.0.0.1:0").unwrap();
474 let port = listener.local_addr().unwrap().port();
475 drop(listener);
476 format!("127.0.0.1:{port}").parse().unwrap()
477 }
478
479 fn prepare_cache() -> (
484 Arc<SpanCache<ChancePredicate<tracing_cache::LevelPredicate>>>,
485 LevelHandle,
486 ChanceHandle,
487 ) {
488 let level =
489 tracing_cache::LevelPredicate::with_filter(tracing::metadata::LevelFilter::TRACE);
490 let level_handle = level.handle();
491 let predicate = ChancePredicate::new(level, 100.0);
492 let chance_handle = predicate.handle();
493 let (cache, driver) = SpanCache::with_predicate(1024, predicate);
494 let cache = Arc::new(cache);
495 tokio::spawn(driver.run());
496 (cache, level_handle, chance_handle)
497 }
498
499 fn emit_under<P: EnabledPredicate>(cache: &Arc<SpanCache<P>>, f: impl FnOnce()) {
503 tracing::subscriber::with_default(Arc::clone(cache), f);
504 cache.flush_pending();
505 }
506
507 async fn wait_for_initial(
513 stream: &mut (impl futures::Stream<Item = Result<Response, protosocket_rpc::Error>> + Unpin),
514 ) {
515 let mut got_server_info = false;
516 let mut got_level = false;
517 let mut got_chance = false;
518 let deadline = tokio::time::Instant::now() + Duration::from_secs(2);
519 while !(got_server_info && got_level && got_chance)
520 && tokio::time::Instant::now() < deadline
521 {
522 match tokio::time::timeout(Duration::from_millis(200), stream.next()).await {
523 Ok(Some(Ok(resp))) => match resp.body {
524 ResponseBody::ServerInfo(_) => got_server_info = true,
525 ResponseBody::CacheLevel(_) => got_level = true,
526 ResponseBody::CacheChance(_) => got_chance = true,
527 _ => {}
528 },
529 _ => break,
530 }
531 }
532 assert!(
533 got_server_info && got_level && got_chance,
534 "stream did not yield initial ServerInfo/CacheLevel/CacheChance",
535 );
536 }
537
538 async fn spawn_server<P: EnabledPredicate>(
541 cache: Arc<SpanCache<P>>,
542 level_handle: LevelHandle,
543 chance_handle: ChanceHandle,
544 ) -> (SocketAddr, tokio::task::JoinHandle<()>) {
545 let addr = pick_addr();
546 let server_cache = Arc::clone(&cache);
547 let serve_level = level_handle.clone();
548 let serve_chance = chance_handle.clone();
549 let handle = tokio::spawn(async move {
550 let _ = serve(server_cache, serve_level, serve_chance, addr).await;
552 });
553 for _ in 0..50 {
555 if std::net::TcpStream::connect(addr).is_ok() {
556 return (addr, handle);
557 }
558 tokio::time::sleep(Duration::from_millis(10)).await;
559 }
560 panic!("server never came up on {addr}");
561 }
562
563 async fn connect_client(addr: SocketAddr) -> RpcClient<Request, Response> {
564 let cfg = Configuration::new(TcpStreamConnector);
565 let (rpc_client, conn) = client::connect::<ClientCodec, _>(addr, &cfg).await.unwrap();
566 tokio::spawn(conn);
568 rpc_client
569 }
570
571 async fn collect_spans(
573 stream: &mut (impl futures::Stream<Item = Result<Response, protosocket_rpc::Error>> + Unpin),
574 n: usize,
575 total_timeout: Duration,
576 ) -> Vec<crate::WireSpan> {
577 let mut out = Vec::with_capacity(n);
578 let deadline = tokio::time::Instant::now() + total_timeout;
579 while out.len() < n {
580 let remaining = deadline.saturating_duration_since(tokio::time::Instant::now());
581 match tokio::time::timeout(remaining, stream.next()).await {
582 Ok(Some(Ok(resp))) => {
583 if let ResponseBody::Span(s) = resp.body {
584 out.push(s);
585 }
586 }
587 Ok(Some(Err(_))) | Ok(None) => break,
588 Err(_) => break,
589 }
590 }
591 out
592 }
593
594 #[tokio::test]
597 async fn start_stream_delivers_closed_spans() {
598 let (cache, level_handle, chance_handle) = prepare_cache();
599 let (addr, server) = spawn_server(
600 Arc::clone(&cache),
601 level_handle.clone(),
602 chance_handle.clone(),
603 )
604 .await;
605 let client = connect_client(addr).await;
606 let mut stream = client
607 .send_streaming(Request::new(RequestBody::StartStream))
608 .unwrap();
609 wait_for_initial(&mut stream).await;
610
611 emit_under(&cache, || {
612 for _ in 0..3 {
613 let span = tracing::span!(parent: None, tracing::Level::INFO, "test_a");
614 let _g = span.enter();
615 }
616 });
617
618 let received = collect_spans(&mut stream, 3, Duration::from_secs(2)).await;
619 assert_eq!(received.len(), 3);
620 assert!(received.iter().all(|s| s.name == "test_a"));
621 assert!(received.iter().all(|s| s.closed_at_ns.is_some()));
622
623 server.abort();
624 }
625
626 #[tokio::test]
627 async fn stop_stream_halts_delivery() {
628 let (cache, level_handle, chance_handle) = prepare_cache();
629 let (addr, server) = spawn_server(
630 Arc::clone(&cache),
631 level_handle.clone(),
632 chance_handle.clone(),
633 )
634 .await;
635 let client = connect_client(addr).await;
636 let mut stream = client
637 .send_streaming(Request::new(RequestBody::StartStream))
638 .unwrap();
639 wait_for_initial(&mut stream).await;
640
641 emit_under(&cache, || {
643 let _g = tracing::span!(parent: None, tracing::Level::INFO, "test_b").entered();
644 });
645 let initial = collect_spans(&mut stream, 1, Duration::from_secs(2)).await;
646 assert_eq!(initial.len(), 1);
647
648 let ack = client
650 .send_unary(Request::new(RequestBody::StopStream))
651 .unwrap()
652 .await
653 .unwrap();
654 assert!(matches!(ack.body, ResponseBody::Ack));
655 tokio::time::sleep(Duration::from_millis(50)).await;
657
658 emit_under(&cache, || {
661 for _ in 0..5 {
662 let _g = tracing::span!(parent: None, tracing::Level::INFO, "test_b").entered();
663 }
664 });
665 let drained_after_stop = collect_spans(&mut stream, 5, Duration::from_millis(300)).await;
666 assert!(
667 drained_after_stop.len() < 5,
668 "stream did not stop: got {} more spans after StopStream",
669 drained_after_stop.len(),
670 );
671
672 server.abort();
673 }
674
675 #[tokio::test]
676 async fn set_level_filters_below_threshold() {
677 let (cache, level_handle, chance_handle) = prepare_cache();
678 let (addr, server) = spawn_server(
679 Arc::clone(&cache),
680 level_handle.clone(),
681 chance_handle.clone(),
682 )
683 .await;
684 let client = connect_client(addr).await;
685
686 let ack = client
687 .send_unary(Request::new(RequestBody::SetLevel(WireLevel::Info)))
688 .unwrap()
689 .await
690 .unwrap();
691 assert!(matches!(ack.body, ResponseBody::Ack));
692
693 let mut stream = client
694 .send_streaming(Request::new(RequestBody::StartStream))
695 .unwrap();
696 wait_for_initial(&mut stream).await;
697
698 emit_under(&cache, || {
701 drop(tracing::span!(parent: None, tracing::Level::INFO, "info_span"));
702 drop(tracing::span!(parent: None, tracing::Level::DEBUG, "debug_span"));
703 });
704
705 let received = collect_spans(&mut stream, 2, Duration::from_millis(500)).await;
706 let names: Vec<_> = received.iter().map(|s| s.name.as_str()).collect();
707 assert_eq!(names, vec!["info_span"], "got: {names:?}");
708
709 server.abort();
710 }
711
712 #[tokio::test]
713 async fn set_sampling_rate_zero_drops_all() {
714 let (cache, level_handle, chance_handle) = prepare_cache();
715 let (addr, server) = spawn_server(
716 Arc::clone(&cache),
717 level_handle.clone(),
718 chance_handle.clone(),
719 )
720 .await;
721 let client = connect_client(addr).await;
722
723 client
724 .send_unary(Request::new(RequestBody::SetSamplingRate(0.0)))
725 .unwrap()
726 .await
727 .unwrap();
728 let mut stream = client
729 .send_streaming(Request::new(RequestBody::StartStream))
730 .unwrap();
731 wait_for_initial(&mut stream).await;
732
733 emit_under(&cache, || {
734 for _ in 0..5 {
735 let _g = tracing::span!(parent: None, tracing::Level::INFO, "sampled").entered();
736 }
737 });
738
739 let received = collect_spans(&mut stream, 5, Duration::from_millis(400)).await;
740 assert!(
741 received.is_empty(),
742 "rate=0 should drop everything; got {received:?}",
743 );
744
745 server.abort();
746 }
747
748 #[tokio::test]
752 async fn set_cache_level_keeps_stream_open() {
753 let (cache, level_handle, chance_handle) = prepare_cache();
754 let (addr, server) = spawn_server(
755 Arc::clone(&cache),
756 level_handle.clone(),
757 chance_handle.clone(),
758 )
759 .await;
760 let client = connect_client(addr).await;
761 let mut start = Request::new(RequestBody::StartStream);
764 start.id = 100;
765 let mut stream = client.send_streaming(start).unwrap();
766
767 let first = tokio::time::timeout(Duration::from_secs(1), stream.next())
769 .await
770 .unwrap()
771 .unwrap()
772 .unwrap();
773 let server_info = match first.body {
774 ResponseBody::ServerInfo(info) => info,
775 other => panic!("first message should be ServerInfo, got {other:?}"),
776 };
777 assert_eq!(
778 server_info.version,
779 env!("CARGO_PKG_VERSION"),
780 "server should advertise its own CARGO_PKG_VERSION",
781 );
782
783 let mut drained_level = false;
787 let mut drained_chance = false;
788 while !(drained_level && drained_chance) {
789 let item = tokio::time::timeout(Duration::from_millis(500), stream.next())
790 .await
791 .unwrap()
792 .unwrap()
793 .unwrap();
794 match item.body {
795 ResponseBody::CacheLevel(_) => drained_level = true,
796 ResponseBody::CacheChance(_) => drained_chance = true,
797 other => panic!("unexpected message during initial drain: {other:?}"),
798 }
799 }
800
801 let mut set = Request::new(RequestBody::SetCacheLevel(WireLevelFilter::Off));
804 set.id = 101;
805 let ack = client.send_unary(set).unwrap().await.unwrap();
806 assert!(matches!(ack.body, ResponseBody::Ack));
807
808 let mut next_level: Option<WireLevelFilter> = None;
811 let deadline = tokio::time::Instant::now() + Duration::from_secs(2);
812 while tokio::time::Instant::now() < deadline && next_level.is_none() {
813 let item = tokio::time::timeout(Duration::from_millis(200), stream.next()).await;
814 let Ok(Some(Ok(resp))) = item else { continue };
815 match resp.body {
816 ResponseBody::CacheLevel(l) => next_level = Some(l),
817 ResponseBody::CacheChance(_) => continue,
820 ResponseBody::ServerInfo(_) => continue,
821 ResponseBody::Span(_) => continue,
822 other => panic!("unexpected stream item: {other:?}"),
823 }
824 }
825 assert_eq!(
826 next_level,
827 Some(WireLevelFilter::Off),
828 "stream did not yield the updated CacheLevel (probably ended)",
829 );
830
831 server.abort();
832 }
833
834 #[tokio::test]
839 async fn level_resets_to_off_when_last_console_disconnects() {
840 let (cache, level_handle, chance_handle) = prepare_cache();
841 level_handle.set(LevelFilter::INFO);
843
844 let (addr, server) = spawn_server(
845 Arc::clone(&cache),
846 level_handle.clone(),
847 chance_handle.clone(),
848 )
849 .await;
850
851 {
854 let client = connect_client(addr).await;
855 let mut start = Request::new(RequestBody::StartStream);
856 start.id = 200;
857 let _stream = client.send_streaming(start).unwrap();
858 tokio::time::sleep(Duration::from_millis(50)).await;
861 }
862 tokio::time::sleep(Duration::from_millis(500)).await;
865
866 assert_eq!(
868 level_handle.get(),
869 LevelFilter::OFF,
870 "level should have reset to OFF after last console disconnected",
871 );
872
873 server.abort();
874 }
875
876 use std::time::Instant;
879 use tracing::callsite::{Callsite, DefaultCallsite, Identifier};
880 use tracing::field::FieldSet;
881 use tracing::metadata::Kind;
882 use tracing_cache::{FieldList, SpanRecord};
883
884 static SAMPLING_CALLSITE: DefaultCallsite = {
885 static META: tracing::Metadata<'static> = tracing::Metadata::new(
886 "sampling_test",
887 "sampling::test",
888 tracing::Level::INFO,
889 None,
890 None,
891 None,
892 FieldSet::new(&[], Identifier(&SAMPLING_CALLSITE)),
893 Kind::SPAN,
894 );
895 DefaultCallsite::new(&META)
896 };
897
898 fn synth_span(id: u64, parent_id: Option<u64>) -> SpanRecord {
899 SpanRecord {
900 id,
901 parent_id,
902 metadata: SAMPLING_CALLSITE.metadata(),
903 fields: FieldList::default(),
904 events: Vec::new(),
905 opened_at: Instant::now(),
906 closed_at: Some(Instant::now()),
907 }
908 }
909
910 #[test]
911 fn sampling_passes_rate_one_short_circuits_true() {
912 for id in [0u64, 1, 17, u64::MAX, 0x9E37_79B9_7F4A_7C15] {
914 assert!(sampling_passes(&synth_span(id, None), 1.0));
915 }
916 }
917
918 #[test]
919 fn sampling_passes_rate_zero_short_circuits_false() {
920 for id in [0u64, 1, 17, u64::MAX] {
921 assert!(!sampling_passes(&synth_span(id, None), 0.0));
922 }
923 }
924
925 #[test]
926 fn sampling_passes_is_deterministic_per_root_id() {
927 for id in 1u64..=20 {
931 let r = synth_span(id, None);
932 let first = sampling_passes(&r, 0.5);
933 for _ in 0..3 {
934 assert_eq!(sampling_passes(&r, 0.5), first, "id={id}");
935 }
936 }
937 }
938
939 #[test]
940 fn sampling_passes_children_inherit_parents_root_id_bucket() {
941 let root = synth_span(7, None);
945 let want = sampling_passes(&root, 0.5);
946 for child_id in [100u64, 200, 300, u64::MAX] {
948 let child = synth_span(child_id, Some(7));
949 assert_eq!(sampling_passes(&child, 0.5), want);
950 }
951 }
952
953 #[test]
954 fn sampling_passes_partitions_population_near_target_rate() {
955 let rate = 0.3;
959 let n = 5_000u64;
960 let mut passed = 0usize;
961 for id in 1..=n {
962 if sampling_passes(&synth_span(id, None), rate) {
963 passed += 1;
964 }
965 }
966 let frac = passed as f64 / n as f64;
967 assert!(
968 (frac - rate).abs() < 0.03,
969 "frac={frac} rate={rate} — hash distribution drifted",
970 );
971 }
972
973 #[tokio::test]
976 async fn set_sampling_rate_rejects_out_of_range() {
977 let (cache, level_handle, chance_handle) = prepare_cache();
978 let (addr, server) = spawn_server(
979 Arc::clone(&cache),
980 level_handle.clone(),
981 chance_handle.clone(),
982 )
983 .await;
984 let client = connect_client(addr).await;
985
986 for bad in [1.5_f64, -0.1, f64::NAN] {
987 let resp = client
988 .send_unary(Request::new(RequestBody::SetSamplingRate(bad)))
989 .unwrap()
990 .await
991 .unwrap();
992 match resp.body {
993 ResponseBody::Error(msg) => {
994 assert!(
995 msg.contains("sampling rate"),
996 "unexpected error message for {bad}: {msg}",
997 );
998 }
999 other => panic!("expected Error for rate={bad}, got {other:?}"),
1000 }
1001 }
1002 server.abort();
1003 }
1004}