use std::borrow::Cow;
use std::time::Duration;
use miette::Diagnostic;
use thiserror::Error;
use url::Url;
use crate::commands::test::labeler::crypto;
use crate::commands::test::labeler::http::{self, RealHttpTee};
use crate::commands::test::labeler::identity;
use crate::commands::test::labeler::report::{
CheckResult, CheckStatus, LabelerReport, ReportHeader, Stage,
};
use crate::commands::test::labeler::subscription::{self, RealWebSocketClient};
use crate::common::identity::{Did, DnsResolver, HttpClient};
#[derive(Debug, Clone)]
pub enum LabelerTarget {
Identified {
identifier: AtIdentifier,
explicit_did: Option<Did>,
},
Endpoint {
url: Url,
did: Option<Did>,
},
}
#[derive(Debug, Clone)]
pub enum AtIdentifier {
Handle(String),
Did(Did),
}
pub struct LabelerOptions<'a> {
pub http: &'a dyn HttpClient,
pub dns: &'a dyn DnsResolver,
pub http_tee: HttpTee<'a>,
pub ws_client: Option<&'a dyn subscription::WebSocketClient>,
pub subscribe_timeout: Duration,
pub verbose: bool,
}
pub enum HttpTee<'a> {
Real(&'a reqwest::Client),
Test(&'a dyn http::RawHttpTee),
}
#[derive(Debug, Error, Diagnostic)]
#[error("{message}")]
pub struct TargetParseError {
pub message: String,
}
impl TargetParseError {
fn new(message: impl Into<String>) -> Self {
Self {
message: message.into(),
}
}
fn unrecognized_target(raw: &str) -> Self {
Self::new(format!(
"Unrecognized target '{raw}'. Expected one of:\n - ATProto handle (e.g., alice.bsky.social)\n - DID (e.g., did:plc:abc123 or did:web:example.com)\n - HTTPS endpoint URL (e.g., https://labeler.example.com)"
))
}
fn http_not_supported(raw: &str) -> Self {
Self::new(format!(
"HTTP endpoint '{raw}' is not supported. Please use an HTTPS endpoint instead."
))
}
fn ambiguous_did(raw: &str, explicit: &str) -> Self {
Self::new(format!(
"Ambiguous target specification: target '{raw}' is already a DID, but --did {explicit} was also provided. Please use only one."
))
}
}
fn is_valid_handle(s: &str) -> bool {
if !s.contains('.') {
return false;
}
if s.is_empty()
|| s.starts_with('-')
|| s.starts_with('.')
|| s.ends_with('-')
|| s.ends_with('.')
{
return false;
}
for c in s.chars() {
if !c.is_ascii_alphanumeric() && c != '-' && c != '.' {
return false;
}
}
if s.contains("..") {
return false;
}
true
}
pub fn parse_target(
raw: &str,
explicit_did: Option<&str>,
) -> Result<LabelerTarget, TargetParseError> {
if raw.starts_with("did:") {
if let Some(ed) = explicit_did {
return Err(TargetParseError::ambiguous_did(raw, ed));
}
return Ok(LabelerTarget::Identified {
identifier: AtIdentifier::Did(Did(raw.to_string())),
explicit_did: None,
});
}
if raw.starts_with("https://") {
let url = Url::parse(raw)
.map_err(|e| TargetParseError::new(format!("Invalid URL '{raw}': {e}")))?;
return Ok(LabelerTarget::Endpoint {
url,
did: explicit_did.map(|d| Did(d.to_string())),
});
}
if raw.starts_with("http://") {
return Err(TargetParseError::http_not_supported(raw));
}
if is_valid_handle(raw) {
return Ok(LabelerTarget::Identified {
identifier: AtIdentifier::Handle(raw.to_string()),
explicit_did: explicit_did.map(|d| Did(d.to_string())),
});
}
Err(TargetParseError::unrecognized_target(raw))
}
pub async fn run_pipeline(target: LabelerTarget, opts: LabelerOptions<'_>) -> LabelerReport {
let header = ReportHeader {
target: format_target(&target),
resolved_did: None,
pds_endpoint: None,
labeler_endpoint: None,
};
let mut report = LabelerReport::new(header);
let identity_output = identity::run(&target, opts.http, opts.dns).await;
if let Some(ref facts) = identity_output.facts {
report.header.resolved_did = Some(facts.did.to_string());
report.header.pds_endpoint = Some(facts.pds_endpoint.to_string());
report.header.labeler_endpoint = Some(facts.labeler_endpoint.to_string());
}
let is_no_did_supplied = !identity_output.results.is_empty()
&& identity_output.results.iter().all(|r| {
r.status == CheckStatus::Skipped
&& r.skipped_reason
.as_ref()
.map(|reason| reason.contains("no DID supplied"))
.unwrap_or(false)
});
for result in identity_output.results {
report.record(result);
}
let labeler_endpoint = if let Some(ref facts) = identity_output.facts {
Some(facts.labeler_endpoint.clone())
} else if let LabelerTarget::Endpoint { url, .. } = &target {
Some(url.clone())
} else {
None
};
let mut http_facts = None;
if let Some(endpoint) = &labeler_endpoint {
let output = match opts.http_tee {
HttpTee::Test(tee) => {
http::run(tee).await
}
HttpTee::Real(client) => {
let http_client = client.clone();
let real_tee = RealHttpTee::new(http_client, endpoint.clone());
http::run(&real_tee).await
}
};
for result in output.results {
report.record(result);
}
http_facts = output.facts.clone();
} else if identity_output.facts.is_none() && !is_no_did_supplied {
report.record(CheckResult {
id: "http::not_run",
stage: Stage::Http,
status: CheckStatus::Skipped,
summary: Cow::Borrowed("HTTP stage (not run)"),
diagnostic: None,
skipped_reason: Some(Cow::Borrowed("blocked by identity stage failures")),
});
} else {
report.record(CheckResult {
id: "http::not_run",
stage: Stage::Http,
status: CheckStatus::Skipped,
summary: Cow::Borrowed("HTTP stage (not run)"),
diagnostic: None,
skipped_reason: Some(Cow::Borrowed("identity stage produced no labeler endpoint")),
});
}
let mut subscription_facts = None;
if let Some(endpoint) = &labeler_endpoint {
let ws: &dyn subscription::WebSocketClient = if let Some(injected_ws) = opts.ws_client {
injected_ws
} else {
&RealWebSocketClient
};
let sub_output = subscription::run(endpoint, ws, opts.subscribe_timeout).await;
for result in sub_output.results {
report.record(result);
}
subscription_facts = sub_output.facts;
} else if identity_output.facts.is_none() && !is_no_did_supplied {
report.record(CheckResult {
id: "subscription::not_run",
stage: Stage::Subscription,
status: CheckStatus::Skipped,
summary: Cow::Borrowed("Subscription stage (not run)"),
diagnostic: None,
skipped_reason: Some(Cow::Borrowed("blocked by identity stage failures")),
});
} else {
report.record(CheckResult {
id: "subscription::not_run",
stage: Stage::Subscription,
status: CheckStatus::Skipped,
summary: Cow::Borrowed("Subscription stage (not run)"),
diagnostic: None,
skipped_reason: Some(Cow::Borrowed("identity stage produced no labeler endpoint")),
});
}
let sub_has_labels = subscription_facts
.as_ref()
.map(|f| !f.sample_labels.is_empty())
.unwrap_or(false);
if let Some(identity_facts) = &identity_output.facts {
if http_facts.is_some() || sub_has_labels {
let mut combined_labels: Vec<atrium_api::com::atproto::label::defs::Label> = Vec::new();
if let Some(h) = &http_facts {
combined_labels.extend(h.first_page.iter().cloned());
}
if let Some(s) = &subscription_facts {
combined_labels.extend(s.sample_labels.iter().cloned());
}
let crypto_output = crypto::run(identity_facts, &combined_labels, opts.http).await;
for result in crypto_output.results {
report.record(result);
}
} else {
report.record(CheckResult {
id: "crypto::not_run",
stage: Stage::Crypto,
status: CheckStatus::Skipped,
summary: Cow::Borrowed("Crypto stage (not run)"),
diagnostic: None,
skipped_reason: Some(Cow::Borrowed(
"neither HTTP nor subscription stage produced labels to verify",
)),
});
}
} else if identity_output.facts.is_none() && !is_no_did_supplied {
report.record(CheckResult {
id: "crypto::not_run",
stage: Stage::Crypto,
status: CheckStatus::Skipped,
summary: Cow::Borrowed("Crypto stage (not run)"),
diagnostic: None,
skipped_reason: Some(Cow::Borrowed("blocked by upstream stage failures")),
});
} else {
report.record(CheckResult {
id: "crypto::not_run",
stage: Stage::Crypto,
status: CheckStatus::Skipped,
summary: Cow::Borrowed("Crypto stage (not run)"),
diagnostic: None,
skipped_reason: Some(Cow::Borrowed("identity stage produced no labeler endpoint")),
});
}
report.finish();
report
}
fn format_target(target: &LabelerTarget) -> String {
match target {
LabelerTarget::Identified {
identifier,
explicit_did,
} => {
let id_str = match identifier {
AtIdentifier::Handle(h) => h.clone(),
AtIdentifier::Did(d) => d.0.clone(),
};
if explicit_did.is_some() {
format!("{id_str} (with explicit DID)")
} else {
id_str
}
}
LabelerTarget::Endpoint { url, did } => {
if did.is_some() {
format!("{url} (with explicit DID)")
} else {
url.to_string()
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_target_handle() {
let target = parse_target("alice.bsky.social", None).expect("should parse");
match target {
LabelerTarget::Identified {
identifier,
explicit_did,
} => {
assert!(
matches!(identifier, AtIdentifier::Handle(ref h) if h == "alice.bsky.social")
);
assert!(explicit_did.is_none());
}
_ => panic!("expected Identified variant"),
}
}
#[test]
fn parse_target_did_plc() {
let target = parse_target("did:plc:abc123", None).expect("should parse");
match target {
LabelerTarget::Identified {
identifier,
explicit_did,
} => {
assert!(matches!(identifier, AtIdentifier::Did(ref d) if d.0 == "did:plc:abc123"));
assert!(explicit_did.is_none());
}
_ => panic!("expected Identified variant"),
}
}
#[test]
fn parse_target_did_web() {
let target = parse_target("did:web:example.com", None).expect("should parse");
match target {
LabelerTarget::Identified {
identifier,
explicit_did,
} => {
assert!(
matches!(identifier, AtIdentifier::Did(ref d) if d.0 == "did:web:example.com")
);
assert!(explicit_did.is_none());
}
_ => panic!("expected Identified variant"),
}
}
#[test]
fn parse_target_endpoint_https() {
let target = parse_target("https://example.com/labeler", None).expect("should parse");
match target {
LabelerTarget::Endpoint { url, did } => {
assert_eq!(url.as_str(), "https://example.com/labeler");
assert!(did.is_none());
}
_ => panic!("expected Endpoint variant"),
}
}
#[test]
fn parse_target_endpoint_with_explicit_did() {
let target =
parse_target("https://example.com/labeler", Some("did:plc:xyz")).expect("should parse");
match target {
LabelerTarget::Endpoint { url, did } => {
assert_eq!(url.as_str(), "https://example.com/labeler");
assert_eq!(did.map(|d| d.0.clone()), Some("did:plc:xyz".to_string()));
}
_ => panic!("expected Endpoint variant"),
}
}
#[test]
fn parse_target_endpoint_http_rejected() {
let err = parse_target("http://evil.example", None).expect_err("should reject http");
assert!(err.message.contains("HTTPS"));
}
#[test]
fn parse_target_unrecognised() {
let err = parse_target("not a handle or did", None).expect_err("should fail");
assert!(err.message.contains("Unrecognized target"));
}
#[test]
fn parse_target_did_with_conflicting_flag() {
let err = parse_target("did:plc:abc", Some("did:web:example.com"))
.expect_err("should reject ambiguous target");
assert!(err.message.contains("Ambiguous"));
}
}