Skip to main content

tracing_console_host/
server.rs

1//! protosocket-rpc server that streams closed spans from a [`tracing_cache::SpanCache`].
2
3use std::net::SocketAddr;
4use std::sync::atomic::{AtomicUsize, Ordering};
5use std::sync::{Arc, RwLock};
6
7use protosocket::TcpSocketListener;
8use protosocket_messagepack::{MessagePackDecoder, MessagePackSerializer};
9use protosocket_rpc::Message;
10use protosocket_rpc::server::{ConnectionService, RpcResponder, SocketRpcServer, SocketService};
11use tokio::sync::watch;
12use tracing::metadata::LevelFilter;
13use tracing_cache::{ChanceHandle, EnabledPredicate, LevelHandle, SpanCache, SpanRecord};
14
15use crate::protocol::{Request, RequestBody, Response, WireLevel, WireLevelFilter};
16use crate::wire::{TimeBase, span_to_wire};
17
18// One messagepack frame per direction:
19//   server reads `Request`, writes `Response`.
20type ServerCodec = (MessagePackSerializer<Response>, MessagePackDecoder<Request>);
21
22/// Per-stream subscriber buffer.  Sized so a brief client stall doesn't
23/// cost us spans.  When the receiver falls behind by this much, the
24/// driver logs and drops a whole batch (see
25/// `tracing_cache::driver::fanout_to_subscribers`).
26const STREAM_SUBSCRIBER_CAPACITY: u64 = 65_536;
27
28/// Per-connection mutable state.  Read by the streaming RPC, mutated by the
29/// `Set*` RPCs and `StartStream` / `StopStream`.
30#[derive(Debug, Default)]
31struct StreamState {
32    streaming: bool,
33    min_level: Option<WireLevel>,
34    sampling_rate: f64,
35}
36
37impl StreamState {
38    fn new() -> Self {
39        Self {
40            streaming: false,
41            min_level: None,
42            sampling_rate: 1.0,
43        }
44    }
45}
46
47// ── Cache-level broadcaster ──────────────────────────────────────────────────
48
49/// Holds the `LevelHandle` for the cache's `LevelPredicate` plus a
50/// `tokio::sync::watch` channel so every active streaming connection
51/// can observe level changes without polling the handle.  A
52/// `SetCacheLevel` from any client flips the handle and pushes the
53/// new value into the watch — receivers wake up and forward a
54/// `CacheLevel` message down their span stream.
55///
56/// Also tracks the number of active streaming sessions via
57/// [`StreamGuard`] so the host can fall back to `OFF` when the last
58/// console drops — the cache costs zero work when nobody's watching.
59#[derive(Clone)]
60pub(crate) struct CacheLevelBroadcast {
61    level_handle: LevelHandle,
62    level_tx: watch::Sender<WireLevelFilter>,
63    chance_handle: ChanceHandle,
64    chance_tx: watch::Sender<f64>,
65    active_streams: Arc<AtomicUsize>,
66}
67
68impl CacheLevelBroadcast {
69    pub fn new(level_handle: LevelHandle, chance_handle: ChanceHandle) -> Self {
70        let initial_level = WireLevelFilter::from_tracing(level_handle.get());
71        let initial_chance = chance_handle.get();
72        let (level_tx, _) = watch::channel(initial_level);
73        let (chance_tx, _) = watch::channel(initial_chance);
74        Self {
75            level_handle,
76            level_tx,
77            chance_handle,
78            chance_tx,
79            active_streams: Arc::new(AtomicUsize::new(0)),
80        }
81    }
82
83    fn set_level(&self, filter: WireLevelFilter) {
84        self.level_handle.set(filter.to_tracing());
85        let _ = self.level_tx.send(filter);
86    }
87
88    fn set_chance(&self, pct: f64) {
89        // Clamp to a sensible range — the cache also clamps, but we
90        // broadcast the *effective* value so clients show what the
91        // host is actually applying.
92        let pct = if pct.is_nan() {
93            0.0
94        } else {
95            pct.clamp(0.0, 100.0)
96        };
97        self.chance_handle.set(pct);
98        let _ = self.chance_tx.send(pct);
99    }
100
101    fn subscribe_level(&self) -> watch::Receiver<WireLevelFilter> {
102        self.level_tx.subscribe()
103    }
104
105    fn subscribe_chance(&self) -> watch::Receiver<f64> {
106        self.chance_tx.subscribe()
107    }
108
109    /// Register a new console *streaming session* (one StartStream
110    /// RPC).  Returns a guard whose `Drop` decrements the counter;
111    /// when the counter hits zero (the last streaming RPC ended),
112    /// the level resets to `OFF` and the new state is broadcast so
113    /// any still-active stream picks it up.  Scoped to the streaming
114    /// RPC (not the connection) so liveness probes that open + close
115    /// a TCP socket without issuing StartStream don't trigger a
116    /// spurious reset.
117    fn enter_stream(&self) -> StreamGuard {
118        self.active_streams.fetch_add(1, Ordering::SeqCst);
119        StreamGuard {
120            broadcast: self.clone(),
121        }
122    }
123}
124
125/// RAII guard tracking a single active StartStream RPC.  Held by
126/// the `span_stream` async generator, so its `Drop` runs when the
127/// generator (and thus the spawned responder.stream future) ends —
128/// e.g. when the client cancels the streaming RPC by dropping the
129/// `StreamingCompletion`.
130pub(crate) struct StreamGuard {
131    broadcast: CacheLevelBroadcast,
132}
133
134impl Drop for StreamGuard {
135    fn drop(&mut self) {
136        let prev = self.broadcast.active_streams.fetch_sub(1, Ordering::SeqCst);
137        if prev == 1 {
138            // Last active streaming session — drop the cache back to
139            // OFF so an idle host pays nothing for tracing dispatch,
140            // and reset chance to 100% so the next console reconnects
141            // to a clean slate.
142            self.broadcast.level_handle.set(LevelFilter::OFF);
143            let _ = self.broadcast.level_tx.send(WireLevelFilter::Off);
144            self.broadcast.chance_handle.set(100.0);
145            let _ = self.broadcast.chance_tx.send(100.0);
146        }
147    }
148}
149
150// ── Per-connection service ───────────────────────────────────────────────────
151
152/// One per active client connection.  Holds an `Arc` to the shared cache so
153/// it can subscribe to closed-span fan-out, plus its own filter / sampling /
154/// level state.
155pub(crate) struct ConnectionState<P: EnabledPredicate> {
156    cache: Arc<SpanCache<P>>,
157    base: TimeBase,
158    state: Arc<RwLock<StreamState>>,
159    level_bus: CacheLevelBroadcast,
160    /// Lazily set on the first StartStream RPC.  Liveness probes that
161    /// open + close a TCP connection without ever issuing a streaming
162    /// RPC leave this `None`, so their drop doesn't decrement the
163    /// active-stream counter and trigger a spurious reset.  A
164    /// console always sends StartStream once, so its connection's
165    /// drop reliably fires the reset.
166    stream_guard: Option<StreamGuard>,
167}
168
169impl<P: EnabledPredicate> ConnectionState<P> {
170    fn new(cache: Arc<SpanCache<P>>, base: TimeBase, level_bus: CacheLevelBroadcast) -> Self {
171        Self {
172            cache,
173            base,
174            state: Arc::new(RwLock::new(StreamState::new())),
175            level_bus,
176            stream_guard: None,
177        }
178    }
179}
180
181impl<P: EnabledPredicate> ConnectionService for ConnectionState<P> {
182    type Request = Request;
183    type Response = Response;
184
185    #[allow(clippy::expect_used, reason = "poisoned lock")]
186    fn new_rpc(&mut self, msg: Request, responder: RpcResponder<'_, Response>) {
187        // Every Response must echo the request id so the client's
188        // completion registry (keyed by id) routes it back to the
189        // right pending RPC — see `Response::with_id`.
190        let request_id = msg.message_id();
191        match msg.body {
192            RequestBody::StartStream => {
193                self.state
194                    .write()
195                    .expect("lock must not be poisoned")
196                    .streaming = true;
197                // First StartStream on this connection — register a
198                // stream guard tied to the connection's lifetime.
199                // Subsequent StartStreams are idempotent: the guard
200                // already exists, the counter doesn't move twice.
201                if self.stream_guard.is_none() {
202                    self.stream_guard = Some(self.level_bus.enter_stream());
203                }
204                let cache = Arc::clone(&self.cache);
205                let state = Arc::clone(&self.state);
206                let base = self.base;
207                let level_rx = self.level_bus.subscribe_level();
208                let chance_rx = self.level_bus.subscribe_chance();
209                tokio::spawn(responder.stream(span_stream(
210                    cache, state, base, level_rx, chance_rx, request_id,
211                )));
212            }
213            RequestBody::StopStream => {
214                self.state
215                    .write()
216                    .expect("lock must not be poisoned")
217                    .streaming = false;
218                responder.immediate(Response::ack().with_id(request_id));
219            }
220            RequestBody::SetLevel(level) => {
221                self.state
222                    .write()
223                    .expect("lock must not be poisoned")
224                    .min_level = Some(level);
225                responder.immediate(Response::ack().with_id(request_id));
226            }
227            RequestBody::SetCacheLevel(filter) => {
228                self.level_bus.set_level(filter);
229                responder.immediate(Response::ack().with_id(request_id));
230            }
231            RequestBody::SetCacheChance(pct) => {
232                self.level_bus.set_chance(pct);
233                responder.immediate(Response::ack().with_id(request_id));
234            }
235            RequestBody::SetSamplingRate(rate) => {
236                if !(0.0..=1.0).contains(&rate) || rate.is_nan() {
237                    responder.immediate(
238                        Response::error(format!("sampling rate {rate} out of range [0.0, 1.0]"))
239                            .with_id(request_id),
240                    );
241                    return;
242                }
243                self.state
244                    .write()
245                    .expect("lock must not be poisoned")
246                    .sampling_rate = rate;
247                responder.immediate(Response::ack().with_id(request_id));
248            }
249            RequestBody::Noop => {}
250        }
251    }
252}
253
254/// Build the async stream of `Response` messages that satisfies a
255/// `StartStream` RPC.  Yields:
256///
257/// * an initial `CacheLevel` carrying the current cache level (so the
258///   client's UI is in sync the moment streaming begins),
259/// * every span the cache produces (after per-connection level /
260///   sampling / filter), and
261/// * a fresh `CacheLevel` every time the level changes (broadcast
262///   from any client's `SetCacheLevel`).
263fn span_stream<P: EnabledPredicate>(
264    cache: Arc<SpanCache<P>>,
265    state: Arc<RwLock<StreamState>>,
266    base: TimeBase,
267    mut level_rx: watch::Receiver<WireLevelFilter>,
268    mut chance_rx: watch::Receiver<f64>,
269    request_id: u64,
270) -> impl futures_core::Stream<Item = Response> {
271    async_stream::stream! {
272        // Identify the host crate first so the client can spot a
273        // version mismatch before consuming any spans.  `CARGO_PKG_VERSION`
274        // is the workspace-pinned version, same as the client binary's.
275        yield Response::server_info(env!("CARGO_PKG_VERSION")).with_id(request_id);
276        // Push current level + chance next so the client renders
277        // its switcher / chance UI before any spans land.
278        let initial_level = *level_rx.borrow_and_update();
279        yield Response::cache_level(initial_level).with_id(request_id);
280        let initial_chance = *chance_rx.borrow_and_update();
281        yield Response::cache_chance(initial_chance).with_id(request_id);
282
283        // Register a subscriber — the driver fans every closed span
284        // into this receiver in commit (close-time) order, replacing
285        // the previous open-time `page(after_id, _)` cursor that
286        // silently dropped spans whenever close order diverged from
287        // open order (typical for async workloads).  Dropping this
288        // receiver (e.g. when the stream future ends) tells the
289        // driver to prune the sender on its next fan-out.
290        let mut span_rx = cache.subscribe(STREAM_SUBSCRIBER_CAPACITY);
291
292        loop {
293            tokio::select! {
294                changed = level_rx.changed() => {
295                    if changed.is_err() { break; }
296                    let lvl = *level_rx.borrow_and_update();
297                    yield Response::cache_level(lvl).with_id(request_id);
298                }
299                changed = chance_rx.changed() => {
300                    if changed.is_err() { break; }
301                    let pct = *chance_rx.borrow_and_update();
302                    yield Response::cache_chance(pct).with_id(request_id);
303                }
304                batch = span_rx.next_batch() => {
305                    let Some(batch) = batch else { break };
306                    let (streaming, min_level, sampling_rate) = {
307                        #[allow(clippy::expect_used, reason = "poisoned lock")]
308                        let s = state.read().expect("lock must not be poisoned");
309                        (s.streaming, s.min_level, s.sampling_rate)
310                    };
311                    if !streaming {
312                        // StopStream is in effect — keep draining the
313                        // batch (so the subscriber channel doesn't
314                        // back up) but discard.
315                        drop(batch);
316                        continue;
317                    }
318                    for record in batch {
319                        if let Some(min) = min_level
320                            && !level_at_least(record.metadata.level(), min)
321                        {
322                            continue;
323                        }
324                        if !sampling_passes(&record, sampling_rate) {
325                            continue;
326                        }
327                        yield Response::span(span_to_wire(&record, base)).with_id(request_id);
328                    }
329                }
330            }
331        }
332    }
333}
334
335/// True iff `record_level` is at least as severe as `floor`.  In tracing's
336/// reversed `Ord`, lower-severity levels compare *greater* (ERROR < WARN <
337/// INFO < DEBUG < TRACE), so the "at least as severe" relation is `<=`.
338/// E.g. with floor=INFO: INFO/WARN/ERROR pass, DEBUG/TRACE don't.
339fn level_at_least(record_level: &tracing::Level, floor: WireLevel) -> bool {
340    record_level <= &floor.to_tracing()
341}
342
343/// Hash-based sampling so the same root id deterministically passes/fails.
344/// Descendants follow the root's decision (the cache feeds us spans in
345/// increasing actual_id order, so a root is always streamed before its kids
346/// — but we still memoise via `root_decisions` to handle late-arriving ids).
347fn sampling_passes(record: &SpanRecord, rate: f64) -> bool {
348    if rate >= 1.0 {
349        return true;
350    }
351    if rate <= 0.0 {
352        return false;
353    }
354    // Use the root id (or this id, if a root) to pick the bucket.
355    let bucket_id = record.parent_id.unwrap_or(record.id);
356    // Cheap deterministic hash — splitmix-style.
357    let mut x = bucket_id.wrapping_mul(0x9E37_79B9_7F4A_7C15);
358    x ^= x >> 33;
359    x = x.wrapping_mul(0xC2B2_AE3D_27D4_EB4F);
360    x ^= x >> 29;
361    let frac = (x as f64) / (u64::MAX as f64);
362    frac < rate
363}
364
365// ── Top-level acceptor ───────────────────────────────────────────────────────
366
367struct Service<P: EnabledPredicate> {
368    cache: Arc<SpanCache<P>>,
369    base: TimeBase,
370    level_bus: CacheLevelBroadcast,
371}
372
373impl<P: EnabledPredicate> SocketService for Service<P> {
374    type Codec = ServerCodec;
375    type ConnectionService = ConnectionState<P>;
376    type SocketListener = TcpSocketListener;
377
378    fn codec(&self) -> Self::Codec {
379        (
380            MessagePackSerializer::default(),
381            MessagePackDecoder::default(),
382        )
383    }
384
385    fn new_stream_service(
386        &self,
387        _stream: &<Self::SocketListener as protosocket::SocketListener>::Stream,
388    ) -> Self::ConnectionService {
389        ConnectionState::new(Arc::clone(&self.cache), self.base, self.level_bus.clone())
390    }
391}
392
393/// Errors returned by [`serve`].
394#[derive(Debug)]
395pub enum ServeError {
396    Io(std::io::Error),
397    Rpc(protosocket_rpc::Error),
398}
399impl std::fmt::Display for ServeError {
400    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
401        match self {
402            ServeError::Io(e) => write!(f, "io: {e}"),
403            ServeError::Rpc(e) => write!(f, "rpc: {e}"),
404        }
405    }
406}
407impl std::error::Error for ServeError {}
408impl From<std::io::Error> for ServeError {
409    fn from(e: std::io::Error) -> Self {
410        ServeError::Io(e)
411    }
412}
413impl From<protosocket_rpc::Error> for ServeError {
414    fn from(e: protosocket_rpc::Error) -> Self {
415        ServeError::Rpc(e)
416    }
417}
418
419/// Bind to `addr` and serve the console RPC protocol against `cache`.
420///
421/// `level_handle` is the `LevelHandle` returned by the cache's
422/// `LevelPredicate`; the server uses it to apply `SetCacheLevel`
423/// requests and to broadcast the resulting level to every connected
424/// stream.  Caller is responsible for spawning the cache's `Driver`
425/// and for keeping `level_handle` consistent with what the cache
426/// actually uses.  The future runs until the listener errors out.
427pub async fn serve<P: EnabledPredicate>(
428    cache: Arc<SpanCache<P>>,
429    level_handle: LevelHandle,
430    chance_handle: ChanceHandle,
431    addr: SocketAddr,
432) -> Result<(), ServeError> {
433    // listen(addr, listen_backlog, accept_timeout) — last two are optional knobs.
434    let listener = TcpSocketListener::listen(addr, 1024, None)?;
435
436    let service = Service {
437        cache,
438        base: TimeBase::now(),
439        level_bus: CacheLevelBroadcast::new(level_handle, chance_handle),
440    };
441    let server: SocketRpcServer<Service<P>, _> = SocketRpcServer::new(
442        listener,
443        service,
444        /* max_buffer_length */ 16 * 1024 * 1024,
445        /* buffer_allocation_increment */ 64 * 1024,
446        /* max_queued_outbound_messages */ 4096,
447    )?;
448    server.await?;
449    Ok(())
450}
451
452// ── Integration tests ────────────────────────────────────────────────────────
453
454#[cfg(test)]
455mod tests {
456    use super::*;
457    use std::net::TcpListener as StdTcpListener;
458    use std::time::Duration;
459
460    use futures::StreamExt;
461    use protosocket_messagepack::{MessagePackDecoder, MessagePackSerializer};
462    use protosocket_rpc::client::{self, Configuration, RpcClient, TcpStreamConnector};
463    use tracing_cache::{ChancePredicate, SpanCache};
464
465    use crate::protocol::{ResponseBody, WireLevel};
466
467    type ClientCodec = (MessagePackSerializer<Request>, MessagePackDecoder<Response>);
468
469    /// Bind a std listener to ephemeral port, capture the port, drop it.  The
470    /// next bind on this port (by `serve`) reuses it (SO_REUSEADDR is set).
471    /// There's a tiny race window — fine on a developer box and CI.
472    fn pick_addr() -> SocketAddr {
473        let listener = StdTcpListener::bind("127.0.0.1:0").unwrap();
474        let port = listener.local_addr().unwrap().port();
475        drop(listener);
476        format!("127.0.0.1:{port}").parse().unwrap()
477    }
478
479    /// Build a SpanCache and spawn its driver so the subscriber model
480    /// can fan out closed spans live.  Tests emit spans *after*
481    /// subscribing — the cache holds no history, so anything emitted
482    /// before the first `StartStream` is lost.
483    fn prepare_cache() -> (
484        Arc<SpanCache<ChancePredicate<tracing_cache::LevelPredicate>>>,
485        LevelHandle,
486        ChanceHandle,
487    ) {
488        let level =
489            tracing_cache::LevelPredicate::with_filter(tracing::metadata::LevelFilter::TRACE);
490        let level_handle = level.handle();
491        let predicate = ChancePredicate::new(level, 100.0);
492        let chance_handle = predicate.handle();
493        let (cache, driver) = SpanCache::with_predicate(1024, predicate);
494        let cache = Arc::new(cache);
495        tokio::spawn(driver.run());
496        (cache, level_handle, chance_handle)
497    }
498
499    /// Run `f` with `cache` as the active tracing subscriber, then
500    /// flush this thread's PENDING buffer so the spans leave for the
501    /// driver immediately.
502    fn emit_under<P: EnabledPredicate>(cache: &Arc<SpanCache<P>>, f: impl FnOnce()) {
503        tracing::subscriber::with_default(Arc::clone(cache), f);
504        cache.flush_pending();
505    }
506
507    /// Drain the initial `ServerInfo` + `CacheLevel` + `CacheChance`
508    /// messages `span_stream` always pushes before subscribing.
509    /// Receiving all three is the sync point that proves the
510    /// subscriber is registered, so any span emitted afterward will
511    /// be fanned out into this stream.
512    async fn wait_for_initial(
513        stream: &mut (impl futures::Stream<Item = Result<Response, protosocket_rpc::Error>> + Unpin),
514    ) {
515        let mut got_server_info = false;
516        let mut got_level = false;
517        let mut got_chance = false;
518        let deadline = tokio::time::Instant::now() + Duration::from_secs(2);
519        while !(got_server_info && got_level && got_chance)
520            && tokio::time::Instant::now() < deadline
521        {
522            match tokio::time::timeout(Duration::from_millis(200), stream.next()).await {
523                Ok(Some(Ok(resp))) => match resp.body {
524                    ResponseBody::ServerInfo(_) => got_server_info = true,
525                    ResponseBody::CacheLevel(_) => got_level = true,
526                    ResponseBody::CacheChance(_) => got_chance = true,
527                    _ => {}
528                },
529                _ => break,
530            }
531        }
532        assert!(
533            got_server_info && got_level && got_chance,
534            "stream did not yield initial ServerInfo/CacheLevel/CacheChance",
535        );
536    }
537
538    /// Spawn `serve` on a free port; return the address and a JoinHandle for
539    /// abort-on-drop semantics.  Briefly retries connect to confirm bind.
540    async fn spawn_server<P: EnabledPredicate>(
541        cache: Arc<SpanCache<P>>,
542        level_handle: LevelHandle,
543        chance_handle: ChanceHandle,
544    ) -> (SocketAddr, tokio::task::JoinHandle<()>) {
545        let addr = pick_addr();
546        let server_cache = Arc::clone(&cache);
547        let serve_level = level_handle.clone();
548        let serve_chance = chance_handle.clone();
549        let handle = tokio::spawn(async move {
550            // Discard the result; the test aborts this task at the end.
551            let _ = serve(server_cache, serve_level, serve_chance, addr).await;
552        });
553        // Wait for the server to actually be listening.
554        for _ in 0..50 {
555            if std::net::TcpStream::connect(addr).is_ok() {
556                return (addr, handle);
557            }
558            tokio::time::sleep(Duration::from_millis(10)).await;
559        }
560        panic!("server never came up on {addr}");
561    }
562
563    async fn connect_client(addr: SocketAddr) -> RpcClient<Request, Response> {
564        let cfg = Configuration::new(TcpStreamConnector);
565        let (rpc_client, conn) = client::connect::<ClientCodec, _>(addr, &cfg).await.unwrap();
566        // Drive the connection's I/O loop in the background.
567        tokio::spawn(conn);
568        rpc_client
569    }
570
571    /// Try to receive `n` Span responses from the stream within `total_timeout`.
572    async fn collect_spans(
573        stream: &mut (impl futures::Stream<Item = Result<Response, protosocket_rpc::Error>> + Unpin),
574        n: usize,
575        total_timeout: Duration,
576    ) -> Vec<crate::WireSpan> {
577        let mut out = Vec::with_capacity(n);
578        let deadline = tokio::time::Instant::now() + total_timeout;
579        while out.len() < n {
580            let remaining = deadline.saturating_duration_since(tokio::time::Instant::now());
581            match tokio::time::timeout(remaining, stream.next()).await {
582                Ok(Some(Ok(resp))) => {
583                    if let ResponseBody::Span(s) = resp.body {
584                        out.push(s);
585                    }
586                }
587                Ok(Some(Err(_))) | Ok(None) => break,
588                Err(_) => break,
589            }
590        }
591        out
592    }
593
594    // ── tests ─────────────────────────────────────────────────────────────────
595
596    #[tokio::test]
597    async fn start_stream_delivers_closed_spans() {
598        let (cache, level_handle, chance_handle) = prepare_cache();
599        let (addr, server) = spawn_server(
600            Arc::clone(&cache),
601            level_handle.clone(),
602            chance_handle.clone(),
603        )
604        .await;
605        let client = connect_client(addr).await;
606        let mut stream = client
607            .send_streaming(Request::new(RequestBody::StartStream))
608            .unwrap();
609        wait_for_initial(&mut stream).await;
610
611        emit_under(&cache, || {
612            for _ in 0..3 {
613                let span = tracing::span!(parent: None, tracing::Level::INFO, "test_a");
614                let _g = span.enter();
615            }
616        });
617
618        let received = collect_spans(&mut stream, 3, Duration::from_secs(2)).await;
619        assert_eq!(received.len(), 3);
620        assert!(received.iter().all(|s| s.name == "test_a"));
621        assert!(received.iter().all(|s| s.closed_at_ns.is_some()));
622
623        server.abort();
624    }
625
626    #[tokio::test]
627    async fn stop_stream_halts_delivery() {
628        let (cache, level_handle, chance_handle) = prepare_cache();
629        let (addr, server) = spawn_server(
630            Arc::clone(&cache),
631            level_handle.clone(),
632            chance_handle.clone(),
633        )
634        .await;
635        let client = connect_client(addr).await;
636        let mut stream = client
637            .send_streaming(Request::new(RequestBody::StartStream))
638            .unwrap();
639        wait_for_initial(&mut stream).await;
640
641        // First wave: emit one span and confirm it flows through.
642        emit_under(&cache, || {
643            let _g = tracing::span!(parent: None, tracing::Level::INFO, "test_b").entered();
644        });
645        let initial = collect_spans(&mut stream, 1, Duration::from_secs(2)).await;
646        assert_eq!(initial.len(), 1);
647
648        // Pause delivery.
649        let ack = client
650            .send_unary(Request::new(RequestBody::StopStream))
651            .unwrap()
652            .await
653            .unwrap();
654        assert!(matches!(ack.body, ResponseBody::Ack));
655        // Give the streaming task a chance to observe streaming=false.
656        tokio::time::sleep(Duration::from_millis(50)).await;
657
658        // Second wave: emit 5 spans; the server should drain them but
659        // not forward any to the client.
660        emit_under(&cache, || {
661            for _ in 0..5 {
662                let _g = tracing::span!(parent: None, tracing::Level::INFO, "test_b").entered();
663            }
664        });
665        let drained_after_stop = collect_spans(&mut stream, 5, Duration::from_millis(300)).await;
666        assert!(
667            drained_after_stop.len() < 5,
668            "stream did not stop: got {} more spans after StopStream",
669            drained_after_stop.len(),
670        );
671
672        server.abort();
673    }
674
675    #[tokio::test]
676    async fn set_level_filters_below_threshold() {
677        let (cache, level_handle, chance_handle) = prepare_cache();
678        let (addr, server) = spawn_server(
679            Arc::clone(&cache),
680            level_handle.clone(),
681            chance_handle.clone(),
682        )
683        .await;
684        let client = connect_client(addr).await;
685
686        let ack = client
687            .send_unary(Request::new(RequestBody::SetLevel(WireLevel::Info)))
688            .unwrap()
689            .await
690            .unwrap();
691        assert!(matches!(ack.body, ResponseBody::Ack));
692
693        let mut stream = client
694            .send_streaming(Request::new(RequestBody::StartStream))
695            .unwrap();
696        wait_for_initial(&mut stream).await;
697
698        // The cache predicate is TRACE so both spans reach the driver;
699        // the host's SetLevel must be what filters DEBUG on the wire.
700        emit_under(&cache, || {
701            drop(tracing::span!(parent: None, tracing::Level::INFO, "info_span"));
702            drop(tracing::span!(parent: None, tracing::Level::DEBUG, "debug_span"));
703        });
704
705        let received = collect_spans(&mut stream, 2, Duration::from_millis(500)).await;
706        let names: Vec<_> = received.iter().map(|s| s.name.as_str()).collect();
707        assert_eq!(names, vec!["info_span"], "got: {names:?}");
708
709        server.abort();
710    }
711
712    #[tokio::test]
713    async fn set_sampling_rate_zero_drops_all() {
714        let (cache, level_handle, chance_handle) = prepare_cache();
715        let (addr, server) = spawn_server(
716            Arc::clone(&cache),
717            level_handle.clone(),
718            chance_handle.clone(),
719        )
720        .await;
721        let client = connect_client(addr).await;
722
723        client
724            .send_unary(Request::new(RequestBody::SetSamplingRate(0.0)))
725            .unwrap()
726            .await
727            .unwrap();
728        let mut stream = client
729            .send_streaming(Request::new(RequestBody::StartStream))
730            .unwrap();
731        wait_for_initial(&mut stream).await;
732
733        emit_under(&cache, || {
734            for _ in 0..5 {
735                let _g = tracing::span!(parent: None, tracing::Level::INFO, "sampled").entered();
736            }
737        });
738
739        let received = collect_spans(&mut stream, 5, Duration::from_millis(400)).await;
740        assert!(
741            received.is_empty(),
742            "rate=0 should drop everything; got {received:?}",
743        );
744
745        server.abort();
746    }
747
748    /// Setting the cache level must not end the streaming RPC — the
749    /// server should push a fresh `CacheLevel` notification on the
750    /// existing stream and continue streaming spans after.
751    #[tokio::test]
752    async fn set_cache_level_keeps_stream_open() {
753        let (cache, level_handle, chance_handle) = prepare_cache();
754        let (addr, server) = spawn_server(
755            Arc::clone(&cache),
756            level_handle.clone(),
757            chance_handle.clone(),
758        )
759        .await;
760        let client = connect_client(addr).await;
761        // Distinct ids — the framework matches responses to RPCs by
762        // request id, and id=0 would clobber on the client side.
763        let mut start = Request::new(RequestBody::StartStream);
764        start.id = 100;
765        let mut stream = client.send_streaming(start).unwrap();
766
767        // First push is always the ServerInfo handshake.
768        let first = tokio::time::timeout(Duration::from_secs(1), stream.next())
769            .await
770            .unwrap()
771            .unwrap()
772            .unwrap();
773        let server_info = match first.body {
774            ResponseBody::ServerInfo(info) => info,
775            other => panic!("first message should be ServerInfo, got {other:?}"),
776        };
777        assert_eq!(
778            server_info.version,
779            env!("CARGO_PKG_VERSION"),
780            "server should advertise its own CARGO_PKG_VERSION",
781        );
782
783        // Drain the initial CacheLevel + CacheChance pushes so the
784        // loop below observes the *updated* CacheLevel rather than
785        // the initial one.
786        let mut drained_level = false;
787        let mut drained_chance = false;
788        while !(drained_level && drained_chance) {
789            let item = tokio::time::timeout(Duration::from_millis(500), stream.next())
790                .await
791                .unwrap()
792                .unwrap()
793                .unwrap();
794            match item.body {
795                ResponseBody::CacheLevel(_) => drained_level = true,
796                ResponseBody::CacheChance(_) => drained_chance = true,
797                other => panic!("unexpected message during initial drain: {other:?}"),
798            }
799        }
800
801        // Change the level — the unary should ack while the streaming
802        // RPC stays open.
803        let mut set = Request::new(RequestBody::SetCacheLevel(WireLevelFilter::Off));
804        set.id = 101;
805        let ack = client.send_unary(set).unwrap().await.unwrap();
806        assert!(matches!(ack.body, ResponseBody::Ack));
807
808        // Next stream item must be the updated CacheLevel; it must
809        // arrive (not end-of-stream) within a generous window.
810        let mut next_level: Option<WireLevelFilter> = None;
811        let deadline = tokio::time::Instant::now() + Duration::from_secs(2);
812        while tokio::time::Instant::now() < deadline && next_level.is_none() {
813            let item = tokio::time::timeout(Duration::from_millis(200), stream.next()).await;
814            let Ok(Some(Ok(resp))) = item else { continue };
815            match resp.body {
816                ResponseBody::CacheLevel(l) => next_level = Some(l),
817                // Initial chance push + any chance broadcasts are fine
818                // — we just don't care about them in this test.
819                ResponseBody::CacheChance(_) => continue,
820                ResponseBody::ServerInfo(_) => continue,
821                ResponseBody::Span(_) => continue,
822                other => panic!("unexpected stream item: {other:?}"),
823            }
824        }
825        assert_eq!(
826            next_level,
827            Some(WireLevelFilter::Off),
828            "stream did not yield the updated CacheLevel (probably ended)",
829        );
830
831        server.abort();
832    }
833
834    /// When the last streaming RPC drops, the server should reset
835    /// the cache level to `OFF`.  Verified by: connect a client at
836    /// non-OFF level, drop the client, then reconnect and observe
837    /// the initial CacheLevel is OFF.
838    #[tokio::test]
839    async fn level_resets_to_off_when_last_console_disconnects() {
840        let (cache, level_handle, chance_handle) = prepare_cache();
841        // Start at INFO via the cache predicate handle.
842        level_handle.set(LevelFilter::INFO);
843
844        let (addr, server) = spawn_server(
845            Arc::clone(&cache),
846            level_handle.clone(),
847            chance_handle.clone(),
848        )
849        .await;
850
851        // Open a streaming RPC; drop it immediately to mimic a console
852        // disconnect.
853        {
854            let client = connect_client(addr).await;
855            let mut start = Request::new(RequestBody::StartStream);
856            start.id = 200;
857            let _stream = client.send_streaming(start).unwrap();
858            // Wait a beat for the StartStream to register on the
859            // server side, then drop everything.
860            tokio::time::sleep(Duration::from_millis(50)).await;
861        }
862        // Give the server time to notice the disconnect and run the
863        // StreamGuard's Drop.
864        tokio::time::sleep(Duration::from_millis(500)).await;
865
866        // Level handle should now read OFF.
867        assert_eq!(
868            level_handle.get(),
869            LevelFilter::OFF,
870            "level should have reset to OFF after last console disconnected",
871        );
872
873        server.abort();
874    }
875
876    // ── sampling_passes unit tests ───────────────────────────────────────────
877
878    use std::time::Instant;
879    use tracing::callsite::{Callsite, DefaultCallsite, Identifier};
880    use tracing::field::FieldSet;
881    use tracing::metadata::Kind;
882    use tracing_cache::{FieldList, SpanRecord};
883
884    static SAMPLING_CALLSITE: DefaultCallsite = {
885        static META: tracing::Metadata<'static> = tracing::Metadata::new(
886            "sampling_test",
887            "sampling::test",
888            tracing::Level::INFO,
889            None,
890            None,
891            None,
892            FieldSet::new(&[], Identifier(&SAMPLING_CALLSITE)),
893            Kind::SPAN,
894        );
895        DefaultCallsite::new(&META)
896    };
897
898    fn synth_span(id: u64, parent_id: Option<u64>) -> SpanRecord {
899        SpanRecord {
900            id,
901            parent_id,
902            metadata: SAMPLING_CALLSITE.metadata(),
903            fields: FieldList::default(),
904            events: Vec::new(),
905            opened_at: Instant::now(),
906            closed_at: Some(Instant::now()),
907        }
908    }
909
910    #[test]
911    fn sampling_passes_rate_one_short_circuits_true() {
912        // rate >= 1.0 must accept every span regardless of id hash.
913        for id in [0u64, 1, 17, u64::MAX, 0x9E37_79B9_7F4A_7C15] {
914            assert!(sampling_passes(&synth_span(id, None), 1.0));
915        }
916    }
917
918    #[test]
919    fn sampling_passes_rate_zero_short_circuits_false() {
920        for id in [0u64, 1, 17, u64::MAX] {
921            assert!(!sampling_passes(&synth_span(id, None), 0.0));
922        }
923    }
924
925    #[test]
926    fn sampling_passes_is_deterministic_per_root_id() {
927        // Repeating the call with the same record must yield the same
928        // answer — otherwise children inheriting a root's decision
929        // would race against their root's hash.
930        for id in 1u64..=20 {
931            let r = synth_span(id, None);
932            let first = sampling_passes(&r, 0.5);
933            for _ in 0..3 {
934                assert_eq!(sampling_passes(&r, 0.5), first, "id={id}");
935            }
936        }
937    }
938
939    #[test]
940    fn sampling_passes_children_inherit_parents_root_id_bucket() {
941        // Children with `parent_id = Some(root)` must hash on the
942        // root, not on their own id.  Pick a root id that does pass
943        // at rate=0.5 and demonstrate the child gets the same answer.
944        let root = synth_span(7, None);
945        let want = sampling_passes(&root, 0.5);
946        // Several different child ids, all under root=7 → all match.
947        for child_id in [100u64, 200, 300, u64::MAX] {
948            let child = synth_span(child_id, Some(7));
949            assert_eq!(sampling_passes(&child, 0.5), want);
950        }
951    }
952
953    #[test]
954    fn sampling_passes_partitions_population_near_target_rate() {
955        // Coarse distribution sanity-check — splitmix should produce
956        // close-to-uniform fractions, so rate=0.3 over a large pool
957        // should pass roughly 30% of distinct root ids.
958        let rate = 0.3;
959        let n = 5_000u64;
960        let mut passed = 0usize;
961        for id in 1..=n {
962            if sampling_passes(&synth_span(id, None), rate) {
963                passed += 1;
964            }
965        }
966        let frac = passed as f64 / n as f64;
967        assert!(
968            (frac - rate).abs() < 0.03,
969            "frac={frac} rate={rate} — hash distribution drifted",
970        );
971    }
972
973    // ── SetSamplingRate RPC validation ───────────────────────────────────────
974
975    #[tokio::test]
976    async fn set_sampling_rate_rejects_out_of_range() {
977        let (cache, level_handle, chance_handle) = prepare_cache();
978        let (addr, server) = spawn_server(
979            Arc::clone(&cache),
980            level_handle.clone(),
981            chance_handle.clone(),
982        )
983        .await;
984        let client = connect_client(addr).await;
985
986        for bad in [1.5_f64, -0.1, f64::NAN] {
987            let resp = client
988                .send_unary(Request::new(RequestBody::SetSamplingRate(bad)))
989                .unwrap()
990                .await
991                .unwrap();
992            match resp.body {
993                ResponseBody::Error(msg) => {
994                    assert!(
995                        msg.contains("sampling rate"),
996                        "unexpected error message for {bad}: {msg}",
997                    );
998                }
999                other => panic!("expected Error for rate={bad}, got {other:?}"),
1000            }
1001        }
1002        server.abort();
1003    }
1004}