1use std::borrow::Cow;
7use std::sync::Arc;
8
9use async_trait::async_trait;
10use atrium_api::com::atproto::label::defs::Label;
11use atrium_api::com::atproto::label::query_labels;
12use base64::Engine;
13use miette::{Diagnostic, NamedSource, SourceSpan};
14use thiserror::Error;
15use url::Url;
16
17use crate::commands::test::labeler::report::{CheckResult, CheckStatus, Stage};
18use crate::common::diagnostics::{pretty_json_for_display, span_at_line_column};
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum Check {
23 EndpointReachable,
25 QueryLabelsSchemaFirstPage,
27 QueryLabelsEmptyAdvisory,
29 QueryLabelsSchemaSecondPage,
31 PaginationRoundTrip,
33 PaginationIgnoredCursor,
35}
36
37impl Check {
38 pub fn id(self) -> &'static str {
40 match self {
41 Check::EndpointReachable => "http::endpoint_reachable",
42 Check::QueryLabelsSchemaFirstPage => "http::query_labels_schema_first_page",
43 Check::QueryLabelsEmptyAdvisory => "http::query_labels_empty_advisory",
44 Check::QueryLabelsSchemaSecondPage => "http::query_labels_schema_second_page",
45 Check::PaginationRoundTrip => "http::pagination_round_trip",
46 Check::PaginationIgnoredCursor => "http::pagination_ignored_cursor",
47 }
48 }
49
50 pub fn pass(self) -> CheckResult {
51 CheckResult {
52 id: self.id(),
53 stage: Stage::Http,
54 status: CheckStatus::Pass,
55 summary: Cow::Borrowed(match self {
56 Check::EndpointReachable => "Labeler endpoint reachability",
57 Check::QueryLabelsSchemaFirstPage => "First page schema",
58 Check::QueryLabelsSchemaSecondPage => "Second page schema",
59 Check::PaginationRoundTrip => "Pagination round-trip",
60 _ => "HTTP check passed",
61 }),
62 diagnostic: None,
63 skipped_reason: None,
64 }
65 }
66
67 pub fn spec_violation(
68 self,
69 diagnostic: Option<Box<dyn miette::Diagnostic + Send + Sync>>,
70 ) -> CheckResult {
71 CheckResult {
72 id: self.id(),
73 stage: Stage::Http,
74 status: CheckStatus::SpecViolation,
75 summary: Cow::Borrowed(match self {
76 Check::QueryLabelsSchemaFirstPage => "Schema validation failed",
77 Check::QueryLabelsSchemaSecondPage => "Second page schema validation failed",
78 Check::PaginationIgnoredCursor => "Labeler ignored the cursor parameter",
79 _ => "HTTP check failed",
80 }),
81 diagnostic,
82 skipped_reason: None,
83 }
84 }
85
86 pub fn network_error(self) -> CheckResult {
87 CheckResult {
88 id: self.id(),
89 stage: Stage::Http,
90 status: CheckStatus::NetworkError,
91 summary: Cow::Borrowed(match self {
92 Check::EndpointReachable => "Labeler endpoint unreachable",
93 Check::QueryLabelsSchemaSecondPage => "Second page fetch failed",
94 _ => "HTTP network error",
95 }),
96 diagnostic: None,
97 skipped_reason: None,
98 }
99 }
100
101 pub fn advisory(self) -> CheckResult {
102 CheckResult {
103 id: self.id(),
104 stage: Stage::Http,
105 status: CheckStatus::Advisory,
106 summary: Cow::Borrowed(match self {
107 Check::QueryLabelsEmptyAdvisory => "Labeler has no published labels",
108 _ => "HTTP advisory",
109 }),
110 diagnostic: None,
111 skipped_reason: None,
112 }
113 }
114}
115
116#[derive(Debug, Clone)]
118pub struct HttpFacts {
119 pub first_page: Vec<Label>,
121 pub first_page_raw_bytes: Arc<[u8]>,
123 pub first_page_source_url: String,
125 pub pagination_ok: bool,
127}
128
129#[derive(Debug)]
131pub struct HttpStageOutput {
132 pub facts: Option<HttpFacts>,
134 pub results: Vec<CheckResult>,
136}
137
138pub struct RawXrpcResponse {
140 pub status: reqwest::StatusCode,
142 pub raw_body: Arc<[u8]>,
144 pub decoded: query_labels::Output,
146 pub source_url: String,
148}
149
150#[derive(Debug, Error, Diagnostic)]
152#[error("{message}")]
153#[diagnostic(code = "labeler::http::schema_failure")]
154pub struct HttpDecodeFailure {
155 pub message: String,
157 #[source_code]
159 pub source_code: NamedSource<Arc<[u8]>>,
160 #[label("JSON error")]
162 pub span: Option<SourceSpan>,
163}
164
165#[derive(Debug, Error)]
167pub enum HttpStageError {
168 #[error("HTTP transport error: {message}")]
170 Transport {
171 message: String,
173 #[source]
175 source: Option<Box<dyn std::error::Error + Send + Sync>>,
176 },
177
178 #[error("Schema decode failure")]
180 DecodeFailed {
181 raw_body: Arc<[u8]>,
183 source: serde_json::Error,
185 source_url: String,
187 },
188}
189
190#[async_trait]
195pub trait RawHttpTee: Send + Sync {
196 async fn query_labels(&self, cursor: Option<&str>) -> Result<RawXrpcResponse, HttpStageError>;
203}
204
205pub struct RealHttpTee {
207 client: reqwest::Client,
209 endpoint: Url,
211}
212
213impl RealHttpTee {
214 pub fn new(client: reqwest::Client, endpoint: Url) -> Self {
216 RealHttpTee { client, endpoint }
217 }
218}
219
220#[async_trait]
221impl RawHttpTee for RealHttpTee {
222 async fn query_labels(&self, cursor: Option<&str>) -> Result<RawXrpcResponse, HttpStageError> {
223 let mut url = self.endpoint.clone();
225 url.set_path("xrpc/com.atproto.label.queryLabels");
226
227 {
229 let mut query = url.query_pairs_mut();
230 query.append_pair("uriPatterns", "*");
231 query.append_pair("limit", "50");
232 if let Some(c) = cursor {
233 query.append_pair("cursor", c);
234 }
235 }
236
237 let source_url = url.to_string();
238
239 tracing::debug!(
240 url = %source_url,
241 cursor = ?cursor,
242 "http stage: issuing queryLabels GET"
243 );
244
245 let response =
247 self.client
248 .get(url.as_str())
249 .send()
250 .await
251 .map_err(|e| HttpStageError::Transport {
252 message: e.to_string(),
253 source: Some(Box::new(e)),
254 })?;
255
256 let status = response.status();
257 let body_bytes = response
258 .bytes()
259 .await
260 .map_err(|e| HttpStageError::Transport {
261 message: e.to_string(),
262 source: Some(Box::new(e)),
263 })?;
264
265 tracing::debug!(
266 url = %source_url,
267 status = %status,
268 body_len = body_bytes.len(),
269 "http stage: queryLabels response received"
270 );
271 let raw_body: Arc<[u8]> = Arc::from(body_bytes.as_ref());
272
273 let decoded = decode_query_labels_output(&raw_body).map_err(|source| {
281 HttpStageError::DecodeFailed {
282 raw_body: raw_body.clone(),
283 source,
284 source_url: source_url.clone(),
285 }
286 })?;
287
288 Ok(RawXrpcResponse {
289 status,
290 raw_body,
291 decoded,
292 source_url,
293 })
294 }
295}
296
297fn decode_query_labels_output(body: &[u8]) -> Result<query_labels::Output, serde_json::Error> {
305 let mut value: serde_json::Value = serde_json::from_slice(body)?;
306 rewrite_atproto_json_bytes(&mut value);
307 serde_json::from_value(value)
308}
309
310fn rewrite_atproto_json_bytes(value: &mut serde_json::Value) {
321 use serde_json::Value;
322 match value {
323 Value::Object(map) => {
324 if let Some(decoded) = decode_atproto_bytes_wrapper(map) {
325 *value = Value::Array(
326 decoded
327 .into_iter()
328 .map(|b| Value::Number(b.into()))
329 .collect(),
330 );
331 return;
332 }
333 for child in map.values_mut() {
334 rewrite_atproto_json_bytes(child);
335 }
336 }
337 Value::Array(arr) => {
338 for child in arr.iter_mut() {
339 rewrite_atproto_json_bytes(child);
340 }
341 }
342 _ => {}
343 }
344}
345
346fn decode_atproto_bytes_wrapper(
351 map: &serde_json::Map<String, serde_json::Value>,
352) -> Option<Vec<u8>> {
353 if map.len() != 1 {
354 return None;
355 }
356 let encoded = match map.get("$bytes")? {
357 serde_json::Value::String(s) => s,
358 _ => return None,
359 };
360 let stripped = encoded.trim_end_matches('=');
361 base64::engine::general_purpose::STANDARD_NO_PAD
362 .decode(stripped)
363 .ok()
364}
365
366fn decode_error_location_for_display(
372 pretty_body: &[u8],
373 raw_err: &serde_json::Error,
374) -> (usize, usize) {
375 if let Err(err) = decode_query_labels_output(pretty_body) {
376 (err.line(), err.column())
377 } else {
378 (raw_err.line(), raw_err.column())
379 }
380}
381
382pub async fn run(http: &dyn RawHttpTee) -> HttpStageOutput {
390 let mut results = Vec::new();
391
392 let first_response = match http.query_labels(None).await {
394 Ok(resp) => {
395 if resp.status.is_success() {
397 results.push(Check::EndpointReachable.pass());
398 } else {
399 let status_code = resp.status;
400 results.push(CheckResult {
401 summary: Cow::Owned(format!(
402 "Labeler endpoint reachability (status {status_code})"
403 )),
404 ..Check::EndpointReachable.pass()
405 });
406 }
407 resp
408 }
409 Err(HttpStageError::Transport { message, .. }) => {
410 results.push(CheckResult {
412 summary: Cow::Owned(format!("Network error: {message}")),
413 ..Check::EndpointReachable.network_error()
414 });
415 return HttpStageOutput {
416 facts: None,
417 results,
418 };
419 }
420 Err(HttpStageError::DecodeFailed {
421 raw_body,
422 source,
423 source_url,
424 }) => {
425 results.push(Check::EndpointReachable.pass());
427 let pretty_body = pretty_json_for_display(&raw_body);
428 let (line, column) = decode_error_location_for_display(&pretty_body, &source);
429 let diagnostic = Box::new(HttpDecodeFailure {
430 message: format!("Failed to decode query_labels response: {source}"),
431 source_code: NamedSource::new(source_url.clone(), pretty_body.clone()),
432 span: Some(span_at_line_column(&pretty_body, line, column)),
433 });
434 results.push(Check::QueryLabelsSchemaFirstPage.spec_violation(Some(diagnostic)));
435 return HttpStageOutput {
436 facts: None,
437 results,
438 };
439 }
440 };
441
442 let output = &first_response.decoded;
444 results.push(Check::QueryLabelsSchemaFirstPage.pass());
445
446 let first_page_labels = output.labels.clone();
447 let first_page_raw_bytes = first_response.raw_body.clone();
448 let first_page_source_url = first_response.source_url.clone();
449
450 if first_page_labels.is_empty() {
451 results.push(Check::QueryLabelsEmptyAdvisory.advisory());
452 }
453
454 let pagination_ok = if let Some(cursor) = &output.cursor {
455 match http.query_labels(Some(cursor)).await {
456 Ok(second_resp) => {
457 let second_output = &second_resp.decoded;
458 if second_output.labels == first_page_labels {
460 results.push(Check::QueryLabelsSchemaSecondPage.pass());
461 results.push(Check::PaginationIgnoredCursor.spec_violation(None));
462 false
463 } else {
464 results.push(Check::QueryLabelsSchemaSecondPage.pass());
465 results.push(Check::PaginationRoundTrip.pass());
466 true
467 }
468 }
469 Err(HttpStageError::Transport { message, .. }) => {
470 results.push(CheckResult {
471 summary: Cow::Owned(format!("Network error fetching second page: {message}")),
472 ..Check::QueryLabelsSchemaSecondPage.network_error()
473 });
474 false
475 }
476 Err(HttpStageError::DecodeFailed {
477 raw_body,
478 source,
479 source_url,
480 }) => {
481 let pretty_body = pretty_json_for_display(&raw_body);
482 let (line, column) = decode_error_location_for_display(&pretty_body, &source);
483 let diagnostic = Box::new(HttpDecodeFailure {
484 message: format!("Failed to decode second page response: {source}"),
485 source_code: NamedSource::new(source_url, pretty_body.clone()),
486 span: Some(span_at_line_column(&pretty_body, line, column)),
487 });
488 results.push(Check::QueryLabelsSchemaSecondPage.spec_violation(Some(diagnostic)));
489 false
490 }
491 }
492 } else {
493 results.push(CheckResult {
495 summary: Cow::Borrowed("First page was complete; pagination not exercised"),
496 ..Check::PaginationRoundTrip.pass()
497 });
498 true
499 };
500
501 let facts = HttpFacts {
503 first_page: first_page_labels,
504 first_page_raw_bytes,
505 first_page_source_url,
506 pagination_ok,
507 };
508
509 HttpStageOutput {
510 facts: Some(facts),
511 results,
512 }
513}
514
515#[cfg(test)]
516mod tests {
517 use super::*;
518
519 #[test]
520 fn rewrite_atproto_json_bytes_replaces_wrapper() {
521 let mut value: serde_json::Value =
522 serde_json::from_str(r#"{"sig": {"$bytes": "AAECAw"}, "other": 1}"#).unwrap();
523 rewrite_atproto_json_bytes(&mut value);
524 assert_eq!(value["sig"], serde_json::json!([0, 1, 2, 3]));
525 assert_eq!(value["other"], serde_json::json!(1));
526 }
527
528 #[test]
529 fn rewrite_atproto_json_bytes_accepts_padded_base64() {
530 let mut value: serde_json::Value =
531 serde_json::from_str(r#"{"$bytes": "AAECAw=="}"#).unwrap();
532 rewrite_atproto_json_bytes(&mut value);
533 assert_eq!(value, serde_json::json!([0, 1, 2, 3]));
534 }
535
536 #[test]
537 fn rewrite_atproto_json_bytes_ignores_non_wrapper_objects() {
538 let mut value: serde_json::Value =
540 serde_json::from_str(r#"{"$bytes": "AAECAw", "extra": true}"#).unwrap();
541 let before = value.clone();
542 rewrite_atproto_json_bytes(&mut value);
543 assert_eq!(value, before);
544 }
545
546 #[test]
547 fn decode_query_labels_output_handles_dollar_bytes_sig() {
548 let body = br#"{"cursor":"c","labels":[{"ver":1,"src":"did:plc:aaa22222222222222222bbbbbb","uri":"at://did:plc:aaa22222222222222222bbbbbb/app.bsky.feed.post/abc","val":"spam","cts":"2026-01-01T00:00:00.000Z","sig":{"$bytes":"AAECAw"}}]}"#;
550 let output = decode_query_labels_output(body).expect("should decode");
551 assert_eq!(output.labels.len(), 1);
552 let sig = output.labels[0].sig.as_ref().expect("sig present");
553 assert_eq!(sig, &vec![0u8, 1, 2, 3]);
554 }
555}