Skip to main content

atproto_devtool/commands/test/labeler/
subscription.rs

1//! Subscription stage for the labeler conformance suite.
2//!
3//! Performs `com.atproto.label.subscribeLabels` requests against the labeler endpoint,
4//! using a two-connection strategy: backfill with cursor=0, and live-tail if backfill
5//! did not complete within the budget.
6
7use std::sync::Arc;
8use std::time::Duration;
9
10use async_trait::async_trait;
11use atrium_api::com::atproto::label::defs::Label;
12use futures_util::StreamExt;
13use miette::{Diagnostic, NamedSource, SourceSpan};
14use serde::{Deserialize, Serialize};
15use thiserror::Error;
16use tokio::time::Instant;
17use url::Url;
18
19/// Frame header parsed from CBOR.
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct FrameHeader {
22    /// Operation type: 1 for message, -1 for error.
23    pub op: i64,
24    /// Message type identifier (e.g., "#labels", "#info"), optional for error frames.
25    #[serde(skip_serializing_if = "Option::is_none")]
26    pub t: Option<String>,
27}
28
29/// Payload for `#labels` message frames.
30#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct SubscribeLabelsPayload {
32    /// Sequence number of this label batch.
33    pub seq: i64,
34    /// Array of labels in this batch.
35    pub labels: Vec<Label>,
36}
37
38/// Payload for `#info` message frames.
39#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct SubscribeInfoPayload {
41    /// Service name.
42    pub name: String,
43    /// Optional additional message.
44    pub message: Option<String>,
45}
46
47/// Payload for error frames (op == -1).
48#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct SubscribeErrorPayload {
50    /// Error code or identifier.
51    pub error: String,
52    /// Optional error description.
53    pub message: Option<String>,
54}
55
56/// A decoded WebSocket frame from subscribeLabels.
57#[derive(Debug, Clone)]
58pub enum DecodedFrame {
59    /// A labels message frame.
60    Labels(SubscribeLabelsPayload),
61    /// An info message frame.
62    Info(SubscribeInfoPayload),
63    /// An error frame.
64    Error(SubscribeErrorPayload),
65}
66
67/// Errors that can occur when decoding a WebSocket frame.
68#[derive(Debug, Clone)]
69pub enum FrameDecodeError {
70    /// Failed to decode the header CBOR block.
71    HeaderDecode {
72        /// Raw bytes of the frame.
73        raw: Arc<[u8]>,
74        /// Human-readable error message.
75        cause: String,
76    },
77    /// Failed to decode the payload CBOR block.
78    PayloadDecode {
79        /// Header successfully decoded.
80        header: FrameHeader,
81        /// Raw bytes of the frame.
82        raw: Arc<[u8]>,
83        /// Human-readable error message.
84        cause: String,
85    },
86    /// Message type not recognized.
87    UnknownMessageType {
88        /// The unrecognized type identifier.
89        t: String,
90        /// Raw bytes of the frame.
91        raw: Arc<[u8]>,
92    },
93    /// Text frame received (not allowed).
94    TextFrameRejected(Arc<[u8]>),
95}
96
97/// Decode a two-CBOR-block WebSocket frame into a typed message.
98pub fn decode_frame(bytes: &[u8]) -> Result<DecodedFrame, FrameDecodeError> {
99    let mut cursor = bytes;
100
101    // Decode the header CBOR block.
102    let header = ciborium::de::from_reader::<FrameHeader, _>(&mut cursor).map_err(|e| {
103        FrameDecodeError::HeaderDecode {
104            raw: Arc::from(bytes),
105            cause: e.to_string(),
106        }
107    })?;
108
109    // Based on op and t, decode the payload block accordingly.
110    match (header.op, &header.t) {
111        (1, Some(t)) if t == "#labels" => {
112            let payload = ciborium::de::from_reader::<SubscribeLabelsPayload, _>(&mut cursor)
113                .map_err(|e| FrameDecodeError::PayloadDecode {
114                    header: header.clone(),
115                    raw: Arc::from(bytes),
116                    cause: e.to_string(),
117                })?;
118            Ok(DecodedFrame::Labels(payload))
119        }
120        (1, Some(t)) if t == "#info" => {
121            let payload = ciborium::de::from_reader::<SubscribeInfoPayload, _>(&mut cursor)
122                .map_err(|e| FrameDecodeError::PayloadDecode {
123                    header: header.clone(),
124                    raw: Arc::from(bytes),
125                    cause: e.to_string(),
126                })?;
127            Ok(DecodedFrame::Info(payload))
128        }
129        (-1, _) => {
130            let payload = ciborium::de::from_reader::<SubscribeErrorPayload, _>(&mut cursor)
131                .map_err(|e| FrameDecodeError::PayloadDecode {
132                    header: header.clone(),
133                    raw: Arc::from(bytes),
134                    cause: e.to_string(),
135                })?;
136            Ok(DecodedFrame::Error(payload))
137        }
138        (_, Some(t)) => Err(FrameDecodeError::UnknownMessageType {
139            t: t.clone(),
140            raw: Arc::from(bytes),
141        }),
142        _ => Err(FrameDecodeError::UnknownMessageType {
143            t: format!("unknown op={} t={:?}", header.op, header.t),
144            raw: Arc::from(bytes),
145        }),
146    }
147}
148
149/// Outcome of the backfill phase.
150#[derive(Debug, Clone)]
151pub enum BackfillOutcome {
152    /// Backfill completed with an idle gap (no frames for 500ms).
153    CompletedWithIdleGap {
154        /// Number of frames observed during backfill.
155        frames_observed: usize,
156        /// Duration of idle gap in milliseconds.
157        idle_gap_ms: u64,
158    },
159    /// Backfill exceeded the time budget while still producing frames.
160    ExceededBudget {
161        /// Number of frames observed before timeout.
162        frames_observed: usize,
163    },
164    /// Server closed the stream before the idle-gap budget was exhausted.
165    StreamClosedDuringBackfill {
166        /// Number of frames observed before the stream closed.
167        frames_observed: usize,
168    },
169    /// No frames received during the entire budget.
170    NoFramesWithinBudget,
171}
172
173/// Outcome of the live-tail phase.
174#[derive(Debug, Clone)]
175pub enum LiveTailOutcome {
176    /// Live tail observed after backfill completed (implicit pass).
177    FromBackfill,
178    /// Live-tail connection held open, frames may have been observed.
179    CleanHold {
180        /// Number of frames observed during live tail.
181        frames_observed: usize,
182    },
183    /// Live tail skipped because no frames were observed in backfill.
184    SkippedEmpty,
185    /// Live-tail connection attempt failed (second connect error after ExceededBudget or StreamClosedDuringBackfill).
186    ConnectFailed,
187}
188
189/// Maximum number of labels to retain from subscribeLabels frames for
190/// downstream crypto verification. Sized to match a typical first-page
191/// response so the crypto stage has a comparable sample when HTTP is
192/// unavailable, without holding an unbounded number of labels in memory
193/// on noisy streams.
194pub const SAMPLE_LABEL_CAP: usize = 256;
195
196/// Facts gathered from the subscription stage.
197#[derive(Debug, Clone)]
198pub struct SubscriptionFacts {
199    /// Outcome of the backfill phase.
200    pub backfill_outcome: BackfillOutcome,
201    /// Outcome of the live-tail phase.
202    pub live_tail_outcome: LiveTailOutcome,
203    /// Any frame decode errors encountered.
204    pub decode_errors: Vec<FrameDecodeError>,
205    /// Labels decoded from `#labels` frames, capped at `SAMPLE_LABEL_CAP`.
206    /// Used by the crypto stage as an alternative/additional label source
207    /// when the HTTP stage cannot provide a sample.
208    pub sample_labels: Vec<Label>,
209}
210
211/// Diagnostic for frame decode failures with source context.
212#[derive(Debug, Error, Diagnostic)]
213#[error("{message}")]
214#[diagnostic(code = "labeler::subscription::frame_decode")]
215pub struct FrameDecodeFailureDiagnostic {
216    /// The error message.
217    pub message: String,
218    /// The raw frame bytes.
219    #[source_code]
220    pub source_code: NamedSource<Arc<[u8]>>,
221    /// Span highlighting the first byte of the frame.
222    #[label("frame decode failure")]
223    pub span: SourceSpan,
224}
225
226/// Errors that can occur in the subscription stage.
227#[derive(Debug, Error)]
228pub enum SubscriptionStageError {
229    /// Network or WebSocket transport error.
230    #[error("Subscription transport error: {message}")]
231    Transport {
232        /// Human-readable error message.
233        message: String,
234        /// The underlying error, if available.
235        #[source]
236        source: Option<Box<dyn std::error::Error + Send + Sync>>,
237    },
238}
239
240/// A stream of WebSocket frames from a subscription connection.
241#[async_trait]
242pub trait FrameStream: Send {
243    /// Retrieve the next frame from the stream, or None if the stream is closed.
244    async fn next_frame(&mut self) -> Option<Result<Vec<u8>, SubscriptionStageError>>;
245
246    /// Close the stream gracefully.
247    async fn close(&mut self);
248}
249
250/// A WebSocket client for connecting to subscription endpoints.
251#[async_trait]
252pub trait WebSocketClient: Send + Sync {
253    /// Connect to a WebSocket endpoint and return a frame stream.
254    async fn connect(&self, url: &Url) -> Result<Box<dyn FrameStream>, SubscriptionStageError>;
255}
256
257/// Real WebSocket client using tokio-tungstenite.
258pub struct RealWebSocketClient;
259
260/// Real frame stream wrapping a tokio-tungstenite WebSocketStream.
261pub struct RealFrameStream {
262    /// The underlying WebSocket stream.
263    stream: tokio_tungstenite::WebSocketStream<
264        tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
265    >,
266}
267
268#[async_trait]
269impl FrameStream for RealFrameStream {
270    async fn next_frame(&mut self) -> Option<Result<Vec<u8>, SubscriptionStageError>> {
271        use tokio_tungstenite::tungstenite::Message;
272
273        loop {
274            match self.stream.next().await? {
275                Ok(Message::Binary(data)) => {
276                    return Some(Ok(data.to_vec()));
277                }
278                Ok(Message::Text(_)) => {
279                    return Some(Err(SubscriptionStageError::Transport {
280                        message: "received text frame, expected binary".to_string(),
281                        source: None,
282                    }));
283                }
284                Ok(Message::Ping(_)) | Ok(Message::Pong(_)) => {
285                    continue;
286                }
287                Ok(Message::Close(_)) => {
288                    return None;
289                }
290                Ok(Message::Frame(_)) => {
291                    continue;
292                }
293                Err(e) => {
294                    return Some(Err(SubscriptionStageError::Transport {
295                        message: e.to_string(),
296                        source: Some(Box::new(e)),
297                    }));
298                }
299            }
300        }
301    }
302
303    async fn close(&mut self) {
304        let _ = self.stream.close(None).await;
305    }
306}
307
308#[async_trait]
309impl WebSocketClient for RealWebSocketClient {
310    async fn connect(&self, url: &Url) -> Result<Box<dyn FrameStream>, SubscriptionStageError> {
311        use tokio_tungstenite::tungstenite::client::IntoClientRequest;
312
313        let request = url.to_string().into_client_request().map_err(|e| {
314            SubscriptionStageError::Transport {
315                message: e.to_string(),
316                source: Some(Box::new(e)),
317            }
318        })?;
319
320        let (stream, _response) = tokio_tungstenite::connect_async(request)
321            .await
322            .map_err(|e| SubscriptionStageError::Transport {
323                message: e.to_string(),
324                source: Some(Box::new(e)),
325            })?;
326
327        Ok(Box::new(RealFrameStream { stream }))
328    }
329}
330
331/// Checks emitted by the subscription stage.
332#[derive(Debug, Clone, Copy, PartialEq, Eq)]
333pub enum Check {
334    /// Whether the backfill connection succeeded.
335    EndpointReachable,
336    /// Whether the live-tail connection succeeded.
337    LiveTailEndpointReachable,
338    /// Backfill phase outcome.
339    Backfill,
340    /// Live-tail phase outcome.
341    LiveTail,
342    /// WebSocket frame decode error.
343    FrameDecode,
344}
345
346impl Check {
347    /// Stable check ID string used in `CheckResult.id`.
348    pub fn id(self) -> &'static str {
349        match self {
350            Check::EndpointReachable => "subscription::endpoint_reachable",
351            Check::LiveTailEndpointReachable => "subscription::live_tail_endpoint_reachable",
352            Check::Backfill => "subscription::backfill",
353            Check::LiveTail => "subscription::live_tail",
354            Check::FrameDecode => "subscription::frame_decode",
355        }
356    }
357
358    pub fn pass(self) -> crate::commands::test::labeler::report::CheckResult {
359        use crate::commands::test::labeler::report::{CheckStatus, Stage};
360        crate::commands::test::labeler::report::CheckResult {
361            id: self.id(),
362            stage: Stage::Subscription,
363            status: CheckStatus::Pass,
364            summary: std::borrow::Cow::Borrowed(match self {
365                Check::Backfill => "Subscription backfill completed",
366                Check::LiveTail => "Subscription live-tail connection held",
367                _ => "subscription check passed",
368            }),
369            diagnostic: None,
370            skipped_reason: None,
371        }
372    }
373
374    pub fn spec_violation(
375        self,
376        diagnostic: Box<dyn miette::Diagnostic + Send + Sync>,
377    ) -> crate::commands::test::labeler::report::CheckResult {
378        use crate::commands::test::labeler::report::{CheckStatus, Stage};
379        crate::commands::test::labeler::report::CheckResult {
380            id: self.id(),
381            stage: Stage::Subscription,
382            status: CheckStatus::SpecViolation,
383            summary: std::borrow::Cow::Borrowed(match self {
384                Check::FrameDecode => "Subscription frame decode failure",
385                _ => "subscription check failed",
386            }),
387            diagnostic: Some(diagnostic),
388            skipped_reason: None,
389        }
390    }
391
392    pub fn network_error(self) -> crate::commands::test::labeler::report::CheckResult {
393        use crate::commands::test::labeler::report::{CheckStatus, Stage};
394        crate::commands::test::labeler::report::CheckResult {
395            id: self.id(),
396            stage: Stage::Subscription,
397            status: CheckStatus::NetworkError,
398            summary: std::borrow::Cow::Borrowed(match self {
399                Check::EndpointReachable => "Subscription endpoint reachability",
400                Check::LiveTailEndpointReachable => "Subscription live-tail reachability",
401                _ => "subscription network error",
402            }),
403            diagnostic: None,
404            skipped_reason: None,
405        }
406    }
407
408    pub fn advisory(self) -> crate::commands::test::labeler::report::CheckResult {
409        use crate::commands::test::labeler::report::{CheckStatus, Stage};
410        crate::commands::test::labeler::report::CheckResult {
411            id: self.id(),
412            stage: Stage::Subscription,
413            status: CheckStatus::Advisory,
414            summary: std::borrow::Cow::Borrowed(match self {
415                Check::Backfill => "Subscription backfill advisory",
416                _ => "subscription advisory",
417            }),
418            diagnostic: None,
419            skipped_reason: None,
420        }
421    }
422
423    pub fn skip(
424        self,
425        reason: impl Into<std::borrow::Cow<'static, str>>,
426    ) -> crate::commands::test::labeler::report::CheckResult {
427        use crate::commands::test::labeler::report::{CheckStatus, Stage};
428        crate::commands::test::labeler::report::CheckResult {
429            id: self.id(),
430            stage: Stage::Subscription,
431            status: CheckStatus::Skipped,
432            summary: std::borrow::Cow::Borrowed(match self {
433                Check::LiveTail => "Subscription live-tail skipped",
434                _ => "subscription check skipped",
435            }),
436            diagnostic: None,
437            skipped_reason: Some(reason.into()),
438        }
439    }
440}
441
442/// Output from the subscription stage: facts (if any) plus all check results.
443#[derive(Debug)]
444pub struct SubscriptionStageOutput {
445    /// Facts populated only when the stage completes without blocking errors.
446    pub facts: Option<SubscriptionFacts>,
447    /// All check results from this stage.
448    pub results: Vec<crate::commands::test::labeler::report::CheckResult>,
449}
450
451/// Push labels from a decoded `#labels` frame into the sample buffer,
452/// stopping once the buffer reaches `SAMPLE_LABEL_CAP`.
453fn collect_sample_labels(buffer: &mut Vec<Label>, frame_labels: Vec<Label>) {
454    if buffer.len() >= SAMPLE_LABEL_CAP {
455        return;
456    }
457    let remaining = SAMPLE_LABEL_CAP - buffer.len();
458    if frame_labels.len() <= remaining {
459        buffer.extend(frame_labels);
460    } else {
461        buffer.extend(frame_labels.into_iter().take(remaining));
462    }
463}
464
465/// Run live-tail on a fresh connection and drain frames until budget exhausted or stream closes.
466async fn run_live_tail(
467    endpoint: &Url,
468    ws: &dyn WebSocketClient,
469    budget: Duration,
470    decode_errors: &mut Vec<FrameDecodeError>,
471    sample_labels: &mut Vec<Label>,
472) -> Result<LiveTailOutcome, SubscriptionStageError> {
473    // Build live-tail URL (no cursor parameter to stream from latest).
474    let mut live_tail_url = endpoint.clone();
475    live_tail_url.set_path("xrpc/com.atproto.label.subscribeLabels");
476    if live_tail_url.scheme() == "https" {
477        let _ = live_tail_url.set_scheme("wss");
478    }
479
480    tracing::debug!(url = %live_tail_url, "subscription stage: connecting for live-tail");
481
482    match ws.connect(&live_tail_url).await {
483        Ok(mut live_stream) => {
484            let mut live_frames_observed = 0;
485            let live_deadline = Instant::now() + budget;
486
487            loop {
488                if Instant::now() >= live_deadline {
489                    break;
490                }
491                let time_left = live_deadline.saturating_duration_since(Instant::now());
492                match tokio::time::timeout(time_left, live_stream.next_frame()).await {
493                    Ok(Some(Ok(frame))) => {
494                        live_frames_observed += 1;
495                        tracing::trace!(
496                            frame_num = live_frames_observed,
497                            frame_len = frame.len(),
498                            "subscription stage: live-tail frame received"
499                        );
500                        match decode_frame(&frame) {
501                            Ok(DecodedFrame::Labels(payload)) => {
502                                collect_sample_labels(sample_labels, payload.labels);
503                            }
504                            Ok(_) => {}
505                            Err(e) => decode_errors.push(e),
506                        }
507                    }
508                    Ok(Some(Err(_))) => {
509                        live_frames_observed += 1;
510                    }
511                    Ok(None) | Err(_) => break,
512                }
513            }
514
515            tracing::debug!(
516                live_frames_observed,
517                "subscription stage: live-tail phase finished"
518            );
519            live_stream.close().await;
520            Ok(LiveTailOutcome::CleanHold {
521                frames_observed: live_frames_observed,
522            })
523        }
524        Err(e) => {
525            tracing::debug!(url = %live_tail_url, "subscription stage: live-tail connect failed");
526            Err(e)
527        }
528    }
529}
530
531/// Run the subscription stage with a two-connection backfill + live-tail strategy.
532pub async fn run(
533    labeler_endpoint: &Url,
534    ws: &dyn WebSocketClient,
535    budget_per_connection: Duration,
536) -> SubscriptionStageOutput {
537    use crate::commands::test::labeler::report::CheckResult;
538    use std::borrow::Cow;
539    use std::collections::HashSet;
540
541    // Build the subscription URL with cursor=0 for backfill.
542    let backfill_url = {
543        let mut url = labeler_endpoint.clone();
544        url.set_path("xrpc/com.atproto.label.subscribeLabels");
545        {
546            let mut query = url.query_pairs_mut();
547            query.append_pair("cursor", "0");
548        }
549        // Ensure the scheme is wss.
550        if url.scheme() == "https" {
551            let _ = url.set_scheme("wss");
552        }
553        url
554    };
555
556    tracing::debug!(url = %backfill_url, "subscription stage: connecting for backfill");
557
558    // Attempt to connect for backfill.
559    let mut stream = match ws.connect(&backfill_url).await {
560        Ok(s) => s,
561        Err(_e) => {
562            tracing::debug!(url = %backfill_url, "subscription stage: backfill connect failed");
563            return SubscriptionStageOutput {
564                facts: None,
565                results: vec![Check::EndpointReachable.network_error()],
566            };
567        }
568    };
569
570    // Backfill phase: drain frames with a budget and idle-gap detection.
571    let mut backfill_outcome = BackfillOutcome::NoFramesWithinBudget;
572    let mut live_tail_outcome: Option<LiveTailOutcome> = None;
573    let mut decode_errors: Vec<FrameDecodeError> = vec![];
574    let mut sample_labels: Vec<Label> = Vec::new();
575    let mut frames_observed = 0;
576    let mut last_frame_at: Option<Instant> = None;
577
578    let backfill_deadline = Instant::now() + budget_per_connection;
579
580    loop {
581        // Check if the deadline has been exceeded.
582        if Instant::now() >= backfill_deadline {
583            if frames_observed > 0 {
584                backfill_outcome = BackfillOutcome::ExceededBudget { frames_observed };
585            }
586            break;
587        }
588
589        // Compute the timeout for the next frame: either budget remaining or idle gap.
590        let idle_gap_deadline = last_frame_at.map(|t| t + Duration::from_millis(500));
591        let timeout = if let Some(idle_deadline) = idle_gap_deadline {
592            if idle_deadline <= Instant::now() {
593                backfill_outcome = BackfillOutcome::CompletedWithIdleGap {
594                    frames_observed,
595                    idle_gap_ms: 500,
596                };
597                live_tail_outcome = Some(LiveTailOutcome::FromBackfill);
598                break;
599            }
600            let idle_time_left = idle_deadline.saturating_duration_since(Instant::now());
601            let budget_time_left = backfill_deadline.saturating_duration_since(Instant::now());
602            idle_time_left.min(budget_time_left)
603        } else {
604            backfill_deadline.saturating_duration_since(Instant::now())
605        };
606
607        // Wait for the next frame with timeout.
608        match tokio::time::timeout(timeout, stream.next_frame()).await {
609            Ok(Some(Ok(frame_bytes))) => {
610                last_frame_at = Some(Instant::now());
611                frames_observed += 1;
612                tracing::trace!(
613                    frame_num = frames_observed,
614                    frame_len = frame_bytes.len(),
615                    "subscription stage: backfill frame received"
616                );
617                match decode_frame(&frame_bytes) {
618                    Ok(DecodedFrame::Labels(payload)) => {
619                        collect_sample_labels(&mut sample_labels, payload.labels);
620                    }
621                    Ok(_) => {}
622                    Err(e) => decode_errors.push(e),
623                }
624            }
625            Ok(Some(Err(_e))) => {
626                // Transport error: do not reset the idle-gap timer.
627            }
628            Ok(None) => {
629                // Stream closed. The server closed before the idle-gap budget was exhausted.
630                if frames_observed > 0 {
631                    backfill_outcome =
632                        BackfillOutcome::StreamClosedDuringBackfill { frames_observed };
633                } else {
634                    backfill_outcome = BackfillOutcome::NoFramesWithinBudget;
635                }
636                break;
637            }
638            Err(_e) => {
639                if frames_observed > 0 {
640                    if let Some(idle_deadline) = idle_gap_deadline {
641                        if Instant::now() >= idle_deadline {
642                            backfill_outcome = BackfillOutcome::CompletedWithIdleGap {
643                                frames_observed,
644                                idle_gap_ms: 500,
645                            };
646                            live_tail_outcome = Some(LiveTailOutcome::FromBackfill);
647                        } else {
648                            backfill_outcome = BackfillOutcome::ExceededBudget { frames_observed };
649                        }
650                    } else {
651                        backfill_outcome = BackfillOutcome::ExceededBudget { frames_observed };
652                    }
653                }
654                break;
655            }
656        }
657    }
658
659    tracing::debug!(
660        frames_observed,
661        outcome = ?backfill_outcome,
662        "subscription stage: backfill phase finished"
663    );
664
665    // Close the stream. When we exit normally (idle gap or stream closed), the stream is already
666    // closed, so this is a noop. When we exit due to timeout or error, we close explicitly.
667    // Either way, calling close() on an already-closed stream is harmless.
668    stream.close().await;
669
670    // Determine the live-tail outcome if not already set.
671    let live_tail_outcome = if let Some(outcome) = live_tail_outcome {
672        outcome
673    } else {
674        match &backfill_outcome {
675            // Server has more labels than we can check; make sure we observe the tail.
676            BackfillOutcome::ExceededBudget { .. }
677            // Server closed the stream during backfill; attempt a second live-tail
678            // connection to detect if the labeler supports live-tail separately.
679            | BackfillOutcome::StreamClosedDuringBackfill { .. } => {
680                run_live_tail(
681                    labeler_endpoint,
682                    ws,
683                    budget_per_connection,
684                    &mut decode_errors,
685                    &mut sample_labels,
686                )
687                .await
688                .ok()
689                // Live-tail connection failed; mark with ConnectFailed outcome.
690                .unwrap_or(LiveTailOutcome::ConnectFailed)
691            }
692            BackfillOutcome::NoFramesWithinBudget => LiveTailOutcome::SkippedEmpty,
693            BackfillOutcome::CompletedWithIdleGap { .. } => {
694                unreachable!(
695                    "live_tail_outcome is already Some(FromBackfill) for CompletedWithIdleGap"
696                );
697            }
698        }
699    };
700
701    // Build check results.
702    let mut results = vec![];
703
704    // Live-tail connect error result (if applicable).
705    if matches!(live_tail_outcome, LiveTailOutcome::ConnectFailed) {
706        results.push(Check::LiveTailEndpointReachable.network_error());
707    }
708
709    // Backfill check result.
710    results.push(match &backfill_outcome {
711        BackfillOutcome::CompletedWithIdleGap { .. } => Check::Backfill.pass(),
712        BackfillOutcome::ExceededBudget { .. } => CheckResult {
713            summary: Cow::Borrowed("Subscription backfill exceeded budget"),
714            ..Check::Backfill.advisory()
715        },
716        BackfillOutcome::StreamClosedDuringBackfill { .. } => CheckResult {
717            summary: Cow::Borrowed("Subscription backfill stream closed unexpectedly"),
718            ..Check::Backfill.advisory()
719        },
720        BackfillOutcome::NoFramesWithinBudget => CheckResult {
721            summary: Cow::Borrowed("Subscription backfill had no frames"),
722            skipped_reason: Some(Cow::Borrowed("labeler has no published labels")),
723            ..Check::Backfill.advisory()
724        },
725    });
726
727    // Live-tail check result.
728    // ConnectFailed is already handled with a NetworkError result above, so skip the live-tail row.
729    if !matches!(live_tail_outcome, LiveTailOutcome::ConnectFailed) {
730        results.push(match live_tail_outcome {
731            LiveTailOutcome::FromBackfill => CheckResult {
732                summary: Cow::Borrowed("Subscription live-tail observed after backfill"),
733                ..Check::LiveTail.pass()
734            },
735            LiveTailOutcome::CleanHold { .. } => Check::LiveTail.pass(),
736            LiveTailOutcome::SkippedEmpty => {
737                Check::LiveTail.skip("labeler has no published labels")
738            }
739            LiveTailOutcome::ConnectFailed => {
740                unreachable!("ConnectFailed case should be filtered by outer guard")
741            }
742        });
743    }
744
745    // Add spec violation results for unique decode error variants.
746    let mut seen_variants = HashSet::new();
747    for err in decode_errors.iter() {
748        let variant_key = std::mem::discriminant(err);
749        if seen_variants.insert(variant_key) {
750            let (raw_bytes, msg) = match err {
751                FrameDecodeError::HeaderDecode { raw, cause } => {
752                    (raw.clone(), format!("Header decode failed: {cause}"))
753                }
754                FrameDecodeError::PayloadDecode { raw, cause, .. } => {
755                    (raw.clone(), format!("Payload decode failed: {cause}"))
756                }
757                FrameDecodeError::UnknownMessageType { t, raw } => {
758                    (raw.clone(), format!("Unknown message type: {t}"))
759                }
760                FrameDecodeError::TextFrameRejected(raw) => (
761                    raw.clone(),
762                    "Text frame rejected (expected binary)".to_string(),
763                ),
764            };
765
766            let diagnostic = FrameDecodeFailureDiagnostic {
767                message: msg,
768                source_code: NamedSource::new("frame", raw_bytes),
769                span: SourceSpan::new(0.into(), 1),
770            };
771
772            results.push(Check::FrameDecode.spec_violation(Box::new(diagnostic)));
773        }
774    }
775
776    let facts = Some(SubscriptionFacts {
777        backfill_outcome,
778        live_tail_outcome,
779        decode_errors,
780        sample_labels,
781    });
782
783    SubscriptionStageOutput { facts, results }
784}
785
786#[cfg(test)]
787mod tests {
788    use super::*;
789
790    /// Helper to encode a struct into CBOR bytes.
791    fn encode_cbor<T: Serialize>(value: &T) -> Vec<u8> {
792        let mut buf = Vec::new();
793        ciborium::ser::into_writer(value, &mut buf).expect("failed to encode CBOR");
794        buf
795    }
796
797    #[test]
798    fn collect_sample_labels_respects_cap() {
799        use atrium_api::com::atproto::label::defs::LabelData;
800        use atrium_api::types::string::Datetime;
801
802        let make_label = |i: usize| -> Label {
803            LabelData {
804                cid: None,
805                cts: Datetime::new("2026-01-01T00:00:00.000Z".parse().expect("valid datetime")),
806                exp: None,
807                neg: None,
808                sig: Some(vec![0u8; 64]),
809                src: "did:plc:test123456789abcdefghijklmnop"
810                    .parse()
811                    .expect("valid did"),
812                uri: format!("at://did:plc:test123456789abcdefghijklmnop/x/{i}"),
813                val: "spam".to_string(),
814                ver: Some(1),
815            }
816            .into()
817        };
818
819        // First batch fills part of the buffer.
820        let mut buffer: Vec<Label> = Vec::new();
821        let half = SAMPLE_LABEL_CAP / 2;
822        collect_sample_labels(&mut buffer, (0..half).map(make_label).collect());
823        assert_eq!(buffer.len(), half);
824
825        // Oversized second batch fills exactly to the cap and discards the rest.
826        let oversized: Vec<Label> = (0..(SAMPLE_LABEL_CAP * 2)).map(make_label).collect();
827        collect_sample_labels(&mut buffer, oversized);
828        assert_eq!(buffer.len(), SAMPLE_LABEL_CAP);
829
830        // A subsequent batch is dropped entirely once the cap is reached.
831        collect_sample_labels(&mut buffer, vec![make_label(99999)]);
832        assert_eq!(buffer.len(), SAMPLE_LABEL_CAP);
833    }
834
835    #[test]
836    fn decode_labels_frame_valid() {
837        let header = FrameHeader {
838            op: 1,
839            t: Some("#labels".to_string()),
840        };
841        let payload = SubscribeLabelsPayload {
842            seq: 0,
843            labels: vec![],
844        };
845
846        let mut frame_bytes = encode_cbor(&header);
847        frame_bytes.extend(encode_cbor(&payload));
848
849        match decode_frame(&frame_bytes) {
850            Ok(DecodedFrame::Labels(p)) => {
851                assert_eq!(p.seq, 0);
852                assert!(p.labels.is_empty());
853            }
854            other => panic!("expected DecodedFrame::Labels, got {other:?}"),
855        }
856    }
857
858    #[test]
859    fn decode_info_frame_valid() {
860        let header = FrameHeader {
861            op: 1,
862            t: Some("#info".to_string()),
863        };
864        let payload = SubscribeInfoPayload {
865            name: "test-service".to_string(),
866            message: Some("info message".to_string()),
867        };
868
869        let mut frame_bytes = encode_cbor(&header);
870        frame_bytes.extend(encode_cbor(&payload));
871
872        match decode_frame(&frame_bytes) {
873            Ok(DecodedFrame::Info(p)) => {
874                assert_eq!(p.name, "test-service");
875                assert_eq!(p.message, Some("info message".to_string()));
876            }
877            other => panic!("expected DecodedFrame::Info, got {other:?}"),
878        }
879    }
880
881    #[test]
882    fn decode_error_frame_valid() {
883        let header = FrameHeader { op: -1, t: None };
884        let payload = SubscribeErrorPayload {
885            error: "TestError".to_string(),
886            message: Some("Test error message".to_string()),
887        };
888
889        let mut frame_bytes = encode_cbor(&header);
890        frame_bytes.extend(encode_cbor(&payload));
891
892        match decode_frame(&frame_bytes) {
893            Ok(DecodedFrame::Error(p)) => {
894                assert_eq!(p.error, "TestError");
895                assert_eq!(p.message, Some("Test error message".to_string()));
896            }
897            other => panic!("expected DecodedFrame::Error, got {other:?}"),
898        }
899    }
900
901    #[test]
902    fn decode_frame_header_decode_failure() {
903        let garbage = vec![0x1f, 0x2f, 0x3f]; // Invalid CBOR
904        match decode_frame(&garbage) {
905            Err(FrameDecodeError::HeaderDecode { raw, cause: _ }) => {
906                assert_eq!(raw.as_ref(), &garbage);
907            }
908            other => panic!("expected HeaderDecode error, got {other:?}"),
909        }
910    }
911
912    #[test]
913    fn decode_frame_payload_decode_failure() {
914        let header = FrameHeader {
915            op: 1,
916            t: Some("#labels".to_string()),
917        };
918        let mut frame_bytes = encode_cbor(&header);
919        frame_bytes.push(0xff); // Garbage after header
920
921        match decode_frame(&frame_bytes) {
922            Err(FrameDecodeError::PayloadDecode {
923                header: _,
924                raw,
925                cause: _,
926            }) => {
927                assert_eq!(raw.as_ref(), &frame_bytes);
928            }
929            other => panic!("expected PayloadDecode error, got {other:?}"),
930        }
931    }
932
933    #[test]
934    fn decode_frame_unknown_message_type() {
935        let header = FrameHeader {
936            op: 1,
937            t: Some("#futureType".to_string()),
938        };
939        let frame_bytes = encode_cbor(&header);
940
941        match decode_frame(&frame_bytes) {
942            Err(FrameDecodeError::UnknownMessageType { t, raw: _ }) => {
943                assert_eq!(t, "#futureType");
944            }
945            other => panic!("expected UnknownMessageType error, got {other:?}"),
946        }
947    }
948
949    #[test]
950    fn decode_frame_error_payload_malformed() {
951        let header = FrameHeader { op: -1, t: None };
952        let mut frame_bytes = encode_cbor(&header);
953        frame_bytes.push(0xff); // Garbage payload
954
955        match decode_frame(&frame_bytes) {
956            Err(FrameDecodeError::PayloadDecode {
957                header: _,
958                raw,
959                cause: _,
960            }) => {
961                assert_eq!(raw.as_ref(), &frame_bytes);
962            }
963            other => panic!("expected PayloadDecode error, got {other:?}"),
964        }
965    }
966}