use std::borrow::Cow;
use std::time::Duration;
use miette::Diagnostic;
use thiserror::Error;
use url::Url;
use crate::commands::test::labeler::create_report::self_mint::{SelfMintCurve, SelfMintSigner};
use crate::commands::test::labeler::create_report::{
self, CreateReportTee, PdsXrpcClient, RealCreateReportTee,
};
use crate::commands::test::labeler::crypto;
use crate::commands::test::labeler::http::{self, RealHttpTee};
use crate::commands::test::labeler::identity;
use crate::commands::test::labeler::report::{
CheckResult, CheckStatus, LabelerReport, ReportHeader, Stage,
};
use crate::commands::test::labeler::subscription::{self, RealWebSocketClient};
use crate::common::identity::{
Did, DnsResolver, HttpClient, find_service, is_local_labeler_hostname, resolve_did,
resolve_handle,
};
#[derive(Debug, Clone)]
pub enum LabelerTarget {
Identified {
identifier: AtIdentifier,
explicit_did: Option<Did>,
},
Endpoint {
url: Url,
did: Option<Did>,
},
}
#[derive(Debug, Clone)]
pub enum AtIdentifier {
Handle(String),
Did(Did),
}
pub struct LabelerOptions<'a> {
pub http: &'a dyn HttpClient,
pub dns: &'a dyn DnsResolver,
pub http_tee: HttpTee<'a>,
pub ws_client: Option<&'a dyn subscription::WebSocketClient>,
pub subscribe_timeout: Duration,
pub verbose: bool,
pub create_report_tee: CreateReportTeeKind<'a>,
pub commit_report: bool,
pub force_self_mint: bool,
pub self_mint_curve: SelfMintCurve,
pub report_subject_override: Option<&'a Did>,
pub self_mint_signer: Option<&'a SelfMintSigner>,
pub pds_credentials: Option<&'a PdsCredentials>,
pub pds_xrpc_client: Option<&'a dyn PdsXrpcClient>,
pub pds_xrpc_client_override: Option<&'a dyn PdsXrpcClient>,
pub run_id: &'a str,
}
pub enum HttpTee<'a> {
Real(&'a reqwest::Client),
Test(&'a dyn http::RawHttpTee),
}
pub enum CreateReportTeeKind<'a> {
Real(&'a reqwest::Client),
Test(&'a dyn CreateReportTee),
}
#[derive(Debug, Clone)]
pub struct PdsCredentials {
pub handle: String,
pub app_password: String,
}
#[derive(Debug, Error, Diagnostic)]
#[error("{message}")]
pub struct TargetParseError {
pub message: String,
}
impl TargetParseError {
fn new(message: impl Into<String>) -> Self {
Self {
message: message.into(),
}
}
fn unrecognized_target(raw: &str) -> Self {
Self::new(format!(
"Unrecognized target '{raw}'. Expected one of:\n - ATProto handle (e.g., alice.bsky.social)\n - DID (e.g., did:plc:abc123 or did:web:example.com)\n - HTTPS endpoint URL (e.g., https://labeler.example.com)\n - HTTP endpoint URL with a local hostname (e.g., http://localhost:8080)"
))
}
fn http_not_supported(raw: &str) -> Self {
Self::new(format!(
"HTTP endpoint '{raw}' is not supported for remote hosts. Use HTTPS, or point at a local labeler (localhost / 127.0.0.0/8 / RFC 1918 / .local) to allow plaintext HTTP."
))
}
fn ambiguous_did(raw: &str, explicit: &str) -> Self {
Self::new(format!(
"Ambiguous target specification: target '{raw}' is already a DID, but --did {explicit} was also provided. Please use only one."
))
}
}
fn is_valid_handle(s: &str) -> bool {
if !s.contains('.') {
return false;
}
if s.is_empty()
|| s.starts_with('-')
|| s.starts_with('.')
|| s.ends_with('-')
|| s.ends_with('.')
{
return false;
}
for c in s.chars() {
if !c.is_ascii_alphanumeric() && c != '-' && c != '.' {
return false;
}
}
if s.contains("..") {
return false;
}
true
}
pub fn parse_target(
raw: &str,
explicit_did: Option<&str>,
) -> Result<LabelerTarget, TargetParseError> {
if raw.starts_with("did:") {
if let Some(ed) = explicit_did {
return Err(TargetParseError::ambiguous_did(raw, ed));
}
return Ok(LabelerTarget::Identified {
identifier: AtIdentifier::Did(Did(raw.to_string())),
explicit_did: None,
});
}
if raw.starts_with("https://") {
let url = Url::parse(raw)
.map_err(|e| TargetParseError::new(format!("Invalid URL '{raw}': {e}")))?;
return Ok(LabelerTarget::Endpoint {
url,
did: explicit_did.map(|d| Did(d.to_string())),
});
}
if raw.starts_with("http://") {
let url = Url::parse(raw)
.map_err(|e| TargetParseError::new(format!("Invalid URL '{raw}': {e}")))?;
if is_local_labeler_hostname(&url) {
return Ok(LabelerTarget::Endpoint {
url,
did: explicit_did.map(|d| Did(d.to_string())),
});
}
return Err(TargetParseError::http_not_supported(raw));
}
if is_valid_handle(raw) {
return Ok(LabelerTarget::Identified {
identifier: AtIdentifier::Handle(raw.to_string()),
explicit_did: explicit_did.map(|d| Did(d.to_string())),
});
}
Err(TargetParseError::unrecognized_target(raw))
}
pub async fn run_pipeline(target: LabelerTarget, opts: LabelerOptions<'_>) -> LabelerReport {
let header = ReportHeader {
target: format_target(&target),
resolved_did: None,
pds_endpoint: None,
labeler_endpoint: None,
};
let mut report = LabelerReport::new(header);
let identity_output = identity::run(&target, opts.http, opts.dns).await;
if let Some(ref facts) = identity_output.facts {
report.header.resolved_did = Some(facts.did.to_string());
report.header.pds_endpoint = Some(facts.pds_endpoint.to_string());
report.header.labeler_endpoint = Some(facts.labeler_endpoint.to_string());
}
let is_no_did_supplied = !identity_output.results.is_empty()
&& identity_output.results.iter().all(|r| {
r.status == CheckStatus::Skipped
&& r.skipped_reason
.as_ref()
.map(|reason| reason.contains("no DID supplied"))
.unwrap_or(false)
});
for result in identity_output.results {
report.record(result);
}
let labeler_endpoint = if let Some(ref facts) = identity_output.facts {
Some(facts.labeler_endpoint.clone())
} else if let LabelerTarget::Endpoint { url, .. } = &target {
Some(url.clone())
} else {
None
};
let mut http_facts = None;
if let Some(endpoint) = &labeler_endpoint {
let output = match opts.http_tee {
HttpTee::Test(tee) => {
http::run(tee).await
}
HttpTee::Real(client) => {
let http_client = client.clone();
let real_tee = RealHttpTee::new(http_client, endpoint.clone());
http::run(&real_tee).await
}
};
for result in output.results {
report.record(result);
}
http_facts = output.facts.clone();
} else if identity_output.facts.is_none() && !is_no_did_supplied {
report.record(CheckResult {
id: "http::not_run",
stage: Stage::Http,
status: CheckStatus::Skipped,
summary: Cow::Borrowed("HTTP stage (not run)"),
diagnostic: None,
skipped_reason: Some(Cow::Borrowed("blocked by identity stage failures")),
});
} else {
report.record(CheckResult {
id: "http::not_run",
stage: Stage::Http,
status: CheckStatus::Skipped,
summary: Cow::Borrowed("HTTP stage (not run)"),
diagnostic: None,
skipped_reason: Some(Cow::Borrowed("identity stage produced no labeler endpoint")),
});
}
let mut subscription_facts = None;
if let Some(endpoint) = &labeler_endpoint {
let ws: &dyn subscription::WebSocketClient = if let Some(injected_ws) = opts.ws_client {
injected_ws
} else {
&RealWebSocketClient
};
let sub_output = subscription::run(endpoint, ws, opts.subscribe_timeout).await;
for result in sub_output.results {
report.record(result);
}
subscription_facts = sub_output.facts;
} else if identity_output.facts.is_none() && !is_no_did_supplied {
report.record(CheckResult {
id: "subscription::not_run",
stage: Stage::Subscription,
status: CheckStatus::Skipped,
summary: Cow::Borrowed("Subscription stage (not run)"),
diagnostic: None,
skipped_reason: Some(Cow::Borrowed("blocked by identity stage failures")),
});
} else {
report.record(CheckResult {
id: "subscription::not_run",
stage: Stage::Subscription,
status: CheckStatus::Skipped,
summary: Cow::Borrowed("Subscription stage (not run)"),
diagnostic: None,
skipped_reason: Some(Cow::Borrowed("identity stage produced no labeler endpoint")),
});
}
let sub_has_labels = subscription_facts
.as_ref()
.map(|f| !f.sample_labels.is_empty())
.unwrap_or(false);
if let Some(identity_facts) = &identity_output.facts {
if http_facts.is_some() || sub_has_labels {
let mut combined_labels: Vec<atrium_api::com::atproto::label::defs::Label> = Vec::new();
if let Some(h) = &http_facts {
combined_labels.extend(h.first_page.iter().cloned());
}
if let Some(s) = &subscription_facts {
combined_labels.extend(s.sample_labels.iter().cloned());
}
let crypto_output = crypto::run(identity_facts, &combined_labels, opts.http).await;
for result in crypto_output.results {
report.record(result);
}
} else {
report.record(CheckResult {
id: "crypto::not_run",
stage: Stage::Crypto,
status: CheckStatus::Skipped,
summary: Cow::Borrowed("Crypto stage (not run)"),
diagnostic: None,
skipped_reason: Some(Cow::Borrowed(
"neither HTTP nor subscription stage produced labels to verify",
)),
});
}
} else if identity_output.facts.is_none() && !is_no_did_supplied {
report.record(CheckResult {
id: "crypto::not_run",
stage: Stage::Crypto,
status: CheckStatus::Skipped,
summary: Cow::Borrowed("Crypto stage (not run)"),
diagnostic: None,
skipped_reason: Some(Cow::Borrowed("blocked by upstream stage failures")),
});
} else {
report.record(CheckResult {
id: "crypto::not_run",
stage: Stage::Crypto,
status: CheckStatus::Skipped,
summary: Cow::Borrowed("Crypto stage (not run)"),
diagnostic: None,
skipped_reason: Some(Cow::Borrowed("identity stage produced no labeler endpoint")),
});
}
let mut pds_resolution_error: Option<String> = None;
let pds_xrpc_client_owned: Option<create_report::RealPdsXrpcClient> =
if opts.pds_xrpc_client_override.is_some() {
None
} else if let (Some(creds), CreateReportTeeKind::Real(http_client)) =
(opts.pds_credentials, &opts.create_report_tee)
{
match resolve_reporter_pds_endpoint(&creds.handle, opts.http, opts.dns).await {
Ok(url) => Some(create_report::RealPdsXrpcClient::new(
(*http_client).clone(),
url,
)),
Err(msg) => {
pds_resolution_error = Some(msg);
None
}
}
} else {
None
};
let pds_xrpc_client_ref: Option<&dyn PdsXrpcClient> = pds_xrpc_client_owned
.as_ref()
.map(|c| c as &dyn PdsXrpcClient);
let create_report_run_opts = create_report::CreateReportRunOptions {
commit_report: opts.commit_report,
force_self_mint: opts.force_self_mint,
self_mint_curve: opts.self_mint_curve,
report_subject_override: opts.report_subject_override,
self_mint_signer: opts.self_mint_signer,
pds_credentials: opts.pds_credentials,
pds_xrpc_client: opts.pds_xrpc_client_override.or(pds_xrpc_client_ref),
pds_resolution_error: pds_resolution_error.as_deref(),
run_id: opts.run_id,
};
let labeler_endpoint_for_report = labeler_endpoint.clone();
let report_output = match opts.create_report_tee {
CreateReportTeeKind::Test(tee) => {
create_report::run(identity_output.facts.as_ref(), tee, &create_report_run_opts).await
}
CreateReportTeeKind::Real(client) => {
let endpoint = labeler_endpoint_for_report.unwrap_or_else(|| {
url::Url::parse("http://127.0.0.1:0").expect("dummy URL parses")
});
let real_tee = RealCreateReportTee::new(client.clone(), endpoint);
create_report::run(
identity_output.facts.as_ref(),
&real_tee,
&create_report_run_opts,
)
.await
}
};
for result in report_output.results {
report.record(result);
}
report.finish();
report
}
async fn resolve_reporter_pds_endpoint(
handle: &str,
http: &dyn HttpClient,
dns: &dyn DnsResolver,
) -> Result<Url, String> {
let did = resolve_handle(handle, http, dns)
.await
.map_err(|e| format!("failed to resolve handle {handle} to a DID: {e}"))?;
let raw_doc = resolve_did(&did, http)
.await
.map_err(|e| format!("failed to resolve DID {did} to a DID document: {e}"))?;
let service = find_service(&raw_doc.parsed, "atproto_pds", "AtprotoPersonalDataServer")
.ok_or_else(|| {
format!("DID document for {did} does not advertise an #atproto_pds service")
})?;
Url::parse(&service.service_endpoint).map_err(|_| {
format!(
"DID document for {did} has a malformed #atproto_pds endpoint: {}",
service.service_endpoint
)
})
}
fn format_target(target: &LabelerTarget) -> String {
match target {
LabelerTarget::Identified {
identifier,
explicit_did,
} => {
let id_str = match identifier {
AtIdentifier::Handle(h) => h.clone(),
AtIdentifier::Did(d) => d.0.clone(),
};
if explicit_did.is_some() {
format!("{id_str} (with explicit DID)")
} else {
id_str
}
}
LabelerTarget::Endpoint { url, did } => {
if did.is_some() {
format!("{url} (with explicit DID)")
} else {
url.to_string()
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_target_handle() {
let target = parse_target("alice.bsky.social", None).expect("should parse");
match target {
LabelerTarget::Identified {
identifier,
explicit_did,
} => {
assert!(
matches!(identifier, AtIdentifier::Handle(ref h) if h == "alice.bsky.social")
);
assert!(explicit_did.is_none());
}
_ => panic!("expected Identified variant"),
}
}
#[test]
fn parse_target_did_plc() {
let target = parse_target("did:plc:abc123", None).expect("should parse");
match target {
LabelerTarget::Identified {
identifier,
explicit_did,
} => {
assert!(matches!(identifier, AtIdentifier::Did(ref d) if d.0 == "did:plc:abc123"));
assert!(explicit_did.is_none());
}
_ => panic!("expected Identified variant"),
}
}
#[test]
fn parse_target_did_web() {
let target = parse_target("did:web:example.com", None).expect("should parse");
match target {
LabelerTarget::Identified {
identifier,
explicit_did,
} => {
assert!(
matches!(identifier, AtIdentifier::Did(ref d) if d.0 == "did:web:example.com")
);
assert!(explicit_did.is_none());
}
_ => panic!("expected Identified variant"),
}
}
#[test]
fn parse_target_endpoint_https() {
let target = parse_target("https://example.com/labeler", None).expect("should parse");
match target {
LabelerTarget::Endpoint { url, did } => {
assert_eq!(url.as_str(), "https://example.com/labeler");
assert!(did.is_none());
}
_ => panic!("expected Endpoint variant"),
}
}
#[test]
fn parse_target_endpoint_with_explicit_did() {
let target =
parse_target("https://example.com/labeler", Some("did:plc:xyz")).expect("should parse");
match target {
LabelerTarget::Endpoint { url, did } => {
assert_eq!(url.as_str(), "https://example.com/labeler");
assert_eq!(did.map(|d| d.0.clone()), Some("did:plc:xyz".to_string()));
}
_ => panic!("expected Endpoint variant"),
}
}
#[test]
fn parse_target_endpoint_http_remote_rejected() {
let err = parse_target("http://evil.example", None).expect_err("should reject http");
assert!(err.message.contains("HTTP"));
assert!(err.message.contains("local"));
}
#[test]
fn parse_target_endpoint_http_local_accepted() {
let cases = &[
"http://localhost:8080",
"http://127.0.0.1:5000",
"http://127.1.2.3/",
"http://[::1]:8080/",
"http://10.0.0.1/",
"http://192.168.1.100:8080",
"http://172.16.0.1/",
"http://mybox.local:8080",
];
for raw in cases {
let target = parse_target(raw, None)
.unwrap_or_else(|e| panic!("expected {raw} to parse, got: {}", e.message));
match target {
LabelerTarget::Endpoint { url, did } => {
assert_eq!(
url.as_str().trim_end_matches('/'),
raw.trim_end_matches('/')
);
assert!(did.is_none());
}
_ => panic!("expected Endpoint variant for {raw}"),
}
}
}
#[test]
fn parse_target_unrecognised() {
let err = parse_target("not a handle or did", None).expect_err("should fail");
assert!(err.message.contains("Unrecognized target"));
}
#[test]
fn parse_target_did_with_conflicting_flag() {
let err = parse_target("did:plc:abc", Some("did:web:example.com"))
.expect_err("should reject ambiguous target");
assert!(err.message.contains("Ambiguous"));
}
}