use std::borrow::Cow;
use std::sync::Arc;
use atrium_api::app::bsky::labeler::defs::LabelerPolicies;
use atrium_api::app::bsky::labeler::service::RecordData as LabelerServiceRecordData;
use miette::{Diagnostic, NamedSource, SourceSpan};
use thiserror::Error;
use url::Url;
use crate::commands::test::labeler::pipeline::{AtIdentifier, LabelerTarget};
use crate::commands::test::labeler::report::{CheckResult, CheckStatus, Stage};
use crate::common::diagnostics::{
pretty_json_for_display, span_at_line_column, span_for_quoted_literal,
};
use crate::common::identity::{
AnyVerifyingKey, Did, DidDocument, DnsResolver, HttpClient, IdentityError, RawDidDocument,
find_service, is_local_labeler_hostname, parse_multikey, resolve_did, resolve_handle,
};
struct FetchedLabelerRecord {
bytes: Arc<[u8]>,
policies: LabelerPolicies,
reason_types: Option<Vec<String>>,
subject_types: Option<Vec<String>>,
subject_collections: Option<Vec<String>>,
}
#[derive(Debug, Clone)]
pub struct IdentityFacts {
pub did: Did,
pub raw_did_doc: RawDidDocument,
pub labeler_endpoint: Url,
pub pds_endpoint: Url,
pub signing_key_id: String,
pub signing_key_multikey: String,
pub signing_key: AnyVerifyingKey,
pub labeler_record_bytes: Arc<[u8]>,
pub labeler_policies: LabelerPolicies,
pub reason_types: Option<Vec<String>>,
pub subject_types: Option<Vec<String>>,
pub subject_collections: Option<Vec<String>>,
}
#[derive(Debug)]
pub struct IdentityStageOutput {
pub facts: Option<IdentityFacts>,
pub results: Vec<CheckResult>,
}
#[derive(Debug, Error, Diagnostic)]
#[error("{message}")]
#[diagnostic(code = "labeler::identity::labeler_service_present")]
struct ServiceMissingError {
message: String,
#[source_code]
named_source: NamedSource<Arc<[u8]>>,
#[label("service array")]
span: Option<SourceSpan>,
}
#[derive(Debug, Error, Diagnostic)]
#[error("{message}")]
#[diagnostic(code = "labeler::identity::labeler_endpoint_parseable")]
struct LabelerEndpointParseError {
message: String,
#[source_code]
named_source: NamedSource<Arc<[u8]>>,
#[label("endpoint value")]
span: Option<SourceSpan>,
}
#[derive(Debug, Error, Diagnostic)]
#[error("{message}")]
#[diagnostic(code = "labeler::identity::labeler_endpoint_is_https")]
struct NonHttpsLabelerEndpointError {
message: String,
#[source_code]
named_source: NamedSource<Arc<[u8]>>,
#[label("endpoint value")]
span: Option<SourceSpan>,
}
#[derive(Debug, Error, Diagnostic)]
#[error("{message}")]
#[diagnostic(code = "labeler::identity::resolved_did_matches_flag")]
struct EndpointMismatchError {
message: String,
#[source_code]
named_source: NamedSource<Arc<[u8]>>,
#[label("endpoint value")]
span: Option<SourceSpan>,
}
#[derive(Debug, Error, Diagnostic)]
#[error("{message}")]
#[diagnostic(code = "labeler::identity::signing_key_present")]
struct SigningKeyMissingError {
message: String,
#[source_code]
named_source: NamedSource<Arc<[u8]>>,
#[label("verificationMethod array")]
span: Option<SourceSpan>,
}
#[derive(Debug, Error, Diagnostic)]
#[error("{message}")]
#[diagnostic(code = "labeler::identity::signing_key_present")]
struct SigningKeyUnparseableError {
message: String,
#[source_code]
named_source: NamedSource<Arc<[u8]>>,
#[label("multikey")]
span: Option<SourceSpan>,
}
#[derive(Debug, Error, Diagnostic)]
#[error("{message}")]
#[diagnostic(code = "labeler::identity::pds_endpoint_present")]
struct PdsServiceMissingError {
message: String,
#[source_code]
named_source: NamedSource<Arc<[u8]>>,
#[label("service array")]
span: Option<SourceSpan>,
}
#[derive(serde::Deserialize)]
struct GetRecordResponse {
value: LabelerServiceRecordData,
#[serde(default)]
#[expect(dead_code)]
uri: Option<String>,
#[serde(default)]
#[expect(dead_code)]
cid: Option<String>,
}
#[derive(Debug, Error)]
enum FetchRecordError {
#[error("Network failure fetching labeler record")]
Network(#[from] IdentityError),
#[error("PDS returned 404: labeler record not found")]
NotFound,
#[error("PDS returned HTTP {status}")]
HttpStatus { status: u16, body: Arc<[u8]> },
#[error("Failed to parse PDS getRecord envelope: {source}")]
ParseEnvelope {
display_body: Arc<[u8]>,
display_line: usize,
display_column: usize,
#[source]
source: serde_json::Error,
},
}
#[derive(Debug, Error, Diagnostic)]
#[error("{message}")]
#[diagnostic(code = "labeler::identity::labeler_record_fetched")]
struct LabelerRecordFetchError {
message: String,
#[source_code]
named_source: Option<NamedSource<Arc<[u8]>>>,
#[label("response")]
span: Option<SourceSpan>,
}
#[derive(Debug, Error, Diagnostic)]
#[error("{message}")]
#[diagnostic(code = "labeler::identity::labeler_record_policies_nonempty")]
struct EmptyPoliciesError {
message: String,
#[source_code]
named_source: NamedSource<Arc<[u8]>>,
#[label("labelValues is empty")]
span: Option<SourceSpan>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Check {
TargetResolved,
DidDocumentFetched,
LabelerServicePresent,
LabelerEndpointParseable,
LabelerEndpointIsHttps,
ResolvedDidMatchesFlag,
SigningKeyPresent,
PdsEndpointPresent,
LabelerRecordFetched,
LabelerRecordPoliciesNonempty,
}
impl Check {
pub const ALL: &[Check] = &[
Check::TargetResolved,
Check::DidDocumentFetched,
Check::LabelerServicePresent,
Check::LabelerEndpointParseable,
Check::LabelerEndpointIsHttps,
Check::ResolvedDidMatchesFlag,
Check::SigningKeyPresent,
Check::PdsEndpointPresent,
Check::LabelerRecordFetched,
Check::LabelerRecordPoliciesNonempty,
];
pub fn id(self) -> &'static str {
match self {
Check::TargetResolved => "identity::target_resolved",
Check::DidDocumentFetched => "identity::did_document_fetched",
Check::LabelerServicePresent => "identity::labeler_service_present",
Check::LabelerEndpointParseable => "identity::labeler_endpoint_parseable",
Check::LabelerEndpointIsHttps => "identity::labeler_endpoint_is_https",
Check::ResolvedDidMatchesFlag => "identity::resolved_did_matches_flag",
Check::SigningKeyPresent => "identity::signing_key_present",
Check::PdsEndpointPresent => "identity::pds_endpoint_present",
Check::LabelerRecordFetched => "identity::labeler_record_fetched",
Check::LabelerRecordPoliciesNonempty => "identity::labeler_record_policies_nonempty",
}
}
fn summary_str(self) -> &'static str {
match self {
Check::TargetResolved => "target resolution",
Check::DidDocumentFetched => "DID document fetch",
Check::LabelerServicePresent => "labeler service entry",
Check::LabelerEndpointParseable => "labeler endpoint URL",
Check::LabelerEndpointIsHttps => "labeler endpoint scheme",
Check::ResolvedDidMatchesFlag => "resolved DID matches --did flag",
Check::SigningKeyPresent => "signing key entry",
Check::PdsEndpointPresent => "PDS endpoint entry",
Check::LabelerRecordFetched => "labeler record fetch",
Check::LabelerRecordPoliciesNonempty => "labeler record policy list",
}
}
pub fn pass(self) -> CheckResult {
CheckResult {
id: self.id(),
stage: Stage::Identity,
status: CheckStatus::Pass,
summary: Cow::Borrowed(self.summary_str()),
diagnostic: None,
skipped_reason: None,
}
}
pub fn spec_violation(
self,
diagnostic: Option<Box<dyn miette::Diagnostic + Send + Sync>>,
) -> CheckResult {
CheckResult {
id: self.id(),
stage: Stage::Identity,
status: CheckStatus::SpecViolation,
summary: Cow::Borrowed(self.summary_str()),
diagnostic,
skipped_reason: None,
}
}
pub fn network_error(
self,
diagnostic: Option<Box<dyn miette::Diagnostic + Send + Sync>>,
) -> CheckResult {
CheckResult {
id: self.id(),
stage: Stage::Identity,
status: CheckStatus::NetworkError,
summary: Cow::Borrowed(self.summary_str()),
diagnostic,
skipped_reason: None,
}
}
pub fn skip(self, reason: impl Into<Cow<'static, str>>) -> CheckResult {
CheckResult {
id: self.id(),
stage: Stage::Identity,
status: CheckStatus::Skipped,
summary: Cow::Borrowed(self.summary_str()),
diagnostic: None,
skipped_reason: Some(reason.into()),
}
}
pub fn advisory(
self,
diagnostic: Option<Box<dyn miette::Diagnostic + Send + Sync>>,
) -> CheckResult {
CheckResult {
id: self.id(),
stage: Stage::Identity,
status: CheckStatus::Advisory,
summary: Cow::Borrowed(self.summary_str()),
diagnostic,
skipped_reason: None,
}
}
pub fn blocked_by(self, prerequisite: Check) -> CheckResult {
self.skip(format!("blocked by {}", prerequisite.id()))
}
}
pub async fn run(
target: &LabelerTarget,
http: &dyn HttpClient,
dns: &dyn DnsResolver,
) -> IdentityStageOutput {
let mut results = Vec::new();
let mut block_facts = false;
if matches!(target, LabelerTarget::Endpoint { did: None, .. }) {
for check in Check::ALL {
results.push(check.skip("no DID supplied; run with a handle, a DID, or --did <did>"));
}
return IdentityStageOutput {
facts: None,
results,
};
}
let resolved_did: Option<Did> = match target {
LabelerTarget::Identified {
identifier,
explicit_did: _,
} => resolve_identifier(identifier, http, dns, &mut results).await,
LabelerTarget::Endpoint { did, .. } => {
results.push(Check::TargetResolved.pass());
did.clone()
}
};
let Some(did) = resolved_did else {
for check in &Check::ALL[1..] {
results.push(check.blocked_by(Check::TargetResolved));
}
return IdentityStageOutput {
facts: None,
results,
};
};
let raw_did_doc: Option<RawDidDocument> = match resolve_did(&did, http).await {
Ok(doc) => {
results.push(Check::DidDocumentFetched.pass());
Some(doc)
}
Err(e) => {
let result = match &e {
IdentityError::DidDocumentDecodeFailed {
source_name,
source_bytes,
cause,
} => {
let display_body = pretty_json_for_display(source_bytes.as_ref());
let (line, column) =
match serde_json::from_slice::<serde_json::Value>(display_body.as_ref()) {
Err(pretty_err) => (pretty_err.line(), pretty_err.column()),
Ok(_) => (cause.line(), cause.column()),
};
let span = span_at_line_column(display_body.as_ref(), line, column);
let diag: Box<dyn Diagnostic + Send + Sync> =
Box::new(DidDocumentDecodeError {
message: format!("DID document JSON decode failed: {e}"),
named_source: NamedSource::new(source_name.clone(), display_body),
span,
});
Check::DidDocumentFetched.spec_violation(Some(diag))
}
IdentityError::HttpTransport(_) => Check::DidDocumentFetched.network_error(None),
_ => Check::DidDocumentFetched.spec_violation(None),
};
block_facts = true;
results.push(result);
None
}
};
let raw_did_doc = match raw_did_doc {
Some(doc) => doc,
None => {
for check in &Check::ALL[2..] {
results.push(check.blocked_by(Check::DidDocumentFetched));
}
return IdentityStageOutput {
facts: None,
results,
};
}
};
let display_doc_bytes = pretty_json_for_display(raw_did_doc.source_bytes.as_ref());
let labeler_service =
match find_service(&raw_did_doc.parsed, "atproto_labeler", "AtprotoLabeler") {
Some(svc) => {
results.push(Check::LabelerServicePresent.pass());
Some(svc.clone())
}
None => {
let span = span_for_quoted_literal(display_doc_bytes.as_ref(), "service");
let diag = Box::new(ServiceMissingError {
message: "DID document is missing the #atproto_labeler service entry"
.to_string(),
named_source: NamedSource::new(
raw_did_doc.source_name.clone(),
display_doc_bytes.clone(),
),
span,
});
block_facts = true;
results.push(Check::LabelerServicePresent.spec_violation(Some(diag)));
None
}
};
let mut labeler_endpoint: Option<Url> = match labeler_service {
None => {
results.push(Check::LabelerEndpointParseable.blocked_by(Check::LabelerServicePresent));
results.push(Check::LabelerEndpointIsHttps.blocked_by(Check::LabelerServicePresent));
None
}
Some(svc) => match Url::parse(&svc.service_endpoint) {
Ok(url) => {
results.push(Check::LabelerEndpointParseable.pass());
let is_https = url.scheme() == "https";
let is_http_local = url.scheme() == "http" && is_local_labeler_hostname(&url);
if !is_https && !is_http_local {
let span =
span_for_quoted_literal(display_doc_bytes.as_ref(), &svc.service_endpoint);
let diag = Box::new(NonHttpsLabelerEndpointError {
message: format!(
"Labeler endpoint must use HTTPS (or HTTP with a local hostname), got: {}",
svc.service_endpoint
),
named_source: NamedSource::new(
raw_did_doc.source_name.clone(),
display_doc_bytes.clone(),
),
span,
});
block_facts = true;
results.push(Check::LabelerEndpointIsHttps.spec_violation(Some(diag)));
None
} else {
results.push(Check::LabelerEndpointIsHttps.pass());
Some(url)
}
}
Err(_) => {
let span =
span_for_quoted_literal(display_doc_bytes.as_ref(), &svc.service_endpoint);
let diag = Box::new(LabelerEndpointParseError {
message: format!(
"Labeler endpoint is not a valid URL: {}",
svc.service_endpoint
),
named_source: NamedSource::new(
raw_did_doc.source_name.clone(),
display_doc_bytes.clone(),
),
span,
});
block_facts = true;
results.push(Check::LabelerEndpointParseable.spec_violation(Some(diag)));
results.push(
Check::LabelerEndpointIsHttps.blocked_by(Check::LabelerEndpointParseable),
);
None
}
},
};
match (target, &labeler_endpoint) {
(
LabelerTarget::Endpoint {
url: flag_url,
did: Some(_),
},
Some(resolved_endpoint),
) => {
if endpoints_match(flag_url, resolved_endpoint) {
results.push(Check::ResolvedDidMatchesFlag.pass());
} else {
let service =
find_service(&raw_did_doc.parsed, "atproto_labeler", "AtprotoLabeler");
let span = service.and_then(|svc| {
span_for_quoted_literal(display_doc_bytes.as_ref(), &svc.service_endpoint)
});
if is_local_labeler_hostname(flag_url) {
let diag = Box::new(EndpointMismatchError {
message: format!(
"DID document endpoint ({resolved_endpoint}) does not match local override ({flag_url}); using the local URL for the remaining stages"
),
named_source: NamedSource::new(
raw_did_doc.source_name.clone(),
display_doc_bytes.clone(),
),
span,
});
results.push(Check::ResolvedDidMatchesFlag.advisory(Some(diag)));
labeler_endpoint = Some(flag_url.clone());
} else {
let diag = Box::new(EndpointMismatchError {
message: format!(
"DID document endpoint ({resolved_endpoint}) does not match provided endpoint ({flag_url})"
),
named_source: NamedSource::new(
raw_did_doc.source_name.clone(),
display_doc_bytes.clone(),
),
span,
});
block_facts = true;
results.push(Check::ResolvedDidMatchesFlag.spec_violation(Some(diag)));
}
}
}
(
LabelerTarget::Identified {
identifier: _,
explicit_did: Some(explicit),
},
_,
) => {
if explicit != &did {
block_facts = true;
results.push(Check::ResolvedDidMatchesFlag.spec_violation(None));
} else {
results.push(Check::ResolvedDidMatchesFlag.pass());
}
}
_ => {
results.push(Check::ResolvedDidMatchesFlag.skip("no endpoint override provided"));
}
}
let signing_key_ids: Option<(String, String)> = match find_signing_key(&raw_did_doc.parsed) {
Some((id, multikey_str)) => match parse_multikey(&multikey_str) {
Ok(_) => {
results.push(Check::SigningKeyPresent.pass());
Some((id, multikey_str))
}
Err(e) => {
let span = span_for_quoted_literal(display_doc_bytes.as_ref(), &multikey_str);
let diag = Box::new(SigningKeyUnparseableError {
message: format!("Failed to parse signing key multikey: {e}"),
named_source: NamedSource::new(
raw_did_doc.source_name.clone(),
display_doc_bytes.clone(),
),
span,
});
block_facts = true;
results.push(Check::SigningKeyPresent.spec_violation(Some(diag)));
None
}
},
None => {
let span = span_for_quoted_literal(display_doc_bytes.as_ref(), "verificationMethod");
let diag = Box::new(SigningKeyMissingError {
message: "DID document is missing the #atproto_label signing key".to_string(),
named_source: NamedSource::new(
raw_did_doc.source_name.clone(),
display_doc_bytes.clone(),
),
span,
});
block_facts = true;
results.push(Check::SigningKeyPresent.spec_violation(Some(diag)));
None
}
};
let signing_key = signing_key_ids.as_ref().and_then(|_| {
raw_did_doc
.parsed
.verification_method
.as_ref()
.and_then(|vms| {
vms.iter()
.find(|vm| {
vm.id.rsplit_once('#').map(|(_, f)| f).unwrap_or("") == "atproto_label"
})
.and_then(|vm| vm.public_key_multibase.as_deref())
})
.and_then(|mk| parse_multikey(mk).ok().map(|parsed| parsed.verifying_key))
});
let pds_endpoint: Option<Url> = match find_service(
&raw_did_doc.parsed,
"atproto_pds",
"AtprotoPersonalDataServer",
) {
Some(svc) => match Url::parse(&svc.service_endpoint) {
Ok(url) => {
results.push(Check::PdsEndpointPresent.pass());
Some(url)
}
Err(_) => {
let span =
span_for_quoted_literal(display_doc_bytes.as_ref(), &svc.service_endpoint);
let diag = Box::new(PdsServiceMissingError {
message: format!("PDS endpoint is not a valid URL: {}", svc.service_endpoint),
named_source: NamedSource::new(
raw_did_doc.source_name.clone(),
display_doc_bytes.clone(),
),
span,
});
block_facts = true;
results.push(Check::PdsEndpointPresent.spec_violation(Some(diag)));
None
}
},
None => {
let span = span_for_quoted_literal(display_doc_bytes.as_ref(), "service");
let diag = Box::new(PdsServiceMissingError {
message: "DID document is missing the #atproto_pds service entry".to_string(),
named_source: NamedSource::new(
raw_did_doc.source_name.clone(),
display_doc_bytes.clone(),
),
span,
});
block_facts = true;
results.push(Check::PdsEndpointPresent.spec_violation(Some(diag)));
None
}
};
let fetched_record: Option<FetchedLabelerRecord> = match &pds_endpoint {
None => {
results.push(Check::LabelerRecordFetched.blocked_by(Check::PdsEndpointPresent));
None
}
Some(pds_url) => match fetch_labeler_record(&did, pds_url, http).await {
Ok(record) => {
results.push(Check::LabelerRecordFetched.pass());
Some(record)
}
Err(e) => {
let (check_status, message, named_source, span) = match &e {
FetchRecordError::Network(_) => {
(CheckStatus::NetworkError, e.to_string(), None, None)
}
FetchRecordError::NotFound => {
(CheckStatus::SpecViolation, e.to_string(), None, None)
}
FetchRecordError::HttpStatus { .. } => {
(CheckStatus::SpecViolation, e.to_string(), None, None)
}
FetchRecordError::ParseEnvelope {
display_body,
display_line,
display_column,
source,
} => {
let src = NamedSource::new("PDS response", display_body.clone());
let span = span_at_line_column(
display_body.as_ref(),
*display_line,
*display_column,
);
(
CheckStatus::SpecViolation,
format!("Failed to parse PDS getRecord envelope: {source}"),
Some(src),
Some(span),
)
}
};
let diag = Box::new(LabelerRecordFetchError {
message,
named_source,
span,
});
block_facts = true;
let base = match check_status {
CheckStatus::NetworkError => {
Check::LabelerRecordFetched.network_error(Some(diag))
}
_ => Check::LabelerRecordFetched.spec_violation(Some(diag)),
};
results.push(base);
None
}
},
};
match &fetched_record {
None => {
results
.push(Check::LabelerRecordPoliciesNonempty.blocked_by(Check::LabelerRecordFetched));
}
Some(record) => {
if record.policies.label_values.is_empty() {
let display_bytes = pretty_json_for_display(record.bytes.as_ref());
let span = span_for_quoted_literal(display_bytes.as_ref(), "labelValues");
let diag = Box::new(EmptyPoliciesError {
message: "Labeler record policies.labelValues is empty".to_string(),
named_source: NamedSource::new("labeler record", display_bytes),
span,
});
block_facts = true;
results.push(Check::LabelerRecordPoliciesNonempty.spec_violation(Some(diag)));
} else {
results.push(Check::LabelerRecordPoliciesNonempty.pass());
}
}
}
let facts = if !block_facts {
match (
labeler_endpoint,
pds_endpoint,
signing_key_ids,
signing_key,
fetched_record,
) {
(Some(le), Some(pe), Some((ski, skm)), Some(sk), Some(record)) => Some(IdentityFacts {
did,
raw_did_doc,
labeler_endpoint: le,
pds_endpoint: pe,
signing_key_id: ski,
signing_key_multikey: skm,
signing_key: sk,
labeler_record_bytes: record.bytes,
labeler_policies: record.policies,
reason_types: record.reason_types,
subject_types: record.subject_types,
subject_collections: record.subject_collections,
}),
_ => None,
}
} else {
None
};
IdentityStageOutput { facts, results }
}
async fn resolve_identifier(
identifier: &AtIdentifier,
http: &dyn HttpClient,
dns: &dyn DnsResolver,
results: &mut Vec<CheckResult>,
) -> Option<Did> {
match identifier {
AtIdentifier::Handle(handle) => match resolve_handle(handle, http, dns).await {
Ok(did) => {
results.push(Check::TargetResolved.pass());
Some(did)
}
Err(e) => {
let is_network = matches!(
e,
IdentityError::HttpTransport(_)
| IdentityError::DnsLookupFailed { .. }
| IdentityError::HandleUnresolvable { .. }
);
if is_network {
results.push(Check::TargetResolved.network_error(None));
} else {
results.push(Check::TargetResolved.spec_violation(None));
}
None
}
},
AtIdentifier::Did(did) => {
results.push(Check::TargetResolved.pass());
Some(did.clone())
}
}
}
fn find_signing_key(doc: &DidDocument) -> Option<(String, String)> {
let vms = doc.verification_method.as_ref()?;
for vm in vms {
if vm.id.rsplit_once('#').map(|(_, f)| f).unwrap_or("") == "atproto_label" {
let multikey = vm.public_key_multibase.as_ref()?;
return Some((vm.id.clone(), multikey.clone()));
}
}
None
}
fn endpoints_match(url1: &Url, url2: &Url) -> bool {
url1.scheme() == url2.scheme()
&& url1.host_str() == url2.host_str()
&& url1.port() == url2.port()
}
async fn fetch_labeler_record(
did: &Did,
pds_endpoint: &Url,
http: &dyn HttpClient,
) -> Result<FetchedLabelerRecord, FetchRecordError> {
let mut url = pds_endpoint.clone();
url.set_path("/xrpc/com.atproto.repo.getRecord");
let query = format!(
"repo={}&collection=app.bsky.labeler.service&rkey=self",
did.0
);
url.set_query(Some(&query));
let (status, body) = match http.get_bytes(&url).await {
Ok((status, body)) => (status, body),
Err(e) => return Err(FetchRecordError::Network(e)),
};
let body_arc: Arc<[u8]> = Arc::from(body);
match status {
404 => Err(FetchRecordError::NotFound),
200 => {
match serde_json::from_slice::<GetRecordResponse>(body_arc.as_ref()) {
Ok(response) => {
let reason_types = response.value.reason_types.as_ref().map(|v| v.to_vec());
let subject_types = response.value.subject_types.as_ref().map(|v| v.to_vec());
let subject_collections = response
.value
.subject_collections
.as_ref()
.map(|v| v.iter().map(|n| n.to_string()).collect::<Vec<String>>());
Ok(FetchedLabelerRecord {
bytes: body_arc,
policies: response.value.policies,
reason_types,
subject_types,
subject_collections,
})
}
Err(raw_err) => {
let display_body = pretty_json_for_display(body_arc.as_ref());
let (display_line, display_column, source) =
match serde_json::from_slice::<GetRecordResponse>(display_body.as_ref()) {
Err(pretty_err) => (pretty_err.line(), pretty_err.column(), pretty_err),
Ok(_) => (raw_err.line(), raw_err.column(), raw_err),
};
Err(FetchRecordError::ParseEnvelope {
display_body,
display_line,
display_column,
source,
})
}
}
}
_ => Err(FetchRecordError::HttpStatus {
status,
body: body_arc,
}),
}
}
#[derive(Debug, Error, Diagnostic)]
#[error("{message}")]
#[diagnostic(code = "labeler::identity::did_document_fetched")]
struct DidDocumentDecodeError {
message: String,
#[source_code]
named_source: NamedSource<Arc<[u8]>>,
#[label("JSON parse error")]
span: SourceSpan,
}
#[cfg(test)]
mod tests {
use super::*;
use async_trait::async_trait;
struct FakeHttpClient {
responses: std::collections::HashMap<String, (u16, Vec<u8>)>,
}
impl FakeHttpClient {
fn new() -> Self {
Self {
responses: std::collections::HashMap::new(),
}
}
fn add_response(&mut self, url: impl Into<String>, status: u16, body: Vec<u8>) {
self.responses.insert(url.into(), (status, body));
}
}
#[async_trait]
impl HttpClient for FakeHttpClient {
async fn get_bytes(&self, url: &Url) -> Result<(u16, Vec<u8>), IdentityError> {
let url_str = url.as_str();
self.responses
.get(url_str)
.cloned()
.ok_or_else(|| IdentityError::DidResolutionFailed {
status: 404,
body: "Not found".to_string(),
})
}
}
#[tokio::test]
async fn identity_retains_reason_and_subject_types() {
let fixture_path = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR")).join(
"tests/fixtures/labeler/identity/report_stage_contract_present/labeler_record.json",
);
let labeler_record_bytes = std::fs::read(&fixture_path).expect("fixture file exists");
let mut http = FakeHttpClient::new();
let pds_url = Url::parse("https://pds.example.com").unwrap();
let did = Did("did:plc:test123456789012345678901234".to_string());
let query = format!(
"repo={}&collection=app.bsky.labeler.service&rkey=self",
did.0
);
let mut fetch_url = pds_url.clone();
fetch_url.set_path("/xrpc/com.atproto.repo.getRecord");
fetch_url.set_query(Some(&query));
http.add_response(fetch_url.as_str(), 200, labeler_record_bytes.clone());
let result = fetch_labeler_record(&did, &pds_url, &http).await;
assert!(result.is_ok(), "fetch_labeler_record should succeed");
let record = result.unwrap();
assert!(record.reason_types.is_some(), "reason_types should be Some");
let rt = record.reason_types.unwrap();
assert_eq!(rt.len(), 2, "reason_types should have 2 entries");
assert!(
rt.iter().any(|r| r.contains("reasonSpam")),
"should include reasonSpam"
);
assert!(
record.subject_types.is_some(),
"subject_types should be Some"
);
let st = record.subject_types.unwrap();
assert_eq!(st.len(), 2, "subject_types should have 2 entries");
assert!(st.iter().any(|s| s == "account"), "should include account");
assert!(st.iter().any(|s| s == "record"), "should include record");
assert!(
record.subject_collections.is_some(),
"subject_collections should be Some"
);
let sc = record.subject_collections.unwrap();
assert_eq!(sc.len(), 2, "subject_collections should have 2 entries");
assert!(
sc.iter().any(|s| s.contains("bsky.feed.post")),
"should include app.bsky.feed.post"
);
}
}