Skip to main content

atproto_devtool/commands/test/labeler/
http.rs

1//! HTTP stage for the labeler conformance suite.
2//!
3//! Performs `com.atproto.label.queryLabels` requests against the labeler endpoint,
4//! verifies schema conformance, and exercises pagination.
5
6use std::borrow::Cow;
7use std::sync::Arc;
8
9use async_trait::async_trait;
10use atrium_api::com::atproto::label::defs::Label;
11use atrium_api::com::atproto::label::query_labels;
12use base64::Engine;
13use miette::{Diagnostic, NamedSource, SourceSpan};
14use thiserror::Error;
15use url::Url;
16
17use crate::commands::test::labeler::report::{CheckResult, CheckStatus, Stage};
18use crate::common::diagnostics::{pretty_json_for_display, span_at_line_column};
19
20/// Checks emitted by the HTTP stage.
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum Check {
23    /// Whether the labeler endpoint responded to an HTTP request.
24    EndpointReachable,
25    /// Schema validation of the first queryLabels page.
26    QueryLabelsSchemaFirstPage,
27    /// Advisory when the labeler has no published labels.
28    QueryLabelsEmptyAdvisory,
29    /// Schema validation of the second queryLabels page.
30    QueryLabelsSchemaSecondPage,
31    /// Pagination cursor round-trip check.
32    PaginationRoundTrip,
33    /// Labeler ignored the cursor parameter.
34    PaginationIgnoredCursor,
35}
36
37impl Check {
38    /// Stable check ID string used in `CheckResult.id`.
39    pub fn id(self) -> &'static str {
40        match self {
41            Check::EndpointReachable => "http::endpoint_reachable",
42            Check::QueryLabelsSchemaFirstPage => "http::query_labels_schema_first_page",
43            Check::QueryLabelsEmptyAdvisory => "http::query_labels_empty_advisory",
44            Check::QueryLabelsSchemaSecondPage => "http::query_labels_schema_second_page",
45            Check::PaginationRoundTrip => "http::pagination_round_trip",
46            Check::PaginationIgnoredCursor => "http::pagination_ignored_cursor",
47        }
48    }
49
50    pub fn pass(self) -> CheckResult {
51        CheckResult {
52            id: self.id(),
53            stage: Stage::Http,
54            status: CheckStatus::Pass,
55            summary: Cow::Borrowed(match self {
56                Check::EndpointReachable => "Labeler endpoint reachability",
57                Check::QueryLabelsSchemaFirstPage => "First page schema",
58                Check::QueryLabelsSchemaSecondPage => "Second page schema",
59                Check::PaginationRoundTrip => "Pagination round-trip",
60                _ => "HTTP check passed",
61            }),
62            diagnostic: None,
63            skipped_reason: None,
64        }
65    }
66
67    pub fn spec_violation(
68        self,
69        diagnostic: Option<Box<dyn miette::Diagnostic + Send + Sync>>,
70    ) -> CheckResult {
71        CheckResult {
72            id: self.id(),
73            stage: Stage::Http,
74            status: CheckStatus::SpecViolation,
75            summary: Cow::Borrowed(match self {
76                Check::QueryLabelsSchemaFirstPage => "Schema validation failed",
77                Check::QueryLabelsSchemaSecondPage => "Second page schema validation failed",
78                Check::PaginationIgnoredCursor => "Labeler ignored the cursor parameter",
79                _ => "HTTP check failed",
80            }),
81            diagnostic,
82            skipped_reason: None,
83        }
84    }
85
86    pub fn network_error(self) -> CheckResult {
87        CheckResult {
88            id: self.id(),
89            stage: Stage::Http,
90            status: CheckStatus::NetworkError,
91            summary: Cow::Borrowed(match self {
92                Check::EndpointReachable => "Labeler endpoint unreachable",
93                Check::QueryLabelsSchemaSecondPage => "Second page fetch failed",
94                _ => "HTTP network error",
95            }),
96            diagnostic: None,
97            skipped_reason: None,
98        }
99    }
100
101    pub fn advisory(self) -> CheckResult {
102        CheckResult {
103            id: self.id(),
104            stage: Stage::Http,
105            status: CheckStatus::Advisory,
106            summary: Cow::Borrowed(match self {
107                Check::QueryLabelsEmptyAdvisory => "Labeler has no published labels",
108                _ => "HTTP advisory",
109            }),
110            diagnostic: None,
111            skipped_reason: None,
112        }
113    }
114}
115
116/// Facts gathered from the HTTP stage, populated only when all checks pass.
117#[derive(Debug, Clone)]
118pub struct HttpFacts {
119    /// The parsed labels from the first page of the query response.
120    pub first_page: Vec<Label>,
121    /// Raw bytes of the first page response body.
122    pub first_page_raw_bytes: Arc<[u8]>,
123    /// The source URL where the first page was retrieved.
124    pub first_page_source_url: String,
125    /// Whether pagination passed the round-trip check.
126    pub pagination_ok: bool,
127}
128
129/// Output from the HTTP stage: facts (if all checks pass) plus all check results.
130#[derive(Debug)]
131pub struct HttpStageOutput {
132    /// Facts populated only when all checks pass and no check is blocking.
133    pub facts: Option<HttpFacts>,
134    /// All check results from this stage.
135    pub results: Vec<CheckResult>,
136}
137
138/// Raw HTTP response with both decoded and raw body.
139pub struct RawXrpcResponse {
140    /// HTTP status code.
141    pub status: reqwest::StatusCode,
142    /// Raw response body bytes.
143    pub raw_body: Arc<[u8]>,
144    /// Decoded typed response.
145    pub decoded: query_labels::Output,
146    /// The source URL where the response came from.
147    pub source_url: String,
148}
149
150/// Diagnostic for schema decode failures with source context.
151#[derive(Debug, Error, Diagnostic)]
152#[error("{message}")]
153#[diagnostic(code = "labeler::http::schema_failure")]
154pub struct HttpDecodeFailure {
155    /// The error message.
156    pub message: String,
157    /// The raw response bytes.
158    #[source_code]
159    pub source_code: NamedSource<Arc<[u8]>>,
160    /// Span highlighting the error location.
161    #[label("JSON error")]
162    pub span: Option<SourceSpan>,
163}
164
165/// Error type for HTTP stage operations.
166#[derive(Debug, Error)]
167pub enum HttpStageError {
168    /// Network or TLS error reaching the endpoint.
169    #[error("HTTP transport error: {message}")]
170    Transport {
171        /// Human-readable error message.
172        message: String,
173        /// The underlying error, if available.
174        #[source]
175        source: Option<Box<dyn std::error::Error + Send + Sync>>,
176    },
177
178    /// Decode failure of a valid HTTP response.
179    #[error("Schema decode failure")]
180    DecodeFailed {
181        /// Raw response body bytes.
182        raw_body: Arc<[u8]>,
183        /// The JSON decode error.
184        source: serde_json::Error,
185        /// The source URL for diagnostic context.
186        source_url: String,
187    },
188}
189
190/// Trait for teeing HTTP responses, allowing both decode and raw bytes capture.
191///
192/// Implementations perform `com.atproto.label.queryLabels` calls and return both
193/// the decoded typed response and the raw bytes for diagnostics.
194#[async_trait]
195pub trait RawHttpTee: Send + Sync {
196    /// Perform a `com.atproto.label.queryLabels` call against the labeler.
197    ///
198    /// Returns both the raw response body and, if decoding succeeded, the typed Output.
199    ///
200    /// # Arguments
201    /// * `cursor` - Optional cursor string for pagination.
202    async fn query_labels(&self, cursor: Option<&str>) -> Result<RawXrpcResponse, HttpStageError>;
203}
204
205/// Real HTTP client implementation using reqwest.
206pub struct RealHttpTee {
207    /// The base HTTP client.
208    client: reqwest::Client,
209    /// The labeler endpoint URL.
210    endpoint: Url,
211}
212
213impl RealHttpTee {
214    /// Create a new RealHttpTee with the given endpoint.
215    pub fn new(client: reqwest::Client, endpoint: Url) -> Self {
216        RealHttpTee { client, endpoint }
217    }
218}
219
220#[async_trait]
221impl RawHttpTee for RealHttpTee {
222    async fn query_labels(&self, cursor: Option<&str>) -> Result<RawXrpcResponse, HttpStageError> {
223        // Build the XRPC endpoint URL.
224        let mut url = self.endpoint.clone();
225        url.set_path("xrpc/com.atproto.label.queryLabels");
226
227        // Set query parameters.
228        {
229            let mut query = url.query_pairs_mut();
230            query.append_pair("uriPatterns", "*");
231            query.append_pair("limit", "50");
232            if let Some(c) = cursor {
233                query.append_pair("cursor", c);
234            }
235        }
236
237        let source_url = url.to_string();
238
239        tracing::debug!(
240            url = %source_url,
241            cursor = ?cursor,
242            "http stage: issuing queryLabels GET"
243        );
244
245        // Perform the GET request.
246        let response =
247            self.client
248                .get(url.as_str())
249                .send()
250                .await
251                .map_err(|e| HttpStageError::Transport {
252                    message: e.to_string(),
253                    source: Some(Box::new(e)),
254                })?;
255
256        let status = response.status();
257        let body_bytes = response
258            .bytes()
259            .await
260            .map_err(|e| HttpStageError::Transport {
261                message: e.to_string(),
262                source: Some(Box::new(e)),
263            })?;
264
265        tracing::debug!(
266            url = %source_url,
267            status = %status,
268            body_len = body_bytes.len(),
269            "http stage: queryLabels response received"
270        );
271        let raw_body: Arc<[u8]> = Arc::from(body_bytes.as_ref());
272
273        // Attempt to decode the response. The atproto JSON encoding wraps
274        // every bytes value as `{"$bytes": "<base64>"}`, but atrium-api's
275        // generated types annotate byte fields with `#[serde(with = "serde_bytes")]`
276        // which expects a raw byte sequence. Parse the body into a
277        // `serde_json::Value` first, rewrite every `{"$bytes": "<base64>"}`
278        // object into an array of byte integers, then hand the transformed
279        // value to atrium for typed decoding.
280        let decoded = decode_query_labels_output(&raw_body).map_err(|source| {
281            HttpStageError::DecodeFailed {
282                raw_body: raw_body.clone(),
283                source,
284                source_url: source_url.clone(),
285            }
286        })?;
287
288        Ok(RawXrpcResponse {
289            status,
290            raw_body,
291            decoded,
292            source_url,
293        })
294    }
295}
296
297/// Deserialize a `com.atproto.label.queryLabels` response body into the
298/// atrium-generated `query_labels::Output`, translating the atproto JSON
299/// `{"$bytes": "<base64>"}` wrapper into a plain byte sequence so that
300/// `serde_bytes`-annotated fields (notably `Label.sig`) deserialize cleanly.
301///
302/// On failure the returned error is a `serde_json::Error`, so the existing
303/// `span_for_json_error` helper can still point at the relevant source line.
304fn decode_query_labels_output(body: &[u8]) -> Result<query_labels::Output, serde_json::Error> {
305    let mut value: serde_json::Value = serde_json::from_slice(body)?;
306    rewrite_atproto_json_bytes(&mut value);
307    serde_json::from_value(value)
308}
309
310/// Walk a `serde_json::Value` and rewrite every
311/// `{"$bytes": "<base64>"}` object into a `Value::Array` of byte integers.
312///
313/// This mirrors the atproto JSON representation of `bytes` values (see
314/// https://atproto.com/specs/data-model#json-representation) so that downstream
315/// deserialization into `Vec<u8>` (or `serde_bytes`-wrapped equivalents) works
316/// without a custom `Visitor`. Objects whose `$bytes` value is not a base64
317/// string (malformed, extra keys, non-string) are left unchanged so that the
318/// eventual typed deserialization surfaces a meaningful error instead of
319/// silently corrupting data.
320fn rewrite_atproto_json_bytes(value: &mut serde_json::Value) {
321    use serde_json::Value;
322    match value {
323        Value::Object(map) => {
324            if let Some(decoded) = decode_atproto_bytes_wrapper(map) {
325                *value = Value::Array(
326                    decoded
327                        .into_iter()
328                        .map(|b| Value::Number(b.into()))
329                        .collect(),
330                );
331                return;
332            }
333            for child in map.values_mut() {
334                rewrite_atproto_json_bytes(child);
335            }
336        }
337        Value::Array(arr) => {
338            for child in arr.iter_mut() {
339                rewrite_atproto_json_bytes(child);
340            }
341        }
342        _ => {}
343    }
344}
345
346/// If `map` is the single-key object `{"$bytes": "<base64>"}`, decode and
347/// return the bytes. Otherwise return `None`. Accepts both padded and
348/// unpadded standard base64 (the atproto spec says padded, but real servers
349/// commonly omit padding).
350fn decode_atproto_bytes_wrapper(
351    map: &serde_json::Map<String, serde_json::Value>,
352) -> Option<Vec<u8>> {
353    if map.len() != 1 {
354        return None;
355    }
356    let encoded = match map.get("$bytes")? {
357        serde_json::Value::String(s) => s,
358        _ => return None,
359    };
360    let stripped = encoded.trim_end_matches('=');
361    base64::engine::general_purpose::STANDARD_NO_PAD
362        .decode(stripped)
363        .ok()
364}
365
366/// Re-run the decode against a pretty-printed copy of the body to obtain a
367/// line/column pair that points into `pretty_body` rather than into the
368/// original one-line wire body. If pretty-printing doesn't change the error
369/// site (or the pretty body doesn't parse for some other reason), falls back
370/// to the original error's location.
371fn decode_error_location_for_display(
372    pretty_body: &[u8],
373    raw_err: &serde_json::Error,
374) -> (usize, usize) {
375    if let Err(err) = decode_query_labels_output(pretty_body) {
376        (err.line(), err.column())
377    } else {
378        (raw_err.line(), raw_err.column())
379    }
380}
381
382/// Run the HTTP stage against a labeler endpoint.
383///
384/// # Arguments
385/// * `http` - The HTTP client implementation (usually `RealHttpTee` in production, or fake in tests).
386///
387/// # Returns
388/// `HttpStageOutput` containing check results and facts (if all checks pass).
389pub async fn run(http: &dyn RawHttpTee) -> HttpStageOutput {
390    let mut results = Vec::new();
391
392    // Fetch first page; derive both endpoint_reachable and query_labels_schema_first_page from the same request.
393    let first_response = match http.query_labels(None).await {
394        Ok(resp) => {
395            // Endpoint is reachable if we got any response (2xx or non-2xx).
396            if resp.status.is_success() {
397                results.push(Check::EndpointReachable.pass());
398            } else {
399                let status_code = resp.status;
400                results.push(CheckResult {
401                    summary: Cow::Owned(format!(
402                        "Labeler endpoint reachability (status {status_code})"
403                    )),
404                    ..Check::EndpointReachable.pass()
405                });
406            }
407            resp
408        }
409        Err(HttpStageError::Transport { message, .. }) => {
410            // Network/TLS error: endpoint is unreachable.
411            results.push(CheckResult {
412                summary: Cow::Owned(format!("Network error: {message}")),
413                ..Check::EndpointReachable.network_error()
414            });
415            return HttpStageOutput {
416                facts: None,
417                results,
418            };
419        }
420        Err(HttpStageError::DecodeFailed {
421            raw_body,
422            source,
423            source_url,
424        }) => {
425            // Endpoint is reachable, but schema decode failed.
426            results.push(Check::EndpointReachable.pass());
427            let pretty_body = pretty_json_for_display(&raw_body);
428            let (line, column) = decode_error_location_for_display(&pretty_body, &source);
429            let diagnostic = Box::new(HttpDecodeFailure {
430                message: format!("Failed to decode query_labels response: {source}"),
431                source_code: NamedSource::new(source_url.clone(), pretty_body.clone()),
432                span: Some(span_at_line_column(&pretty_body, line, column)),
433            });
434            results.push(Check::QueryLabelsSchemaFirstPage.spec_violation(Some(diagnostic)));
435            return HttpStageOutput {
436                facts: None,
437                results,
438            };
439        }
440    };
441
442    // Decode succeeded (we return early on decode failure above).
443    let output = &first_response.decoded;
444    results.push(Check::QueryLabelsSchemaFirstPage.pass());
445
446    let first_page_labels = output.labels.clone();
447    let first_page_raw_bytes = first_response.raw_body.clone();
448    let first_page_source_url = first_response.source_url.clone();
449
450    if first_page_labels.is_empty() {
451        results.push(Check::QueryLabelsEmptyAdvisory.advisory());
452    }
453
454    let pagination_ok = if let Some(cursor) = &output.cursor {
455        match http.query_labels(Some(cursor)).await {
456            Ok(second_resp) => {
457                let second_output = &second_resp.decoded;
458                // Check if cursor was actually honored.
459                if second_output.labels == first_page_labels {
460                    results.push(Check::QueryLabelsSchemaSecondPage.pass());
461                    results.push(Check::PaginationIgnoredCursor.spec_violation(None));
462                    false
463                } else {
464                    results.push(Check::QueryLabelsSchemaSecondPage.pass());
465                    results.push(Check::PaginationRoundTrip.pass());
466                    true
467                }
468            }
469            Err(HttpStageError::Transport { message, .. }) => {
470                results.push(CheckResult {
471                    summary: Cow::Owned(format!("Network error fetching second page: {message}")),
472                    ..Check::QueryLabelsSchemaSecondPage.network_error()
473                });
474                false
475            }
476            Err(HttpStageError::DecodeFailed {
477                raw_body,
478                source,
479                source_url,
480            }) => {
481                let pretty_body = pretty_json_for_display(&raw_body);
482                let (line, column) = decode_error_location_for_display(&pretty_body, &source);
483                let diagnostic = Box::new(HttpDecodeFailure {
484                    message: format!("Failed to decode second page response: {source}"),
485                    source_code: NamedSource::new(source_url, pretty_body.clone()),
486                    span: Some(span_at_line_column(&pretty_body, line, column)),
487                });
488                results.push(Check::QueryLabelsSchemaSecondPage.spec_violation(Some(diagnostic)));
489                false
490            }
491        }
492    } else {
493        // No cursor: pagination not exercised, but that's OK.
494        results.push(CheckResult {
495            summary: Cow::Borrowed("First page was complete; pagination not exercised"),
496            ..Check::PaginationRoundTrip.pass()
497        });
498        true
499    };
500
501    // Facts consumed by the crypto stage.
502    let facts = HttpFacts {
503        first_page: first_page_labels,
504        first_page_raw_bytes,
505        first_page_source_url,
506        pagination_ok,
507    };
508
509    HttpStageOutput {
510        facts: Some(facts),
511        results,
512    }
513}
514
515#[cfg(test)]
516mod tests {
517    use super::*;
518
519    #[test]
520    fn rewrite_atproto_json_bytes_replaces_wrapper() {
521        let mut value: serde_json::Value =
522            serde_json::from_str(r#"{"sig": {"$bytes": "AAECAw"}, "other": 1}"#).unwrap();
523        rewrite_atproto_json_bytes(&mut value);
524        assert_eq!(value["sig"], serde_json::json!([0, 1, 2, 3]));
525        assert_eq!(value["other"], serde_json::json!(1));
526    }
527
528    #[test]
529    fn rewrite_atproto_json_bytes_accepts_padded_base64() {
530        let mut value: serde_json::Value =
531            serde_json::from_str(r#"{"$bytes": "AAECAw=="}"#).unwrap();
532        rewrite_atproto_json_bytes(&mut value);
533        assert_eq!(value, serde_json::json!([0, 1, 2, 3]));
534    }
535
536    #[test]
537    fn rewrite_atproto_json_bytes_ignores_non_wrapper_objects() {
538        // Extra keys: not a wrapper, must not be rewritten.
539        let mut value: serde_json::Value =
540            serde_json::from_str(r#"{"$bytes": "AAECAw", "extra": true}"#).unwrap();
541        let before = value.clone();
542        rewrite_atproto_json_bytes(&mut value);
543        assert_eq!(value, before);
544    }
545
546    #[test]
547    fn decode_query_labels_output_handles_dollar_bytes_sig() {
548        // Minimal queryLabels Output with one label whose `sig` is wrapped.
549        let body = br#"{"cursor":"c","labels":[{"ver":1,"src":"did:plc:aaa22222222222222222bbbbbb","uri":"at://did:plc:aaa22222222222222222bbbbbb/app.bsky.feed.post/abc","val":"spam","cts":"2026-01-01T00:00:00.000Z","sig":{"$bytes":"AAECAw"}}]}"#;
550        let output = decode_query_labels_output(body).expect("should decode");
551        assert_eq!(output.labels.len(), 1);
552        let sig = output.labels[0].sig.as_ref().expect("sig present");
553        assert_eq!(sig, &vec![0u8, 1, 2, 3]);
554    }
555}