use std::borrow::Cow;
use std::sync::Arc;
use async_trait::async_trait;
use atrium_api::com::atproto::label::defs::Label;
use atrium_api::com::atproto::label::query_labels;
use base64::Engine;
use miette::{Diagnostic, NamedSource, SourceSpan};
use thiserror::Error;
use url::Url;
use crate::commands::test::labeler::report::{CheckResult, CheckStatus, Stage};
use crate::common::diagnostics::{pretty_json_for_display, span_at_line_column};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Check {
EndpointReachable,
QueryLabelsSchemaFirstPage,
QueryLabelsEmptyAdvisory,
QueryLabelsSchemaSecondPage,
PaginationRoundTrip,
PaginationIgnoredCursor,
}
impl Check {
pub fn id(self) -> &'static str {
match self {
Check::EndpointReachable => "http::endpoint_reachable",
Check::QueryLabelsSchemaFirstPage => "http::query_labels_schema_first_page",
Check::QueryLabelsEmptyAdvisory => "http::query_labels_empty_advisory",
Check::QueryLabelsSchemaSecondPage => "http::query_labels_schema_second_page",
Check::PaginationRoundTrip => "http::pagination_round_trip",
Check::PaginationIgnoredCursor => "http::pagination_ignored_cursor",
}
}
pub fn pass(self) -> CheckResult {
CheckResult {
id: self.id(),
stage: Stage::Http,
status: CheckStatus::Pass,
summary: Cow::Borrowed(match self {
Check::EndpointReachable => "Labeler endpoint reachability",
Check::QueryLabelsSchemaFirstPage => "First page schema",
Check::QueryLabelsSchemaSecondPage => "Second page schema",
Check::PaginationRoundTrip => "Pagination round-trip",
_ => "HTTP check passed",
}),
diagnostic: None,
skipped_reason: None,
}
}
pub fn spec_violation(
self,
diagnostic: Option<Box<dyn miette::Diagnostic + Send + Sync>>,
) -> CheckResult {
CheckResult {
id: self.id(),
stage: Stage::Http,
status: CheckStatus::SpecViolation,
summary: Cow::Borrowed(match self {
Check::QueryLabelsSchemaFirstPage => "Schema validation failed",
Check::QueryLabelsSchemaSecondPage => "Second page schema validation failed",
Check::PaginationIgnoredCursor => "Labeler ignored the cursor parameter",
_ => "HTTP check failed",
}),
diagnostic,
skipped_reason: None,
}
}
pub fn network_error(self) -> CheckResult {
CheckResult {
id: self.id(),
stage: Stage::Http,
status: CheckStatus::NetworkError,
summary: Cow::Borrowed(match self {
Check::EndpointReachable => "Labeler endpoint unreachable",
Check::QueryLabelsSchemaSecondPage => "Second page fetch failed",
_ => "HTTP network error",
}),
diagnostic: None,
skipped_reason: None,
}
}
pub fn advisory(self) -> CheckResult {
CheckResult {
id: self.id(),
stage: Stage::Http,
status: CheckStatus::Advisory,
summary: Cow::Borrowed(match self {
Check::QueryLabelsEmptyAdvisory => "Labeler has no published labels",
_ => "HTTP advisory",
}),
diagnostic: None,
skipped_reason: None,
}
}
}
#[derive(Debug, Clone)]
pub struct HttpFacts {
pub first_page: Vec<Label>,
pub first_page_raw_bytes: Arc<[u8]>,
pub first_page_source_url: String,
pub pagination_ok: bool,
}
#[derive(Debug)]
pub struct HttpStageOutput {
pub facts: Option<HttpFacts>,
pub results: Vec<CheckResult>,
}
pub struct RawXrpcResponse {
pub status: reqwest::StatusCode,
pub raw_body: Arc<[u8]>,
pub decoded: query_labels::Output,
pub source_url: String,
}
#[derive(Debug, Error, Diagnostic)]
#[error("{message}")]
#[diagnostic(code = "labeler::http::schema_failure")]
pub struct HttpDecodeFailure {
pub message: String,
#[source_code]
pub source_code: NamedSource<Arc<[u8]>>,
#[label("JSON error")]
pub span: Option<SourceSpan>,
}
#[derive(Debug, Error)]
pub enum HttpStageError {
#[error("HTTP transport error: {message}")]
Transport {
message: String,
#[source]
source: Option<Box<dyn std::error::Error + Send + Sync>>,
},
#[error("Schema decode failure")]
DecodeFailed {
raw_body: Arc<[u8]>,
source: serde_json::Error,
source_url: String,
},
}
#[async_trait]
pub trait RawHttpTee: Send + Sync {
async fn query_labels(&self, cursor: Option<&str>) -> Result<RawXrpcResponse, HttpStageError>;
}
pub struct RealHttpTee {
client: reqwest::Client,
endpoint: Url,
}
impl RealHttpTee {
pub fn new(client: reqwest::Client, endpoint: Url) -> Self {
RealHttpTee { client, endpoint }
}
}
#[async_trait]
impl RawHttpTee for RealHttpTee {
async fn query_labels(&self, cursor: Option<&str>) -> Result<RawXrpcResponse, HttpStageError> {
let mut url = self.endpoint.clone();
url.set_path("xrpc/com.atproto.label.queryLabels");
{
let mut query = url.query_pairs_mut();
query.append_pair("uriPatterns", "*");
query.append_pair("limit", "50");
if let Some(c) = cursor {
query.append_pair("cursor", c);
}
}
let source_url = url.to_string();
tracing::debug!(
url = %source_url,
cursor = ?cursor,
"http stage: issuing queryLabels GET"
);
let response =
self.client
.get(url.as_str())
.send()
.await
.map_err(|e| HttpStageError::Transport {
message: e.to_string(),
source: Some(Box::new(e)),
})?;
let status = response.status();
let body_bytes = response
.bytes()
.await
.map_err(|e| HttpStageError::Transport {
message: e.to_string(),
source: Some(Box::new(e)),
})?;
tracing::debug!(
url = %source_url,
status = %status,
body_len = body_bytes.len(),
"http stage: queryLabels response received"
);
let raw_body: Arc<[u8]> = Arc::from(body_bytes.as_ref());
let decoded = decode_query_labels_output(&raw_body).map_err(|source| {
HttpStageError::DecodeFailed {
raw_body: raw_body.clone(),
source,
source_url: source_url.clone(),
}
})?;
Ok(RawXrpcResponse {
status,
raw_body,
decoded,
source_url,
})
}
}
fn decode_query_labels_output(body: &[u8]) -> Result<query_labels::Output, serde_json::Error> {
let mut value: serde_json::Value = serde_json::from_slice(body)?;
rewrite_atproto_json_bytes(&mut value);
serde_json::from_value(value)
}
fn rewrite_atproto_json_bytes(value: &mut serde_json::Value) {
use serde_json::Value;
match value {
Value::Object(map) => {
if let Some(decoded) = decode_atproto_bytes_wrapper(map) {
*value = Value::Array(
decoded
.into_iter()
.map(|b| Value::Number(b.into()))
.collect(),
);
return;
}
for child in map.values_mut() {
rewrite_atproto_json_bytes(child);
}
}
Value::Array(arr) => {
for child in arr.iter_mut() {
rewrite_atproto_json_bytes(child);
}
}
_ => {}
}
}
fn decode_atproto_bytes_wrapper(
map: &serde_json::Map<String, serde_json::Value>,
) -> Option<Vec<u8>> {
if map.len() != 1 {
return None;
}
let encoded = match map.get("$bytes")? {
serde_json::Value::String(s) => s,
_ => return None,
};
let stripped = encoded.trim_end_matches('=');
base64::engine::general_purpose::STANDARD_NO_PAD
.decode(stripped)
.ok()
}
fn decode_error_location_for_display(
pretty_body: &[u8],
raw_err: &serde_json::Error,
) -> (usize, usize) {
if let Err(err) = decode_query_labels_output(pretty_body) {
(err.line(), err.column())
} else {
(raw_err.line(), raw_err.column())
}
}
pub async fn run(http: &dyn RawHttpTee) -> HttpStageOutput {
let mut results = Vec::new();
let first_response = match http.query_labels(None).await {
Ok(resp) => {
if resp.status.is_success() {
results.push(Check::EndpointReachable.pass());
} else {
let status_code = resp.status;
results.push(CheckResult {
summary: Cow::Owned(format!(
"Labeler endpoint reachability (status {status_code})"
)),
..Check::EndpointReachable.pass()
});
}
resp
}
Err(HttpStageError::Transport { message, .. }) => {
results.push(CheckResult {
summary: Cow::Owned(format!("Network error: {message}")),
..Check::EndpointReachable.network_error()
});
return HttpStageOutput {
facts: None,
results,
};
}
Err(HttpStageError::DecodeFailed {
raw_body,
source,
source_url,
}) => {
results.push(Check::EndpointReachable.pass());
let pretty_body = pretty_json_for_display(&raw_body);
let (line, column) = decode_error_location_for_display(&pretty_body, &source);
let diagnostic = Box::new(HttpDecodeFailure {
message: format!("Failed to decode query_labels response: {source}"),
source_code: NamedSource::new(source_url.clone(), pretty_body.clone()),
span: Some(span_at_line_column(&pretty_body, line, column)),
});
results.push(Check::QueryLabelsSchemaFirstPage.spec_violation(Some(diagnostic)));
return HttpStageOutput {
facts: None,
results,
};
}
};
let output = &first_response.decoded;
results.push(Check::QueryLabelsSchemaFirstPage.pass());
let first_page_labels = output.labels.clone();
let first_page_raw_bytes = first_response.raw_body.clone();
let first_page_source_url = first_response.source_url.clone();
if first_page_labels.is_empty() {
results.push(Check::QueryLabelsEmptyAdvisory.advisory());
}
let pagination_ok = if let Some(cursor) = &output.cursor {
match http.query_labels(Some(cursor)).await {
Ok(second_resp) => {
let second_output = &second_resp.decoded;
if second_output.labels == first_page_labels {
results.push(Check::QueryLabelsSchemaSecondPage.pass());
results.push(Check::PaginationIgnoredCursor.spec_violation(None));
false
} else {
results.push(Check::QueryLabelsSchemaSecondPage.pass());
results.push(Check::PaginationRoundTrip.pass());
true
}
}
Err(HttpStageError::Transport { message, .. }) => {
results.push(CheckResult {
summary: Cow::Owned(format!("Network error fetching second page: {message}")),
..Check::QueryLabelsSchemaSecondPage.network_error()
});
false
}
Err(HttpStageError::DecodeFailed {
raw_body,
source,
source_url,
}) => {
let pretty_body = pretty_json_for_display(&raw_body);
let (line, column) = decode_error_location_for_display(&pretty_body, &source);
let diagnostic = Box::new(HttpDecodeFailure {
message: format!("Failed to decode second page response: {source}"),
source_code: NamedSource::new(source_url, pretty_body.clone()),
span: Some(span_at_line_column(&pretty_body, line, column)),
});
results.push(Check::QueryLabelsSchemaSecondPage.spec_violation(Some(diagnostic)));
false
}
}
} else {
results.push(CheckResult {
summary: Cow::Borrowed("First page was complete; pagination not exercised"),
..Check::PaginationRoundTrip.pass()
});
true
};
let facts = HttpFacts {
first_page: first_page_labels,
first_page_raw_bytes,
first_page_source_url,
pagination_ok,
};
HttpStageOutput {
facts: Some(facts),
results,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn rewrite_atproto_json_bytes_replaces_wrapper() {
let mut value: serde_json::Value =
serde_json::from_str(r#"{"sig": {"$bytes": "AAECAw"}, "other": 1}"#).unwrap();
rewrite_atproto_json_bytes(&mut value);
assert_eq!(value["sig"], serde_json::json!([0, 1, 2, 3]));
assert_eq!(value["other"], serde_json::json!(1));
}
#[test]
fn rewrite_atproto_json_bytes_accepts_padded_base64() {
let mut value: serde_json::Value =
serde_json::from_str(r#"{"$bytes": "AAECAw=="}"#).unwrap();
rewrite_atproto_json_bytes(&mut value);
assert_eq!(value, serde_json::json!([0, 1, 2, 3]));
}
#[test]
fn rewrite_atproto_json_bytes_ignores_non_wrapper_objects() {
let mut value: serde_json::Value =
serde_json::from_str(r#"{"$bytes": "AAECAw", "extra": true}"#).unwrap();
let before = value.clone();
rewrite_atproto_json_bytes(&mut value);
assert_eq!(value, before);
}
#[test]
fn decode_query_labels_output_handles_dollar_bytes_sig() {
let body = br#"{"cursor":"c","labels":[{"ver":1,"src":"did:plc:aaa22222222222222222bbbbbb","uri":"at://did:plc:aaa22222222222222222bbbbbb/app.bsky.feed.post/abc","val":"spam","cts":"2026-01-01T00:00:00.000Z","sig":{"$bytes":"AAECAw"}}]}"#;
let output = decode_query_labels_output(body).expect("should decode");
assert_eq!(output.labels.len(), 1);
let sig = output.labels[0].sig.as_ref().expect("sig present");
assert_eq!(sig, &vec![0u8, 1, 2, 3]);
}
}