1use std::borrow::Cow;
7use std::sync::Arc;
8
9use atrium_api::app::bsky::labeler::defs::LabelerPolicies;
10use atrium_api::app::bsky::labeler::service::RecordData as LabelerServiceRecordData;
11use miette::{Diagnostic, NamedSource, SourceSpan};
12use thiserror::Error;
13use url::Url;
14
15use crate::commands::test::labeler::pipeline::{AtIdentifier, LabelerTarget};
16use crate::commands::test::labeler::report::{CheckResult, CheckStatus, Stage};
17use crate::common::diagnostics::{
18 pretty_json_for_display, span_at_line_column, span_for_quoted_literal,
19};
20use crate::common::identity::{
21 AnyVerifyingKey, Did, DidDocument, DnsResolver, HttpClient, IdentityError, RawDidDocument,
22 find_service, is_local_labeler_hostname, parse_multikey, resolve_did, resolve_handle,
23};
24
25struct FetchedLabelerRecord {
27 bytes: Arc<[u8]>,
29 policies: LabelerPolicies,
31 reason_types: Option<Vec<String>>,
33 subject_types: Option<Vec<String>>,
35 subject_collections: Option<Vec<String>>,
37}
38
39#[derive(Debug, Clone)]
41pub struct IdentityFacts {
42 pub did: Did,
44 pub raw_did_doc: RawDidDocument,
46 pub labeler_endpoint: Url,
48 pub pds_endpoint: Url,
50 pub signing_key_id: String,
52 pub signing_key_multikey: String,
58 pub signing_key: AnyVerifyingKey,
60 pub labeler_record_bytes: Arc<[u8]>,
62 pub labeler_policies: LabelerPolicies,
64 pub reason_types: Option<Vec<String>>,
69 pub subject_types: Option<Vec<String>>,
72 pub subject_collections: Option<Vec<String>>,
77}
78
79#[derive(Debug)]
81pub struct IdentityStageOutput {
82 pub facts: Option<IdentityFacts>,
84 pub results: Vec<CheckResult>,
86}
87
88#[derive(Debug, Error, Diagnostic)]
90#[error("{message}")]
91#[diagnostic(code = "labeler::identity::labeler_service_present")]
92struct ServiceMissingError {
93 message: String,
95 #[source_code]
97 named_source: NamedSource<Arc<[u8]>>,
98 #[label("service array")]
100 span: Option<SourceSpan>,
101}
102
103#[derive(Debug, Error, Diagnostic)]
105#[error("{message}")]
106#[diagnostic(code = "labeler::identity::labeler_endpoint_parseable")]
107struct LabelerEndpointParseError {
108 message: String,
110 #[source_code]
112 named_source: NamedSource<Arc<[u8]>>,
113 #[label("endpoint value")]
115 span: Option<SourceSpan>,
116}
117
118#[derive(Debug, Error, Diagnostic)]
120#[error("{message}")]
121#[diagnostic(code = "labeler::identity::labeler_endpoint_is_https")]
122struct NonHttpsLabelerEndpointError {
123 message: String,
125 #[source_code]
127 named_source: NamedSource<Arc<[u8]>>,
128 #[label("endpoint value")]
130 span: Option<SourceSpan>,
131}
132
133#[derive(Debug, Error, Diagnostic)]
135#[error("{message}")]
136#[diagnostic(code = "labeler::identity::resolved_did_matches_flag")]
137struct EndpointMismatchError {
138 message: String,
140 #[source_code]
142 named_source: NamedSource<Arc<[u8]>>,
143 #[label("endpoint value")]
145 span: Option<SourceSpan>,
146}
147
148#[derive(Debug, Error, Diagnostic)]
150#[error("{message}")]
151#[diagnostic(code = "labeler::identity::signing_key_present")]
152struct SigningKeyMissingError {
153 message: String,
155 #[source_code]
157 named_source: NamedSource<Arc<[u8]>>,
158 #[label("verificationMethod array")]
160 span: Option<SourceSpan>,
161}
162
163#[derive(Debug, Error, Diagnostic)]
165#[error("{message}")]
166#[diagnostic(code = "labeler::identity::signing_key_present")]
167struct SigningKeyUnparseableError {
168 message: String,
170 #[source_code]
172 named_source: NamedSource<Arc<[u8]>>,
173 #[label("multikey")]
175 span: Option<SourceSpan>,
176}
177
178#[derive(Debug, Error, Diagnostic)]
180#[error("{message}")]
181#[diagnostic(code = "labeler::identity::pds_endpoint_present")]
182struct PdsServiceMissingError {
183 message: String,
185 #[source_code]
187 named_source: NamedSource<Arc<[u8]>>,
188 #[label("service array")]
190 span: Option<SourceSpan>,
191}
192
193#[derive(serde::Deserialize)]
199struct GetRecordResponse {
200 value: LabelerServiceRecordData,
202 #[serde(default)]
204 #[expect(dead_code)]
205 uri: Option<String>,
206 #[serde(default)]
208 #[expect(dead_code)]
209 cid: Option<String>,
210}
211
212#[derive(Debug, Error)]
214enum FetchRecordError {
215 #[error("Network failure fetching labeler record")]
217 Network(#[from] IdentityError),
218
219 #[error("PDS returned 404: labeler record not found")]
221 NotFound,
222
223 #[error("PDS returned HTTP {status}")]
225 HttpStatus { status: u16, body: Arc<[u8]> },
226
227 #[error("Failed to parse PDS getRecord envelope: {source}")]
234 ParseEnvelope {
235 display_body: Arc<[u8]>,
236 display_line: usize,
237 display_column: usize,
238 #[source]
239 source: serde_json::Error,
240 },
241}
242
243#[derive(Debug, Error, Diagnostic)]
245#[error("{message}")]
246#[diagnostic(code = "labeler::identity::labeler_record_fetched")]
247struct LabelerRecordFetchError {
248 message: String,
250 #[source_code]
252 named_source: Option<NamedSource<Arc<[u8]>>>,
253 #[label("response")]
255 span: Option<SourceSpan>,
256}
257
258#[derive(Debug, Error, Diagnostic)]
260#[error("{message}")]
261#[diagnostic(code = "labeler::identity::labeler_record_policies_nonempty")]
262struct EmptyPoliciesError {
263 message: String,
265 #[source_code]
267 named_source: NamedSource<Arc<[u8]>>,
268 #[label("labelValues is empty")]
270 span: Option<SourceSpan>,
271}
272
273#[derive(Debug, Clone, Copy, PartialEq, Eq)]
275pub enum Check {
276 TargetResolved,
278 DidDocumentFetched,
280 LabelerServicePresent,
282 LabelerEndpointParseable,
284 LabelerEndpointIsHttps,
286 ResolvedDidMatchesFlag,
288 SigningKeyPresent,
290 PdsEndpointPresent,
292 LabelerRecordFetched,
294 LabelerRecordPoliciesNonempty,
296}
297
298impl Check {
299 pub const ALL: &[Check] = &[
301 Check::TargetResolved,
302 Check::DidDocumentFetched,
303 Check::LabelerServicePresent,
304 Check::LabelerEndpointParseable,
305 Check::LabelerEndpointIsHttps,
306 Check::ResolvedDidMatchesFlag,
307 Check::SigningKeyPresent,
308 Check::PdsEndpointPresent,
309 Check::LabelerRecordFetched,
310 Check::LabelerRecordPoliciesNonempty,
311 ];
312
313 pub fn id(self) -> &'static str {
315 match self {
316 Check::TargetResolved => "identity::target_resolved",
317 Check::DidDocumentFetched => "identity::did_document_fetched",
318 Check::LabelerServicePresent => "identity::labeler_service_present",
319 Check::LabelerEndpointParseable => "identity::labeler_endpoint_parseable",
320 Check::LabelerEndpointIsHttps => "identity::labeler_endpoint_is_https",
321 Check::ResolvedDidMatchesFlag => "identity::resolved_did_matches_flag",
322 Check::SigningKeyPresent => "identity::signing_key_present",
323 Check::PdsEndpointPresent => "identity::pds_endpoint_present",
324 Check::LabelerRecordFetched => "identity::labeler_record_fetched",
325 Check::LabelerRecordPoliciesNonempty => "identity::labeler_record_policies_nonempty",
326 }
327 }
328
329 fn summary_str(self) -> &'static str {
331 match self {
332 Check::TargetResolved => "target resolution",
333 Check::DidDocumentFetched => "DID document fetch",
334 Check::LabelerServicePresent => "labeler service entry",
335 Check::LabelerEndpointParseable => "labeler endpoint URL",
336 Check::LabelerEndpointIsHttps => "labeler endpoint scheme",
337 Check::ResolvedDidMatchesFlag => "resolved DID matches --did flag",
338 Check::SigningKeyPresent => "signing key entry",
339 Check::PdsEndpointPresent => "PDS endpoint entry",
340 Check::LabelerRecordFetched => "labeler record fetch",
341 Check::LabelerRecordPoliciesNonempty => "labeler record policy list",
342 }
343 }
344
345 pub fn pass(self) -> CheckResult {
346 CheckResult {
347 id: self.id(),
348 stage: Stage::Identity,
349 status: CheckStatus::Pass,
350 summary: Cow::Borrowed(self.summary_str()),
351 diagnostic: None,
352 skipped_reason: None,
353 }
354 }
355
356 pub fn spec_violation(
357 self,
358 diagnostic: Option<Box<dyn miette::Diagnostic + Send + Sync>>,
359 ) -> CheckResult {
360 CheckResult {
361 id: self.id(),
362 stage: Stage::Identity,
363 status: CheckStatus::SpecViolation,
364 summary: Cow::Borrowed(self.summary_str()),
365 diagnostic,
366 skipped_reason: None,
367 }
368 }
369
370 pub fn network_error(
371 self,
372 diagnostic: Option<Box<dyn miette::Diagnostic + Send + Sync>>,
373 ) -> CheckResult {
374 CheckResult {
375 id: self.id(),
376 stage: Stage::Identity,
377 status: CheckStatus::NetworkError,
378 summary: Cow::Borrowed(self.summary_str()),
379 diagnostic,
380 skipped_reason: None,
381 }
382 }
383
384 pub fn skip(self, reason: impl Into<Cow<'static, str>>) -> CheckResult {
385 CheckResult {
386 id: self.id(),
387 stage: Stage::Identity,
388 status: CheckStatus::Skipped,
389 summary: Cow::Borrowed(self.summary_str()),
390 diagnostic: None,
391 skipped_reason: Some(reason.into()),
392 }
393 }
394
395 pub fn advisory(
396 self,
397 diagnostic: Option<Box<dyn miette::Diagnostic + Send + Sync>>,
398 ) -> CheckResult {
399 CheckResult {
400 id: self.id(),
401 stage: Stage::Identity,
402 status: CheckStatus::Advisory,
403 summary: Cow::Borrowed(self.summary_str()),
404 diagnostic,
405 skipped_reason: None,
406 }
407 }
408
409 pub fn blocked_by(self, prerequisite: Check) -> CheckResult {
411 self.skip(format!("blocked by {}", prerequisite.id()))
412 }
413}
414
415pub async fn run(
417 target: &LabelerTarget,
418 http: &dyn HttpClient,
419 dns: &dyn DnsResolver,
420) -> IdentityStageOutput {
421 let mut results = Vec::new();
422 let mut block_facts = false;
423
424 if matches!(target, LabelerTarget::Endpoint { did: None, .. }) {
426 for check in Check::ALL {
427 results.push(check.skip("no DID supplied; run with a handle, a DID, or --did <did>"));
428 }
429 return IdentityStageOutput {
430 facts: None,
431 results,
432 };
433 }
434
435 let resolved_did: Option<Did> = match target {
437 LabelerTarget::Identified {
438 identifier,
439 explicit_did: _,
440 } => resolve_identifier(identifier, http, dns, &mut results).await,
441 LabelerTarget::Endpoint { did, .. } => {
442 results.push(Check::TargetResolved.pass());
444 did.clone()
445 }
446 };
447
448 let Some(did) = resolved_did else {
449 for check in &Check::ALL[1..] {
451 results.push(check.blocked_by(Check::TargetResolved));
452 }
453 return IdentityStageOutput {
454 facts: None,
455 results,
456 };
457 };
458
459 let raw_did_doc: Option<RawDidDocument> = match resolve_did(&did, http).await {
461 Ok(doc) => {
462 results.push(Check::DidDocumentFetched.pass());
463 Some(doc)
464 }
465 Err(e) => {
466 let result = match &e {
467 IdentityError::DidDocumentDecodeFailed {
468 source_name,
469 source_bytes,
470 cause,
471 } => {
472 let display_body = pretty_json_for_display(source_bytes.as_ref());
476 let (line, column) =
477 match serde_json::from_slice::<serde_json::Value>(display_body.as_ref()) {
478 Err(pretty_err) => (pretty_err.line(), pretty_err.column()),
479 Ok(_) => (cause.line(), cause.column()),
480 };
481 let span = span_at_line_column(display_body.as_ref(), line, column);
482 let diag: Box<dyn Diagnostic + Send + Sync> =
483 Box::new(DidDocumentDecodeError {
484 message: format!("DID document JSON decode failed: {e}"),
485 named_source: NamedSource::new(source_name.clone(), display_body),
486 span,
487 });
488 Check::DidDocumentFetched.spec_violation(Some(diag))
489 }
490 IdentityError::HttpTransport(_) => Check::DidDocumentFetched.network_error(None),
491 _ => Check::DidDocumentFetched.spec_violation(None),
492 };
493 block_facts = true;
494 results.push(result);
495 None
496 }
497 };
498
499 let raw_did_doc = match raw_did_doc {
500 Some(doc) => doc,
501 None => {
502 for check in &Check::ALL[2..] {
504 results.push(check.blocked_by(Check::DidDocumentFetched));
505 }
506 return IdentityStageOutput {
507 facts: None,
508 results,
509 };
510 }
511 };
512
513 let display_doc_bytes = pretty_json_for_display(raw_did_doc.source_bytes.as_ref());
519
520 let labeler_service =
522 match find_service(&raw_did_doc.parsed, "atproto_labeler", "AtprotoLabeler") {
523 Some(svc) => {
524 results.push(Check::LabelerServicePresent.pass());
525 Some(svc.clone())
526 }
527 None => {
528 let span = span_for_quoted_literal(display_doc_bytes.as_ref(), "service");
529 let diag = Box::new(ServiceMissingError {
530 message: "DID document is missing the #atproto_labeler service entry"
531 .to_string(),
532 named_source: NamedSource::new(
533 raw_did_doc.source_name.clone(),
534 display_doc_bytes.clone(),
535 ),
536 span,
537 });
538 block_facts = true;
539 results.push(Check::LabelerServicePresent.spec_violation(Some(diag)));
540 None
541 }
542 };
543
544 let mut labeler_endpoint: Option<Url> = match labeler_service {
552 None => {
553 results.push(Check::LabelerEndpointParseable.blocked_by(Check::LabelerServicePresent));
554 results.push(Check::LabelerEndpointIsHttps.blocked_by(Check::LabelerServicePresent));
555 None
556 }
557 Some(svc) => match Url::parse(&svc.service_endpoint) {
558 Ok(url) => {
559 results.push(Check::LabelerEndpointParseable.pass());
560 let is_https = url.scheme() == "https";
565 let is_http_local = url.scheme() == "http" && is_local_labeler_hostname(&url);
566 if !is_https && !is_http_local {
567 let span =
568 span_for_quoted_literal(display_doc_bytes.as_ref(), &svc.service_endpoint);
569 let diag = Box::new(NonHttpsLabelerEndpointError {
570 message: format!(
571 "Labeler endpoint must use HTTPS (or HTTP with a local hostname), got: {}",
572 svc.service_endpoint
573 ),
574 named_source: NamedSource::new(
575 raw_did_doc.source_name.clone(),
576 display_doc_bytes.clone(),
577 ),
578 span,
579 });
580 block_facts = true;
581 results.push(Check::LabelerEndpointIsHttps.spec_violation(Some(diag)));
582 None
583 } else {
584 results.push(Check::LabelerEndpointIsHttps.pass());
585 Some(url)
586 }
587 }
588 Err(_) => {
589 let span =
590 span_for_quoted_literal(display_doc_bytes.as_ref(), &svc.service_endpoint);
591 let diag = Box::new(LabelerEndpointParseError {
592 message: format!(
593 "Labeler endpoint is not a valid URL: {}",
594 svc.service_endpoint
595 ),
596 named_source: NamedSource::new(
597 raw_did_doc.source_name.clone(),
598 display_doc_bytes.clone(),
599 ),
600 span,
601 });
602 block_facts = true;
603 results.push(Check::LabelerEndpointParseable.spec_violation(Some(diag)));
604 results.push(
605 Check::LabelerEndpointIsHttps.blocked_by(Check::LabelerEndpointParseable),
606 );
607 None
608 }
609 },
610 };
611
612 match (target, &labeler_endpoint) {
616 (
617 LabelerTarget::Endpoint {
618 url: flag_url,
619 did: Some(_),
620 },
621 Some(resolved_endpoint),
622 ) => {
623 if endpoints_match(flag_url, resolved_endpoint) {
624 results.push(Check::ResolvedDidMatchesFlag.pass());
625 } else {
626 let service =
628 find_service(&raw_did_doc.parsed, "atproto_labeler", "AtprotoLabeler");
629 let span = service.and_then(|svc| {
630 span_for_quoted_literal(display_doc_bytes.as_ref(), &svc.service_endpoint)
631 });
632
633 if is_local_labeler_hostname(flag_url) {
634 let diag = Box::new(EndpointMismatchError {
641 message: format!(
642 "DID document endpoint ({resolved_endpoint}) does not match local override ({flag_url}); using the local URL for the remaining stages"
643 ),
644 named_source: NamedSource::new(
645 raw_did_doc.source_name.clone(),
646 display_doc_bytes.clone(),
647 ),
648 span,
649 });
650 results.push(Check::ResolvedDidMatchesFlag.advisory(Some(diag)));
651 labeler_endpoint = Some(flag_url.clone());
652 } else {
653 let diag = Box::new(EndpointMismatchError {
654 message: format!(
655 "DID document endpoint ({resolved_endpoint}) does not match provided endpoint ({flag_url})"
656 ),
657 named_source: NamedSource::new(
658 raw_did_doc.source_name.clone(),
659 display_doc_bytes.clone(),
660 ),
661 span,
662 });
663 block_facts = true;
664 results.push(Check::ResolvedDidMatchesFlag.spec_violation(Some(diag)));
665 }
666 }
667 }
668 (
669 LabelerTarget::Identified {
670 identifier: _,
671 explicit_did: Some(explicit),
672 },
673 _,
674 ) => {
675 if explicit != &did {
676 block_facts = true;
677 results.push(Check::ResolvedDidMatchesFlag.spec_violation(None));
678 } else {
679 results.push(Check::ResolvedDidMatchesFlag.pass());
680 }
681 }
682 _ => {
683 results.push(Check::ResolvedDidMatchesFlag.skip("no endpoint override provided"));
685 }
686 }
687
688 let signing_key_ids: Option<(String, String)> = match find_signing_key(&raw_did_doc.parsed) {
690 Some((id, multikey_str)) => match parse_multikey(&multikey_str) {
691 Ok(_) => {
692 results.push(Check::SigningKeyPresent.pass());
693 Some((id, multikey_str))
694 }
695 Err(e) => {
696 let span = span_for_quoted_literal(display_doc_bytes.as_ref(), &multikey_str);
697 let diag = Box::new(SigningKeyUnparseableError {
698 message: format!("Failed to parse signing key multikey: {e}"),
699 named_source: NamedSource::new(
700 raw_did_doc.source_name.clone(),
701 display_doc_bytes.clone(),
702 ),
703 span,
704 });
705 block_facts = true;
706 results.push(Check::SigningKeyPresent.spec_violation(Some(diag)));
707 None
708 }
709 },
710 None => {
711 let span = span_for_quoted_literal(display_doc_bytes.as_ref(), "verificationMethod");
712 let diag = Box::new(SigningKeyMissingError {
713 message: "DID document is missing the #atproto_label signing key".to_string(),
714 named_source: NamedSource::new(
715 raw_did_doc.source_name.clone(),
716 display_doc_bytes.clone(),
717 ),
718 span,
719 });
720 block_facts = true;
721 results.push(Check::SigningKeyPresent.spec_violation(Some(diag)));
722 None
723 }
724 };
725
726 let signing_key = signing_key_ids.as_ref().and_then(|_| {
728 raw_did_doc
729 .parsed
730 .verification_method
731 .as_ref()
732 .and_then(|vms| {
733 vms.iter()
734 .find(|vm| {
735 vm.id.rsplit_once('#').map(|(_, f)| f).unwrap_or("") == "atproto_label"
736 })
737 .and_then(|vm| vm.public_key_multibase.as_deref())
738 })
739 .and_then(|mk| parse_multikey(mk).ok().map(|parsed| parsed.verifying_key))
740 });
741
742 let pds_endpoint: Option<Url> = match find_service(
744 &raw_did_doc.parsed,
745 "atproto_pds",
746 "AtprotoPersonalDataServer",
747 ) {
748 Some(svc) => match Url::parse(&svc.service_endpoint) {
749 Ok(url) => {
750 results.push(Check::PdsEndpointPresent.pass());
751 Some(url)
752 }
753 Err(_) => {
754 let span =
755 span_for_quoted_literal(display_doc_bytes.as_ref(), &svc.service_endpoint);
756 let diag = Box::new(PdsServiceMissingError {
757 message: format!("PDS endpoint is not a valid URL: {}", svc.service_endpoint),
758 named_source: NamedSource::new(
759 raw_did_doc.source_name.clone(),
760 display_doc_bytes.clone(),
761 ),
762 span,
763 });
764 block_facts = true;
765 results.push(Check::PdsEndpointPresent.spec_violation(Some(diag)));
766 None
767 }
768 },
769 None => {
770 let span = span_for_quoted_literal(display_doc_bytes.as_ref(), "service");
771 let diag = Box::new(PdsServiceMissingError {
772 message: "DID document is missing the #atproto_pds service entry".to_string(),
773 named_source: NamedSource::new(
774 raw_did_doc.source_name.clone(),
775 display_doc_bytes.clone(),
776 ),
777 span,
778 });
779 block_facts = true;
780 results.push(Check::PdsEndpointPresent.spec_violation(Some(diag)));
781 None
782 }
783 };
784
785 let fetched_record: Option<FetchedLabelerRecord> = match &pds_endpoint {
787 None => {
788 results.push(Check::LabelerRecordFetched.blocked_by(Check::PdsEndpointPresent));
789 None
790 }
791 Some(pds_url) => match fetch_labeler_record(&did, pds_url, http).await {
792 Ok(record) => {
793 results.push(Check::LabelerRecordFetched.pass());
794 Some(record)
795 }
796 Err(e) => {
797 let (check_status, message, named_source, span) = match &e {
798 FetchRecordError::Network(_) => {
799 (CheckStatus::NetworkError, e.to_string(), None, None)
800 }
801 FetchRecordError::NotFound => {
802 (CheckStatus::SpecViolation, e.to_string(), None, None)
803 }
804 FetchRecordError::HttpStatus { .. } => {
805 (CheckStatus::SpecViolation, e.to_string(), None, None)
806 }
807 FetchRecordError::ParseEnvelope {
808 display_body,
809 display_line,
810 display_column,
811 source,
812 } => {
813 let src = NamedSource::new("PDS response", display_body.clone());
814 let span = span_at_line_column(
815 display_body.as_ref(),
816 *display_line,
817 *display_column,
818 );
819 (
820 CheckStatus::SpecViolation,
821 format!("Failed to parse PDS getRecord envelope: {source}"),
822 Some(src),
823 Some(span),
824 )
825 }
826 };
827 let diag = Box::new(LabelerRecordFetchError {
828 message,
829 named_source,
830 span,
831 });
832 block_facts = true;
833 let base = match check_status {
834 CheckStatus::NetworkError => {
835 Check::LabelerRecordFetched.network_error(Some(diag))
836 }
837 _ => Check::LabelerRecordFetched.spec_violation(Some(diag)),
838 };
839 results.push(base);
840 None
841 }
842 },
843 };
844
845 match &fetched_record {
847 None => {
848 results
849 .push(Check::LabelerRecordPoliciesNonempty.blocked_by(Check::LabelerRecordFetched));
850 }
851 Some(record) => {
852 if record.policies.label_values.is_empty() {
853 let display_bytes = pretty_json_for_display(record.bytes.as_ref());
854 let span = span_for_quoted_literal(display_bytes.as_ref(), "labelValues");
855 let diag = Box::new(EmptyPoliciesError {
856 message: "Labeler record policies.labelValues is empty".to_string(),
857 named_source: NamedSource::new("labeler record", display_bytes),
858 span,
859 });
860 block_facts = true;
861 results.push(Check::LabelerRecordPoliciesNonempty.spec_violation(Some(diag)));
862 } else {
863 results.push(Check::LabelerRecordPoliciesNonempty.pass());
864 }
865 }
866 }
867
868 let facts = if !block_facts {
870 match (
871 labeler_endpoint,
872 pds_endpoint,
873 signing_key_ids,
874 signing_key,
875 fetched_record,
876 ) {
877 (Some(le), Some(pe), Some((ski, skm)), Some(sk), Some(record)) => Some(IdentityFacts {
878 did,
879 raw_did_doc,
880 labeler_endpoint: le,
881 pds_endpoint: pe,
882 signing_key_id: ski,
883 signing_key_multikey: skm,
884 signing_key: sk,
885 labeler_record_bytes: record.bytes,
886 labeler_policies: record.policies,
887 reason_types: record.reason_types,
888 subject_types: record.subject_types,
889 subject_collections: record.subject_collections,
890 }),
891 _ => None,
892 }
893 } else {
894 None
895 };
896
897 IdentityStageOutput { facts, results }
898}
899
900async fn resolve_identifier(
902 identifier: &AtIdentifier,
903 http: &dyn HttpClient,
904 dns: &dyn DnsResolver,
905 results: &mut Vec<CheckResult>,
906) -> Option<Did> {
907 match identifier {
910 AtIdentifier::Handle(handle) => match resolve_handle(handle, http, dns).await {
911 Ok(did) => {
912 results.push(Check::TargetResolved.pass());
913 Some(did)
914 }
915 Err(e) => {
916 let is_network = matches!(
917 e,
918 IdentityError::HttpTransport(_)
919 | IdentityError::DnsLookupFailed { .. }
920 | IdentityError::HandleUnresolvable { .. }
921 );
922 if is_network {
923 results.push(Check::TargetResolved.network_error(None));
924 } else {
925 results.push(Check::TargetResolved.spec_violation(None));
926 }
927 None
928 }
929 },
930 AtIdentifier::Did(did) => {
931 results.push(Check::TargetResolved.pass());
933 Some(did.clone())
934 }
935 }
936}
937
938fn find_signing_key(doc: &DidDocument) -> Option<(String, String)> {
940 let vms = doc.verification_method.as_ref()?;
941 for vm in vms {
942 if vm.id.rsplit_once('#').map(|(_, f)| f).unwrap_or("") == "atproto_label" {
943 let multikey = vm.public_key_multibase.as_ref()?;
944 return Some((vm.id.clone(), multikey.clone()));
945 }
946 }
947 None
948}
949
950fn endpoints_match(url1: &Url, url2: &Url) -> bool {
952 url1.scheme() == url2.scheme()
953 && url1.host_str() == url2.host_str()
954 && url1.port() == url2.port()
955}
956
957async fn fetch_labeler_record(
960 did: &Did,
961 pds_endpoint: &Url,
962 http: &dyn HttpClient,
963) -> Result<FetchedLabelerRecord, FetchRecordError> {
964 let mut url = pds_endpoint.clone();
966 url.set_path("/xrpc/com.atproto.repo.getRecord");
967 let query = format!(
968 "repo={}&collection=app.bsky.labeler.service&rkey=self",
969 did.0
970 );
971 url.set_query(Some(&query));
972
973 let (status, body) = match http.get_bytes(&url).await {
975 Ok((status, body)) => (status, body),
976 Err(e) => return Err(FetchRecordError::Network(e)),
977 };
978
979 let body_arc: Arc<[u8]> = Arc::from(body);
980
981 match status {
983 404 => Err(FetchRecordError::NotFound),
984 200 => {
985 match serde_json::from_slice::<GetRecordResponse>(body_arc.as_ref()) {
989 Ok(response) => {
990 let reason_types = response.value.reason_types.as_ref().map(|v| v.to_vec());
991 let subject_types = response.value.subject_types.as_ref().map(|v| v.to_vec());
992 let subject_collections = response
993 .value
994 .subject_collections
995 .as_ref()
996 .map(|v| v.iter().map(|n| n.to_string()).collect::<Vec<String>>());
997 Ok(FetchedLabelerRecord {
998 bytes: body_arc,
999 policies: response.value.policies,
1000 reason_types,
1001 subject_types,
1002 subject_collections,
1003 })
1004 }
1005 Err(raw_err) => {
1006 let display_body = pretty_json_for_display(body_arc.as_ref());
1013 let (display_line, display_column, source) =
1014 match serde_json::from_slice::<GetRecordResponse>(display_body.as_ref()) {
1015 Err(pretty_err) => (pretty_err.line(), pretty_err.column(), pretty_err),
1016 Ok(_) => (raw_err.line(), raw_err.column(), raw_err),
1017 };
1018 Err(FetchRecordError::ParseEnvelope {
1019 display_body,
1020 display_line,
1021 display_column,
1022 source,
1023 })
1024 }
1025 }
1026 }
1027 _ => Err(FetchRecordError::HttpStatus {
1028 status,
1029 body: body_arc,
1030 }),
1031 }
1032}
1033
1034#[derive(Debug, Error, Diagnostic)]
1036#[error("{message}")]
1037#[diagnostic(code = "labeler::identity::did_document_fetched")]
1038struct DidDocumentDecodeError {
1039 message: String,
1040 #[source_code]
1041 named_source: NamedSource<Arc<[u8]>>,
1042 #[label("JSON parse error")]
1043 span: SourceSpan,
1044}
1045
1046#[cfg(test)]
1047mod tests {
1048 use super::*;
1049 use async_trait::async_trait;
1050
1051 struct FakeHttpClient {
1053 responses: std::collections::HashMap<String, (u16, Vec<u8>)>,
1054 }
1055
1056 impl FakeHttpClient {
1057 fn new() -> Self {
1058 Self {
1059 responses: std::collections::HashMap::new(),
1060 }
1061 }
1062
1063 fn add_response(&mut self, url: impl Into<String>, status: u16, body: Vec<u8>) {
1064 self.responses.insert(url.into(), (status, body));
1065 }
1066 }
1067
1068 #[async_trait]
1069 impl HttpClient for FakeHttpClient {
1070 async fn get_bytes(&self, url: &Url) -> Result<(u16, Vec<u8>), IdentityError> {
1071 let url_str = url.as_str();
1072 self.responses
1073 .get(url_str)
1074 .cloned()
1075 .ok_or_else(|| IdentityError::DidResolutionFailed {
1076 status: 404,
1077 body: "Not found".to_string(),
1078 })
1079 }
1080 }
1081
1082 #[tokio::test]
1083 async fn identity_retains_reason_and_subject_types() {
1084 let fixture_path = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR")).join(
1086 "tests/fixtures/labeler/identity/report_stage_contract_present/labeler_record.json",
1087 );
1088 let labeler_record_bytes = std::fs::read(&fixture_path).expect("fixture file exists");
1089
1090 let mut http = FakeHttpClient::new();
1092 let pds_url = Url::parse("https://pds.example.com").unwrap();
1093 let did = Did("did:plc:test123456789012345678901234".to_string());
1094
1095 let query = format!(
1097 "repo={}&collection=app.bsky.labeler.service&rkey=self",
1098 did.0
1099 );
1100 let mut fetch_url = pds_url.clone();
1101 fetch_url.set_path("/xrpc/com.atproto.repo.getRecord");
1102 fetch_url.set_query(Some(&query));
1103
1104 http.add_response(fetch_url.as_str(), 200, labeler_record_bytes.clone());
1105
1106 let result = fetch_labeler_record(&did, &pds_url, &http).await;
1108
1109 assert!(result.is_ok(), "fetch_labeler_record should succeed");
1111 let record = result.unwrap();
1112
1113 assert!(record.reason_types.is_some(), "reason_types should be Some");
1115 let rt = record.reason_types.unwrap();
1116 assert_eq!(rt.len(), 2, "reason_types should have 2 entries");
1117 assert!(
1118 rt.iter().any(|r| r.contains("reasonSpam")),
1119 "should include reasonSpam"
1120 );
1121
1122 assert!(
1124 record.subject_types.is_some(),
1125 "subject_types should be Some"
1126 );
1127 let st = record.subject_types.unwrap();
1128 assert_eq!(st.len(), 2, "subject_types should have 2 entries");
1129 assert!(st.iter().any(|s| s == "account"), "should include account");
1130 assert!(st.iter().any(|s| s == "record"), "should include record");
1131
1132 assert!(
1134 record.subject_collections.is_some(),
1135 "subject_collections should be Some"
1136 );
1137 let sc = record.subject_collections.unwrap();
1138 assert_eq!(sc.len(), 2, "subject_collections should have 2 entries");
1139 assert!(
1140 sc.iter().any(|s| s.contains("bsky.feed.post")),
1141 "should include app.bsky.feed.post"
1142 );
1143 }
1144}