Skip to main content

tracing_console_host/
server.rs

1//! protosocket-rpc server that streams closed spans from a [`tracing_cache::SpanCache`].
2
3use std::net::SocketAddr;
4use std::sync::atomic::{AtomicUsize, Ordering};
5use std::sync::{Arc, RwLock};
6
7use protosocket::TcpSocketListener;
8use protosocket_messagepack::{MessagePackDecoder, MessagePackSerializer};
9use protosocket_rpc::Message;
10use protosocket_rpc::server::{ConnectionService, RpcResponder, SocketRpcServer, SocketService};
11use tokio::sync::watch;
12use tracing::metadata::LevelFilter;
13use tracing_cache::{ChanceHandle, EnabledPredicate, LevelHandle, SpanCache, SpanRecord};
14
15use crate::protocol::{Request, RequestBody, Response, WireLevel, WireLevelFilter};
16use crate::wire::{TimeBase, span_to_wire};
17
18// One messagepack frame per direction:
19//   server reads `Request`, writes `Response`.
20type ServerCodec = (MessagePackSerializer<Response>, MessagePackDecoder<Request>);
21
22/// Per-stream subscriber buffer.  Sized so a brief client stall doesn't
23/// cost us spans.  When the receiver falls behind by this much, the
24/// driver logs and drops a whole batch (see
25/// `tracing_cache::driver::fanout_to_subscribers`).
26const STREAM_SUBSCRIBER_CAPACITY: u64 = 65_536;
27
28/// Per-connection mutable state.  Read by the streaming RPC, mutated by the
29/// `Set*` RPCs and `StartStream` / `StopStream`.
30#[derive(Debug, Default)]
31struct StreamState {
32    streaming: bool,
33    min_level: Option<WireLevel>,
34    sampling_rate: f64,
35}
36
37impl StreamState {
38    fn new() -> Self {
39        Self {
40            streaming: false,
41            min_level: None,
42            sampling_rate: 1.0,
43        }
44    }
45}
46
47// ── Cache-level broadcaster ──────────────────────────────────────────────────
48
49/// Holds the `LevelHandle` for the cache's `LevelPredicate` plus a
50/// `tokio::sync::watch` channel so every active streaming connection
51/// can observe level changes without polling the handle.  A
52/// `SetCacheLevel` from any client flips the handle and pushes the
53/// new value into the watch — receivers wake up and forward a
54/// `CacheLevel` message down their span stream.
55///
56/// Also tracks the number of active streaming sessions via
57/// [`StreamGuard`] so the host can fall back to `OFF` when the last
58/// console drops — the cache costs zero work when nobody's watching.
59#[derive(Clone)]
60pub(crate) struct CacheLevelBroadcast {
61    level_handle: LevelHandle,
62    level_tx: watch::Sender<WireLevelFilter>,
63    chance_handle: ChanceHandle,
64    chance_tx: watch::Sender<f64>,
65    active_streams: Arc<AtomicUsize>,
66}
67
68impl CacheLevelBroadcast {
69    pub fn new(level_handle: LevelHandle, chance_handle: ChanceHandle) -> Self {
70        let initial_level = WireLevelFilter::from_tracing(level_handle.get());
71        let initial_chance = chance_handle.get();
72        let (level_tx, _) = watch::channel(initial_level);
73        let (chance_tx, _) = watch::channel(initial_chance);
74        Self {
75            level_handle,
76            level_tx,
77            chance_handle,
78            chance_tx,
79            active_streams: Arc::new(AtomicUsize::new(0)),
80        }
81    }
82
83    fn set_level(&self, filter: WireLevelFilter) {
84        self.level_handle.set(filter.to_tracing());
85        let _ = self.level_tx.send(filter);
86    }
87
88    fn set_chance(&self, pct: f64) {
89        // Clamp to a sensible range — the cache also clamps, but we
90        // broadcast the *effective* value so clients show what the
91        // host is actually applying.
92        let pct = if pct.is_nan() {
93            0.0
94        } else {
95            pct.clamp(0.0, 100.0)
96        };
97        self.chance_handle.set(pct);
98        let _ = self.chance_tx.send(pct);
99    }
100
101    fn subscribe_level(&self) -> watch::Receiver<WireLevelFilter> {
102        self.level_tx.subscribe()
103    }
104
105    fn subscribe_chance(&self) -> watch::Receiver<f64> {
106        self.chance_tx.subscribe()
107    }
108
109    /// Register a new console *streaming session* (one StartStream
110    /// RPC).  Returns a guard whose `Drop` decrements the counter;
111    /// when the counter hits zero (the last streaming RPC ended),
112    /// the level resets to `OFF` and the new state is broadcast so
113    /// any still-active stream picks it up.  Scoped to the streaming
114    /// RPC (not the connection) so liveness probes that open + close
115    /// a TCP socket without issuing StartStream don't trigger a
116    /// spurious reset.
117    fn enter_stream(&self) -> StreamGuard {
118        self.active_streams.fetch_add(1, Ordering::SeqCst);
119        StreamGuard {
120            broadcast: self.clone(),
121        }
122    }
123}
124
125/// RAII guard tracking a single active StartStream RPC.  Held by
126/// the `span_stream` async generator, so its `Drop` runs when the
127/// generator (and thus the spawned responder.stream future) ends —
128/// e.g. when the client cancels the streaming RPC by dropping the
129/// `StreamingCompletion`.
130pub(crate) struct StreamGuard {
131    broadcast: CacheLevelBroadcast,
132}
133
134impl Drop for StreamGuard {
135    fn drop(&mut self) {
136        let prev = self.broadcast.active_streams.fetch_sub(1, Ordering::SeqCst);
137        if prev == 1 {
138            // Last active streaming session — drop the cache back to
139            // OFF so an idle host pays nothing for tracing dispatch,
140            // and reset chance to 100% so the next console reconnects
141            // to a clean slate.
142            self.broadcast.level_handle.set(LevelFilter::OFF);
143            let _ = self.broadcast.level_tx.send(WireLevelFilter::Off);
144            self.broadcast.chance_handle.set(100.0);
145            let _ = self.broadcast.chance_tx.send(100.0);
146        }
147    }
148}
149
150// ── Per-connection service ───────────────────────────────────────────────────
151
152/// One per active client connection.  Holds an `Arc` to the shared cache so
153/// it can subscribe to closed-span fan-out, plus its own filter / sampling /
154/// level state.
155pub(crate) struct ConnectionState<P: EnabledPredicate> {
156    cache: Arc<SpanCache<P>>,
157    base: TimeBase,
158    state: Arc<RwLock<StreamState>>,
159    level_bus: CacheLevelBroadcast,
160    /// Lazily set on the first StartStream RPC.  Liveness probes that
161    /// open + close a TCP connection without ever issuing a streaming
162    /// RPC leave this `None`, so their drop doesn't decrement the
163    /// active-stream counter and trigger a spurious reset.  A
164    /// console always sends StartStream once, so its connection's
165    /// drop reliably fires the reset.
166    stream_guard: Option<StreamGuard>,
167}
168
169impl<P: EnabledPredicate> ConnectionState<P> {
170    fn new(cache: Arc<SpanCache<P>>, base: TimeBase, level_bus: CacheLevelBroadcast) -> Self {
171        Self {
172            cache,
173            base,
174            state: Arc::new(RwLock::new(StreamState::new())),
175            level_bus,
176            stream_guard: None,
177        }
178    }
179}
180
181impl<P: EnabledPredicate> ConnectionService for ConnectionState<P> {
182    type Request = Request;
183    type Response = Response;
184
185    #[allow(clippy::expect_used, reason = "poisoned lock")]
186    fn new_rpc(&mut self, msg: Request, responder: RpcResponder<'_, Response>) {
187        // Every Response must echo the request id so the client's
188        // completion registry (keyed by id) routes it back to the
189        // right pending RPC — see `Response::with_id`.
190        let request_id = msg.message_id();
191        match msg.body {
192            RequestBody::StartStream => {
193                self.state
194                    .write()
195                    .expect("lock must not be poisoned")
196                    .streaming = true;
197                // First StartStream on this connection — register a
198                // stream guard tied to the connection's lifetime.
199                // Subsequent StartStreams are idempotent: the guard
200                // already exists, the counter doesn't move twice.
201                if self.stream_guard.is_none() {
202                    self.stream_guard = Some(self.level_bus.enter_stream());
203                }
204                let cache = Arc::clone(&self.cache);
205                let state = Arc::clone(&self.state);
206                let base = self.base;
207                let level_rx = self.level_bus.subscribe_level();
208                let chance_rx = self.level_bus.subscribe_chance();
209                tokio::spawn(responder.stream(span_stream(
210                    cache, state, base, level_rx, chance_rx, request_id,
211                )));
212            }
213            RequestBody::StopStream => {
214                self.state
215                    .write()
216                    .expect("lock must not be poisoned")
217                    .streaming = false;
218                responder.immediate(Response::ack().with_id(request_id));
219            }
220            RequestBody::SetLevel(level) => {
221                self.state
222                    .write()
223                    .expect("lock must not be poisoned")
224                    .min_level = Some(level);
225                responder.immediate(Response::ack().with_id(request_id));
226            }
227            RequestBody::SetCacheLevel(filter) => {
228                self.level_bus.set_level(filter);
229                responder.immediate(Response::ack().with_id(request_id));
230            }
231            RequestBody::SetCacheChance(pct) => {
232                self.level_bus.set_chance(pct);
233                responder.immediate(Response::ack().with_id(request_id));
234            }
235            RequestBody::SetSamplingRate(rate) => {
236                if !(0.0..=1.0).contains(&rate) || rate.is_nan() {
237                    responder.immediate(
238                        Response::error(format!("sampling rate {rate} out of range [0.0, 1.0]"))
239                            .with_id(request_id),
240                    );
241                    return;
242                }
243                self.state
244                    .write()
245                    .expect("lock must not be poisoned")
246                    .sampling_rate = rate;
247                responder.immediate(Response::ack().with_id(request_id));
248            }
249            RequestBody::Noop => {}
250        }
251    }
252}
253
254/// Build the async stream of `Response` messages that satisfies a
255/// `StartStream` RPC.  Yields:
256///
257/// * an initial `CacheLevel` carrying the current cache level (so the
258///   client's UI is in sync the moment streaming begins),
259/// * every span the cache produces (after per-connection level /
260///   sampling / filter), and
261/// * a fresh `CacheLevel` every time the level changes (broadcast
262///   from any client's `SetCacheLevel`).
263fn span_stream<P: EnabledPredicate>(
264    cache: Arc<SpanCache<P>>,
265    state: Arc<RwLock<StreamState>>,
266    base: TimeBase,
267    mut level_rx: watch::Receiver<WireLevelFilter>,
268    mut chance_rx: watch::Receiver<f64>,
269    request_id: u64,
270) -> impl futures_core::Stream<Item = Response> {
271    async_stream::stream! {
272        // Push current level + chance first so the client renders
273        // its switcher / chance UI before any spans land.
274        let initial_level = *level_rx.borrow_and_update();
275        yield Response::cache_level(initial_level).with_id(request_id);
276        let initial_chance = *chance_rx.borrow_and_update();
277        yield Response::cache_chance(initial_chance).with_id(request_id);
278
279        // Register a subscriber — the driver fans every closed span
280        // into this receiver in commit (close-time) order, replacing
281        // the previous open-time `page(after_id, _)` cursor that
282        // silently dropped spans whenever close order diverged from
283        // open order (typical for async workloads).  Dropping this
284        // receiver (e.g. when the stream future ends) tells the
285        // driver to prune the sender on its next fan-out.
286        let mut span_rx = cache.subscribe(STREAM_SUBSCRIBER_CAPACITY);
287
288        loop {
289            tokio::select! {
290                changed = level_rx.changed() => {
291                    if changed.is_err() { break; }
292                    let lvl = *level_rx.borrow_and_update();
293                    yield Response::cache_level(lvl).with_id(request_id);
294                }
295                changed = chance_rx.changed() => {
296                    if changed.is_err() { break; }
297                    let pct = *chance_rx.borrow_and_update();
298                    yield Response::cache_chance(pct).with_id(request_id);
299                }
300                batch = span_rx.next_batch() => {
301                    let Some(batch) = batch else { break };
302                    let (streaming, min_level, sampling_rate) = {
303                        #[allow(clippy::expect_used, reason = "poisoned lock")]
304                        let s = state.read().expect("lock must not be poisoned");
305                        (s.streaming, s.min_level, s.sampling_rate)
306                    };
307                    if !streaming {
308                        // StopStream is in effect — keep draining the
309                        // batch (so the subscriber channel doesn't
310                        // back up) but discard.
311                        drop(batch);
312                        continue;
313                    }
314                    for record in batch {
315                        if let Some(min) = min_level
316                            && !level_at_least(record.metadata.level(), min)
317                        {
318                            continue;
319                        }
320                        if !sampling_passes(&record, sampling_rate) {
321                            continue;
322                        }
323                        yield Response::span(span_to_wire(&record, base)).with_id(request_id);
324                    }
325                }
326            }
327        }
328    }
329}
330
331/// True iff `record_level` is at least as severe as `floor`.  In tracing's
332/// reversed `Ord`, lower-severity levels compare *greater* (ERROR < WARN <
333/// INFO < DEBUG < TRACE), so the "at least as severe" relation is `<=`.
334/// E.g. with floor=INFO: INFO/WARN/ERROR pass, DEBUG/TRACE don't.
335fn level_at_least(record_level: &tracing::Level, floor: WireLevel) -> bool {
336    record_level <= &floor.to_tracing()
337}
338
339/// Hash-based sampling so the same root id deterministically passes/fails.
340/// Descendants follow the root's decision (the cache feeds us spans in
341/// increasing actual_id order, so a root is always streamed before its kids
342/// — but we still memoise via `root_decisions` to handle late-arriving ids).
343fn sampling_passes(record: &SpanRecord, rate: f64) -> bool {
344    if rate >= 1.0 {
345        return true;
346    }
347    if rate <= 0.0 {
348        return false;
349    }
350    // Use the root id (or this id, if a root) to pick the bucket.
351    let bucket_id = record.parent_id.unwrap_or(record.id);
352    // Cheap deterministic hash — splitmix-style.
353    let mut x = bucket_id.wrapping_mul(0x9E37_79B9_7F4A_7C15);
354    x ^= x >> 33;
355    x = x.wrapping_mul(0xC2B2_AE3D_27D4_EB4F);
356    x ^= x >> 29;
357    let frac = (x as f64) / (u64::MAX as f64);
358    frac < rate
359}
360
361// ── Top-level acceptor ───────────────────────────────────────────────────────
362
363struct Service<P: EnabledPredicate> {
364    cache: Arc<SpanCache<P>>,
365    base: TimeBase,
366    level_bus: CacheLevelBroadcast,
367}
368
369impl<P: EnabledPredicate> SocketService for Service<P> {
370    type Codec = ServerCodec;
371    type ConnectionService = ConnectionState<P>;
372    type SocketListener = TcpSocketListener;
373
374    fn codec(&self) -> Self::Codec {
375        (
376            MessagePackSerializer::default(),
377            MessagePackDecoder::default(),
378        )
379    }
380
381    fn new_stream_service(
382        &self,
383        _stream: &<Self::SocketListener as protosocket::SocketListener>::Stream,
384    ) -> Self::ConnectionService {
385        ConnectionState::new(Arc::clone(&self.cache), self.base, self.level_bus.clone())
386    }
387}
388
389/// Errors returned by [`serve`].
390#[derive(Debug)]
391pub enum ServeError {
392    Io(std::io::Error),
393    Rpc(protosocket_rpc::Error),
394}
395impl std::fmt::Display for ServeError {
396    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
397        match self {
398            ServeError::Io(e) => write!(f, "io: {e}"),
399            ServeError::Rpc(e) => write!(f, "rpc: {e}"),
400        }
401    }
402}
403impl std::error::Error for ServeError {}
404impl From<std::io::Error> for ServeError {
405    fn from(e: std::io::Error) -> Self {
406        ServeError::Io(e)
407    }
408}
409impl From<protosocket_rpc::Error> for ServeError {
410    fn from(e: protosocket_rpc::Error) -> Self {
411        ServeError::Rpc(e)
412    }
413}
414
415/// Bind to `addr` and serve the console RPC protocol against `cache`.
416///
417/// `level_handle` is the `LevelHandle` returned by the cache's
418/// `LevelPredicate`; the server uses it to apply `SetCacheLevel`
419/// requests and to broadcast the resulting level to every connected
420/// stream.  Caller is responsible for spawning the cache's `Driver`
421/// and for keeping `level_handle` consistent with what the cache
422/// actually uses.  The future runs until the listener errors out.
423pub async fn serve<P: EnabledPredicate>(
424    cache: Arc<SpanCache<P>>,
425    level_handle: LevelHandle,
426    chance_handle: ChanceHandle,
427    addr: SocketAddr,
428) -> Result<(), ServeError> {
429    // listen(addr, listen_backlog, accept_timeout) — last two are optional knobs.
430    let listener = TcpSocketListener::listen(addr, 1024, None)?;
431
432    let service = Service {
433        cache,
434        base: TimeBase::now(),
435        level_bus: CacheLevelBroadcast::new(level_handle, chance_handle),
436    };
437    let server: SocketRpcServer<Service<P>, _> = SocketRpcServer::new(
438        listener,
439        service,
440        /* max_buffer_length */ 16 * 1024 * 1024,
441        /* buffer_allocation_increment */ 64 * 1024,
442        /* max_queued_outbound_messages */ 4096,
443    )?;
444    server.await?;
445    Ok(())
446}
447
448// ── Integration tests ────────────────────────────────────────────────────────
449
450#[cfg(test)]
451mod tests {
452    use super::*;
453    use std::net::TcpListener as StdTcpListener;
454    use std::time::Duration;
455
456    use futures::StreamExt;
457    use protosocket_messagepack::{MessagePackDecoder, MessagePackSerializer};
458    use protosocket_rpc::client::{self, Configuration, RpcClient, TcpStreamConnector};
459    use tracing_cache::{ChancePredicate, SpanCache};
460
461    use crate::protocol::{ResponseBody, WireLevel};
462
463    type ClientCodec = (MessagePackSerializer<Request>, MessagePackDecoder<Response>);
464
465    /// Bind a std listener to ephemeral port, capture the port, drop it.  The
466    /// next bind on this port (by `serve`) reuses it (SO_REUSEADDR is set).
467    /// There's a tiny race window — fine on a developer box and CI.
468    fn pick_addr() -> SocketAddr {
469        let listener = StdTcpListener::bind("127.0.0.1:0").unwrap();
470        let port = listener.local_addr().unwrap().port();
471        drop(listener);
472        format!("127.0.0.1:{port}").parse().unwrap()
473    }
474
475    /// Build a SpanCache and spawn its driver so the subscriber model
476    /// can fan out closed spans live.  Tests emit spans *after*
477    /// subscribing — the cache holds no history, so anything emitted
478    /// before the first `StartStream` is lost.
479    fn prepare_cache() -> (
480        Arc<SpanCache<ChancePredicate<tracing_cache::LevelPredicate>>>,
481        LevelHandle,
482        ChanceHandle,
483    ) {
484        let level =
485            tracing_cache::LevelPredicate::with_filter(tracing::metadata::LevelFilter::TRACE);
486        let level_handle = level.handle();
487        let predicate = ChancePredicate::new(level, 100.0);
488        let chance_handle = predicate.handle();
489        let (cache, driver) = SpanCache::with_predicate(1024, predicate);
490        let cache = Arc::new(cache);
491        tokio::spawn(driver.run());
492        (cache, level_handle, chance_handle)
493    }
494
495    /// Run `f` with `cache` as the active tracing subscriber, then
496    /// flush this thread's PENDING buffer so the spans leave for the
497    /// driver immediately.
498    fn emit_under<P: EnabledPredicate>(cache: &Arc<SpanCache<P>>, f: impl FnOnce()) {
499        tracing::subscriber::with_default(Arc::clone(cache), f);
500        cache.flush_pending();
501    }
502
503    /// Drain the initial `CacheLevel` and `CacheChance` messages
504    /// `span_stream` always pushes before subscribing.  Receiving
505    /// both is the sync point that proves the subscriber is
506    /// registered, so any span emitted afterward will be fanned
507    /// out into this stream.
508    async fn wait_for_initial(
509        stream: &mut (impl futures::Stream<Item = Result<Response, protosocket_rpc::Error>> + Unpin),
510    ) {
511        let mut got_level = false;
512        let mut got_chance = false;
513        let deadline = tokio::time::Instant::now() + Duration::from_secs(2);
514        while !(got_level && got_chance) && tokio::time::Instant::now() < deadline {
515            match tokio::time::timeout(Duration::from_millis(200), stream.next()).await {
516                Ok(Some(Ok(resp))) => match resp.body {
517                    ResponseBody::CacheLevel(_) => got_level = true,
518                    ResponseBody::CacheChance(_) => got_chance = true,
519                    _ => {}
520                },
521                _ => break,
522            }
523        }
524        assert!(
525            got_level && got_chance,
526            "stream did not yield initial CacheLevel/CacheChance",
527        );
528    }
529
530    /// Spawn `serve` on a free port; return the address and a JoinHandle for
531    /// abort-on-drop semantics.  Briefly retries connect to confirm bind.
532    async fn spawn_server<P: EnabledPredicate>(
533        cache: Arc<SpanCache<P>>,
534        level_handle: LevelHandle,
535        chance_handle: ChanceHandle,
536    ) -> (SocketAddr, tokio::task::JoinHandle<()>) {
537        let addr = pick_addr();
538        let server_cache = Arc::clone(&cache);
539        let serve_level = level_handle.clone();
540        let serve_chance = chance_handle.clone();
541        let handle = tokio::spawn(async move {
542            // Discard the result; the test aborts this task at the end.
543            let _ = serve(server_cache, serve_level, serve_chance, addr).await;
544        });
545        // Wait for the server to actually be listening.
546        for _ in 0..50 {
547            if std::net::TcpStream::connect(addr).is_ok() {
548                return (addr, handle);
549            }
550            tokio::time::sleep(Duration::from_millis(10)).await;
551        }
552        panic!("server never came up on {addr}");
553    }
554
555    async fn connect_client(addr: SocketAddr) -> RpcClient<Request, Response> {
556        let cfg = Configuration::new(TcpStreamConnector);
557        let (rpc_client, conn) = client::connect::<ClientCodec, _>(addr, &cfg).await.unwrap();
558        // Drive the connection's I/O loop in the background.
559        tokio::spawn(conn);
560        rpc_client
561    }
562
563    /// Try to receive `n` Span responses from the stream within `total_timeout`.
564    async fn collect_spans(
565        stream: &mut (impl futures::Stream<Item = Result<Response, protosocket_rpc::Error>> + Unpin),
566        n: usize,
567        total_timeout: Duration,
568    ) -> Vec<crate::WireSpan> {
569        let mut out = Vec::with_capacity(n);
570        let deadline = tokio::time::Instant::now() + total_timeout;
571        while out.len() < n {
572            let remaining = deadline.saturating_duration_since(tokio::time::Instant::now());
573            match tokio::time::timeout(remaining, stream.next()).await {
574                Ok(Some(Ok(resp))) => {
575                    if let ResponseBody::Span(s) = resp.body {
576                        out.push(s);
577                    }
578                }
579                Ok(Some(Err(_))) | Ok(None) => break,
580                Err(_) => break,
581            }
582        }
583        out
584    }
585
586    // ── tests ─────────────────────────────────────────────────────────────────
587
588    #[tokio::test]
589    async fn start_stream_delivers_closed_spans() {
590        let (cache, level_handle, chance_handle) = prepare_cache();
591        let (addr, server) = spawn_server(
592            Arc::clone(&cache),
593            level_handle.clone(),
594            chance_handle.clone(),
595        )
596        .await;
597        let client = connect_client(addr).await;
598        let mut stream = client
599            .send_streaming(Request::new(RequestBody::StartStream))
600            .unwrap();
601        wait_for_initial(&mut stream).await;
602
603        emit_under(&cache, || {
604            for _ in 0..3 {
605                let span = tracing::span!(parent: None, tracing::Level::INFO, "test_a");
606                let _g = span.enter();
607            }
608        });
609
610        let received = collect_spans(&mut stream, 3, Duration::from_secs(2)).await;
611        assert_eq!(received.len(), 3);
612        assert!(received.iter().all(|s| s.name == "test_a"));
613        assert!(received.iter().all(|s| s.closed_at_ns.is_some()));
614
615        server.abort();
616    }
617
618    #[tokio::test]
619    async fn stop_stream_halts_delivery() {
620        let (cache, level_handle, chance_handle) = prepare_cache();
621        let (addr, server) = spawn_server(
622            Arc::clone(&cache),
623            level_handle.clone(),
624            chance_handle.clone(),
625        )
626        .await;
627        let client = connect_client(addr).await;
628        let mut stream = client
629            .send_streaming(Request::new(RequestBody::StartStream))
630            .unwrap();
631        wait_for_initial(&mut stream).await;
632
633        // First wave: emit one span and confirm it flows through.
634        emit_under(&cache, || {
635            let _g = tracing::span!(parent: None, tracing::Level::INFO, "test_b").entered();
636        });
637        let initial = collect_spans(&mut stream, 1, Duration::from_secs(2)).await;
638        assert_eq!(initial.len(), 1);
639
640        // Pause delivery.
641        let ack = client
642            .send_unary(Request::new(RequestBody::StopStream))
643            .unwrap()
644            .await
645            .unwrap();
646        assert!(matches!(ack.body, ResponseBody::Ack));
647        // Give the streaming task a chance to observe streaming=false.
648        tokio::time::sleep(Duration::from_millis(50)).await;
649
650        // Second wave: emit 5 spans; the server should drain them but
651        // not forward any to the client.
652        emit_under(&cache, || {
653            for _ in 0..5 {
654                let _g = tracing::span!(parent: None, tracing::Level::INFO, "test_b").entered();
655            }
656        });
657        let drained_after_stop = collect_spans(&mut stream, 5, Duration::from_millis(300)).await;
658        assert!(
659            drained_after_stop.len() < 5,
660            "stream did not stop: got {} more spans after StopStream",
661            drained_after_stop.len(),
662        );
663
664        server.abort();
665    }
666
667    #[tokio::test]
668    async fn set_level_filters_below_threshold() {
669        let (cache, level_handle, chance_handle) = prepare_cache();
670        let (addr, server) = spawn_server(
671            Arc::clone(&cache),
672            level_handle.clone(),
673            chance_handle.clone(),
674        )
675        .await;
676        let client = connect_client(addr).await;
677
678        let ack = client
679            .send_unary(Request::new(RequestBody::SetLevel(WireLevel::Info)))
680            .unwrap()
681            .await
682            .unwrap();
683        assert!(matches!(ack.body, ResponseBody::Ack));
684
685        let mut stream = client
686            .send_streaming(Request::new(RequestBody::StartStream))
687            .unwrap();
688        wait_for_initial(&mut stream).await;
689
690        // The cache predicate is TRACE so both spans reach the driver;
691        // the host's SetLevel must be what filters DEBUG on the wire.
692        emit_under(&cache, || {
693            drop(tracing::span!(parent: None, tracing::Level::INFO, "info_span"));
694            drop(tracing::span!(parent: None, tracing::Level::DEBUG, "debug_span"));
695        });
696
697        let received = collect_spans(&mut stream, 2, Duration::from_millis(500)).await;
698        let names: Vec<_> = received.iter().map(|s| s.name.as_str()).collect();
699        assert_eq!(names, vec!["info_span"], "got: {names:?}");
700
701        server.abort();
702    }
703
704    #[tokio::test]
705    async fn set_sampling_rate_zero_drops_all() {
706        let (cache, level_handle, chance_handle) = prepare_cache();
707        let (addr, server) = spawn_server(
708            Arc::clone(&cache),
709            level_handle.clone(),
710            chance_handle.clone(),
711        )
712        .await;
713        let client = connect_client(addr).await;
714
715        client
716            .send_unary(Request::new(RequestBody::SetSamplingRate(0.0)))
717            .unwrap()
718            .await
719            .unwrap();
720        let mut stream = client
721            .send_streaming(Request::new(RequestBody::StartStream))
722            .unwrap();
723        wait_for_initial(&mut stream).await;
724
725        emit_under(&cache, || {
726            for _ in 0..5 {
727                let _g = tracing::span!(parent: None, tracing::Level::INFO, "sampled").entered();
728            }
729        });
730
731        let received = collect_spans(&mut stream, 5, Duration::from_millis(400)).await;
732        assert!(
733            received.is_empty(),
734            "rate=0 should drop everything; got {received:?}",
735        );
736
737        server.abort();
738    }
739
740    /// Setting the cache level must not end the streaming RPC — the
741    /// server should push a fresh `CacheLevel` notification on the
742    /// existing stream and continue streaming spans after.
743    #[tokio::test]
744    async fn set_cache_level_keeps_stream_open() {
745        let (cache, level_handle, chance_handle) = prepare_cache();
746        let (addr, server) = spawn_server(
747            Arc::clone(&cache),
748            level_handle.clone(),
749            chance_handle.clone(),
750        )
751        .await;
752        let client = connect_client(addr).await;
753        // Distinct ids — the framework matches responses to RPCs by
754        // request id, and id=0 would clobber on the client side.
755        let mut start = Request::new(RequestBody::StartStream);
756        start.id = 100;
757        let mut stream = client.send_streaming(start).unwrap();
758
759        // First push is always the initial CacheLevel snapshot.
760        let first = tokio::time::timeout(Duration::from_secs(1), stream.next())
761            .await
762            .unwrap()
763            .unwrap()
764            .unwrap();
765        assert!(
766            matches!(first.body, ResponseBody::CacheLevel(_)),
767            "first message should be CacheLevel, got {:?}",
768            first.body
769        );
770
771        // Change the level — the unary should ack while the streaming
772        // RPC stays open.
773        let mut set = Request::new(RequestBody::SetCacheLevel(WireLevelFilter::Off));
774        set.id = 101;
775        let ack = client.send_unary(set).unwrap().await.unwrap();
776        assert!(matches!(ack.body, ResponseBody::Ack));
777
778        // Next stream item must be the updated CacheLevel; it must
779        // arrive (not end-of-stream) within a generous window.
780        let mut next_level: Option<WireLevelFilter> = None;
781        let deadline = tokio::time::Instant::now() + Duration::from_secs(2);
782        while tokio::time::Instant::now() < deadline && next_level.is_none() {
783            let item = tokio::time::timeout(Duration::from_millis(200), stream.next()).await;
784            let Ok(Some(Ok(resp))) = item else { continue };
785            match resp.body {
786                ResponseBody::CacheLevel(l) => next_level = Some(l),
787                // Initial chance push + any chance broadcasts are fine
788                // — we just don't care about them in this test.
789                ResponseBody::CacheChance(_) => continue,
790                ResponseBody::Span(_) => continue,
791                other => panic!("unexpected stream item: {other:?}"),
792            }
793        }
794        assert_eq!(
795            next_level,
796            Some(WireLevelFilter::Off),
797            "stream did not yield the updated CacheLevel (probably ended)",
798        );
799
800        server.abort();
801    }
802
803    /// When the last streaming RPC drops, the server should reset
804    /// the cache level to `OFF`.  Verified by: connect a client at
805    /// non-OFF level, drop the client, then reconnect and observe
806    /// the initial CacheLevel is OFF.
807    #[tokio::test]
808    async fn level_resets_to_off_when_last_console_disconnects() {
809        let (cache, level_handle, chance_handle) = prepare_cache();
810        // Start at INFO via the cache predicate handle.
811        level_handle.set(LevelFilter::INFO);
812
813        let (addr, server) = spawn_server(
814            Arc::clone(&cache),
815            level_handle.clone(),
816            chance_handle.clone(),
817        )
818        .await;
819
820        // Open a streaming RPC; drop it immediately to mimic a console
821        // disconnect.
822        {
823            let client = connect_client(addr).await;
824            let mut start = Request::new(RequestBody::StartStream);
825            start.id = 200;
826            let _stream = client.send_streaming(start).unwrap();
827            // Wait a beat for the StartStream to register on the
828            // server side, then drop everything.
829            tokio::time::sleep(Duration::from_millis(50)).await;
830        }
831        // Give the server time to notice the disconnect and run the
832        // StreamGuard's Drop.
833        tokio::time::sleep(Duration::from_millis(500)).await;
834
835        // Level handle should now read OFF.
836        assert_eq!(
837            level_handle.get(),
838            LevelFilter::OFF,
839            "level should have reset to OFF after last console disconnected",
840        );
841
842        server.abort();
843    }
844
845    // ── sampling_passes unit tests ───────────────────────────────────────────
846
847    use std::time::Instant;
848    use tracing::callsite::{Callsite, DefaultCallsite, Identifier};
849    use tracing::field::FieldSet;
850    use tracing::metadata::Kind;
851    use tracing_cache::{FieldList, SpanRecord};
852
853    static SAMPLING_CALLSITE: DefaultCallsite = {
854        static META: tracing::Metadata<'static> = tracing::Metadata::new(
855            "sampling_test",
856            "sampling::test",
857            tracing::Level::INFO,
858            None,
859            None,
860            None,
861            FieldSet::new(&[], Identifier(&SAMPLING_CALLSITE)),
862            Kind::SPAN,
863        );
864        DefaultCallsite::new(&META)
865    };
866
867    fn synth_span(id: u64, parent_id: Option<u64>) -> SpanRecord {
868        SpanRecord {
869            id,
870            parent_id,
871            metadata: SAMPLING_CALLSITE.metadata(),
872            fields: FieldList::default(),
873            events: Vec::new(),
874            opened_at: Instant::now(),
875            closed_at: Some(Instant::now()),
876        }
877    }
878
879    #[test]
880    fn sampling_passes_rate_one_short_circuits_true() {
881        // rate >= 1.0 must accept every span regardless of id hash.
882        for id in [0u64, 1, 17, u64::MAX, 0x9E37_79B9_7F4A_7C15] {
883            assert!(sampling_passes(&synth_span(id, None), 1.0));
884        }
885    }
886
887    #[test]
888    fn sampling_passes_rate_zero_short_circuits_false() {
889        for id in [0u64, 1, 17, u64::MAX] {
890            assert!(!sampling_passes(&synth_span(id, None), 0.0));
891        }
892    }
893
894    #[test]
895    fn sampling_passes_is_deterministic_per_root_id() {
896        // Repeating the call with the same record must yield the same
897        // answer — otherwise children inheriting a root's decision
898        // would race against their root's hash.
899        for id in 1u64..=20 {
900            let r = synth_span(id, None);
901            let first = sampling_passes(&r, 0.5);
902            for _ in 0..3 {
903                assert_eq!(sampling_passes(&r, 0.5), first, "id={id}");
904            }
905        }
906    }
907
908    #[test]
909    fn sampling_passes_children_inherit_parents_root_id_bucket() {
910        // Children with `parent_id = Some(root)` must hash on the
911        // root, not on their own id.  Pick a root id that does pass
912        // at rate=0.5 and demonstrate the child gets the same answer.
913        let root = synth_span(7, None);
914        let want = sampling_passes(&root, 0.5);
915        // Several different child ids, all under root=7 → all match.
916        for child_id in [100u64, 200, 300, u64::MAX] {
917            let child = synth_span(child_id, Some(7));
918            assert_eq!(sampling_passes(&child, 0.5), want);
919        }
920    }
921
922    #[test]
923    fn sampling_passes_partitions_population_near_target_rate() {
924        // Coarse distribution sanity-check — splitmix should produce
925        // close-to-uniform fractions, so rate=0.3 over a large pool
926        // should pass roughly 30% of distinct root ids.
927        let rate = 0.3;
928        let n = 5_000u64;
929        let mut passed = 0usize;
930        for id in 1..=n {
931            if sampling_passes(&synth_span(id, None), rate) {
932                passed += 1;
933            }
934        }
935        let frac = passed as f64 / n as f64;
936        assert!(
937            (frac - rate).abs() < 0.03,
938            "frac={frac} rate={rate} — hash distribution drifted",
939        );
940    }
941
942    // ── SetSamplingRate RPC validation ───────────────────────────────────────
943
944    #[tokio::test]
945    async fn set_sampling_rate_rejects_out_of_range() {
946        let (cache, level_handle, chance_handle) = prepare_cache();
947        let (addr, server) = spawn_server(
948            Arc::clone(&cache),
949            level_handle.clone(),
950            chance_handle.clone(),
951        )
952        .await;
953        let client = connect_client(addr).await;
954
955        for bad in [1.5_f64, -0.1, f64::NAN] {
956            let resp = client
957                .send_unary(Request::new(RequestBody::SetSamplingRate(bad)))
958                .unwrap()
959                .await
960                .unwrap();
961            match resp.body {
962                ResponseBody::Error(msg) => {
963                    assert!(
964                        msg.contains("sampling rate"),
965                        "unexpected error message for {bad}: {msg}",
966                    );
967                }
968                other => panic!("expected Error for rate={bad}, got {other:?}"),
969            }
970        }
971        server.abort();
972    }
973}