use crate::error::AppError;
use reqwest::Client;
use serde::Deserialize;
use std::collections::HashMap;
use std::time::{Duration, Instant};
use tokio::sync::Mutex;
use tracing::{debug, warn, info};
use serde_json::Value;
use reqwest::header::ACCEPT;
#[derive(Debug, Deserialize)]
struct ResolveHandleResponse {
did: String,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_did_web_to_did_document_url_root() {
let url = did_web_to_did_document_url("did:web:example.com").unwrap();
assert_eq!(url, "https://example.com/.well-known/did.json");
}
#[test]
fn test_did_web_to_did_document_url_path() {
let url = did_web_to_did_document_url("did:web:example.com:user:alice").unwrap();
assert_eq!(url, "https://example.com/user/alice/did.json");
}
#[test]
fn test_did_web_to_did_document_url_various_formats() {
let url = did_web_to_did_document_url("did:web:subdomain.example.org").unwrap();
assert_eq!(url, "https://subdomain.example.org/.well-known/did.json");
let url = did_web_to_did_document_url("did:web:api.example.com:v1:users:bob").unwrap();
assert_eq!(url, "https://api.example.com/v1/users/bob/did.json");
let url = did_web_to_did_document_url("did:web:example.com:alice").unwrap();
assert_eq!(url, "https://example.com/alice/did.json");
}
#[test]
fn test_did_web_to_did_document_url_invalid() {
assert!(did_web_to_did_document_url("").is_none());
assert!(did_web_to_did_document_url("not_a_did").is_none());
assert!(did_web_to_did_document_url("did:plc:test").is_none());
assert!(did_web_to_did_document_url("did:web:").is_none());
}
#[test]
fn test_did_validation() {
let valid_plc = vec![
"did:plc:abcdefghijklmnopqrstuvwx", "did:plc:123456789012345678901234", "did:plc:abcdefabcdefabcdefabcdef", ];
for did in valid_plc {
assert!(did.starts_with("did:plc:"), "{} should start with did:plc:", did);
assert_eq!(did.len(), 32, "{} should be 32 chars long", did);
assert!(did[8..].chars().all(|c| c.is_ascii_alphanumeric()), "{} should be alphanumeric after prefix", did);
}
let invalid = vec![
"did:web:example.com", "did:plc:tooshort", "did:plc:toolong123456789012345678901", "did:plc:has-invalid-chars123456789!", "did:other:abc123def456789012345678901", "not_a_did",
"",
];
for did in invalid {
let is_valid_plc = did.starts_with("did:plc:")
&& did.len() == 32
&& did[8..].chars().all(|c| c.is_ascii_alphanumeric());
assert!(!is_valid_plc, "{} should not be a valid PLC DID", did);
}
}
#[test]
fn test_handle_validation_logic() {
let invalid_handles = vec![
"",
"nodot",
"empty.",
".empty",
"double..dot",
];
for handle in invalid_handles {
let is_invalid = handle.is_empty()
|| !handle.contains('.')
|| handle.contains("..")
|| handle.starts_with('.')
|| handle.ends_with('.');
assert!(is_invalid, "{} should be considered invalid", handle);
}
let valid_handles = vec![
"alice.bsky.social",
"bob.example.com",
"user.subdomain.example.org",
];
for handle in valid_handles {
let is_valid = !handle.is_empty()
&& handle.contains('.')
&& !handle.contains("..")
&& !handle.starts_with('.')
&& !handle.ends_with('.');
assert!(is_valid, "{} should be considered valid", handle);
}
}
#[test]
fn test_well_known_url_construction() {
let test_cases = vec![
("alice.bsky.social", "https://alice.bsky.social/.well-known/atproto-did"),
("bob.example.com", "https://bob.example.com/.well-known/atproto-did"),
("user.subdomain.org", "https://user.subdomain.org/.well-known/atproto-did"),
];
for (handle, expected_url) in test_cases {
let url = format!("https://{}/.well-known/atproto-did", handle.trim_start_matches('@'));
assert_eq!(url, expected_url);
}
}
#[test]
fn test_plc_audit_url_construction() {
let dids = vec![
"did:plc:abc123def456789012345678901",
"did:plc:zyxwvutsrqponmlkjihgfedcba98",
];
for did in dids {
let url = format!("https://plc.directory/{}/log/audit", did);
assert!(url.starts_with("https://plc.directory/"));
assert!(url.contains(did));
assert!(url.ends_with("/log/audit"));
}
}
#[test]
fn test_xrpc_url_construction() {
let endpoints = vec![
"https://api.bsky.app",
"https://bsky.social",
];
let handle = "alice.bsky.social";
for endpoint in endpoints {
let url = format!("{}/xrpc/com.atproto.identity.resolveHandle?handle={}", endpoint, handle);
assert!(url.contains("/xrpc/com.atproto.identity.resolveHandle"));
assert!(url.contains(&format!("handle={}", handle)));
}
}
#[test]
fn test_pds_endpoint_url_construction() {
let base_endpoints = vec![
"https://bsky.social",
"https://api.example.com",
"https://pds.internal.org",
];
let did = "did:plc:abc123def456789012345678901";
for endpoint in base_endpoints {
let repo_url = format!("{}/xrpc/com.atproto.sync.getRepo?did={}", endpoint, did);
assert!(repo_url.contains("/xrpc/com.atproto.sync.getRepo"));
assert!(repo_url.contains(&format!("did={}", did)));
}
}
#[tokio::test]
async fn test_did_resolver_creation() {
let resolver = DidResolver::new();
assert!(resolver.cache_ttl > Duration::from_secs(0));
let resolver2 = DidResolver::default();
assert!(resolver2.cache_ttl > Duration::from_secs(0));
}
#[tokio::test]
async fn test_resolve_handle_plc_did_passthrough() {
let resolver = DidResolver::new();
let plc_did = "did:plc:abc123def456789012345678901";
let result = resolver.resolve_handle(plc_did).await.unwrap();
assert_eq!(result, plc_did);
}
#[tokio::test]
async fn test_cache_operations() {
let resolver = DidResolver::new();
resolver.cache_result("test.handle", "did:plc:test123456789012345678901").await;
resolver.cleanup_cache().await;
}
#[tokio::test]
async fn test_pds_operations() {
let resolver = DidResolver::new();
let pds = resolver.discover_pds("did:plc:unknown").await;
assert!(matches!(pds, Ok(None)));
let pds = resolver.discover_pds("invalid:did").await;
assert!(matches!(pds, Ok(None)));
}
}
impl DidResolver {
async fn try_well_known(&self, handle_domain: &str) -> Result<Option<String>, AppError> {
if !handle_domain.contains('.') {
return Ok(None);
}
let url = format!("https://{}/.well-known/atproto-did", handle_domain);
debug!("Trying .well-known at {}", url);
let resp = match self
.client
.get(&url)
.header(ACCEPT, "text/plain, application/json")
.send()
.await
{
Ok(r) => r,
Err(e) => {
warn!(".well-known request error for {}: {}", handle_domain, e);
return Ok(None);
}
};
if !resp.status().is_success() {
debug!(".well-known HTTP {} for {}", resp.status(), handle_domain);
return Ok(None);
}
let body = match resp.text().await {
Ok(t) => t.trim().to_string(),
Err(e) => {
warn!(".well-known read error for {}: {}", handle_domain, e);
return Ok(None);
}
};
if body.starts_with("did:plc:") && body.len() == 32 && body[8..].chars().all(|c| c.is_ascii_alphanumeric()) {
return Ok(Some(body));
}
if body.starts_with('{') {
if let Ok(v) = serde_json::from_str::<serde_json::Value>(&body) {
if let Some(d) = v.get("did").and_then(|x| x.as_str()) {
if d.starts_with("did:plc:") && d.len() == 32 && d[8..].chars().all(|c| c.is_ascii_alphanumeric()) {
return Ok(Some(d.to_string()));
}
}
}
}
Ok(None)
}
async fn cache_result(&self, account: &str, did: &str) {
let mut cache = self.cache.lock().await;
cache.insert(account.to_string(), (did.to_string(), Instant::now()));
}
async fn resolve_did_web(&self, did: &str) -> Result<(), AppError> {
let url = did_web_to_did_document_url(did).ok_or_else(|| {
AppError::DidResolveFailed("Invalid did:web format".to_string())
})?;
debug!("Fetching did:web document at {}", url);
let resp = self.client.get(&url).send().await?;
if !resp.status().is_success() {
return Err(AppError::DidResolveFailed(format!(
"did:web document fetch failed with status {}",
resp.status()
)));
}
let text = resp.text().await.unwrap_or_default();
let v: serde_json::Value = serde_json::from_str(&text)
.map_err(|e| AppError::DidResolveFailed(format!("Invalid did:web JSON: {}", e)))?;
if let Some(services) = v.get("service").and_then(|s| s.as_array()) {
for svc in services {
let type_ok = svc.get("type").and_then(|t| t.as_str()).map(|t| t == "AtprotoPersonalDataServer").unwrap_or(false);
let id_ok = svc.get("id").and_then(|t| t.as_str()).map(|t| t.contains("atproto") || t.contains("pds")).unwrap_or(false);
if type_ok || id_ok {
if let Some(ep) = svc.get("serviceEndpoint").and_then(|e| e.as_str()) {
let endpoint = if ep.starts_with("http") { ep.to_string() } else { format!("https://{}", ep) };
let mut map = self.pds_map.lock().await;
map.insert(did.to_string(), endpoint);
return Ok(());
}
}
}
}
Ok(())
}
}
fn did_web_to_did_document_url(did: &str) -> Option<String> {
if !did.starts_with("did:web:") { return None; }
let rest = &did[8..];
if rest.is_empty() { return None; }
let parts: Vec<&str> = rest.split(':').collect();
let host = parts.get(0)?.to_string();
if parts.len() == 1 {
Some(format!("https://{}/.well-known/did.json", host))
} else {
let path = parts[1..].join("/");
Some(format!("https://{}/{}/did.json", host, path))
}
}
pub struct DidResolver {
client: Client,
cache: Mutex<HashMap<String, (String, Instant)>>,
cache_ttl: Duration,
pds_map: Mutex<HashMap<String, String>>, }
impl DidResolver {
pub fn new() -> Self {
let client = crate::http::client_with_timeout(Duration::from_secs(10));
Self {
client,
cache: Mutex::new(HashMap::new()),
cache_ttl: Duration::from_secs(3600), pds_map: Mutex::new(HashMap::new()),
}
}
pub async fn discover_pds(&self, did: &str) -> Result<Option<String>, AppError> {
if !did.starts_with("did:plc:") {
return Ok(None);
}
let url = format!("https://plc.directory/{}/log/audit", did);
debug!("Querying PLC audit log for DID {}: {}", did, url);
let resp = self.client.get(&url).send().await?;
if !resp.status().is_success() {
warn!("PLC audit log HTTP {} for {}", resp.status(), did);
return Ok(None);
}
let text = resp.text().await?;
let v: Value = serde_json::from_str(&text)?;
if let Some(entries) = v.as_array() {
for entry in entries.iter().rev() {
if let Some(op) = entry.get("operation") {
if let Some(services) = op.get("services") {
if let Some(atp) = services.get("atproto_pds") {
if let Some(endpoint) = atp.get("endpoint") {
if let Some(endpoint_str) = endpoint.as_str() {
let pds = if endpoint_str.starts_with("http") {
endpoint_str.to_string()
} else {
format!("https://{}", endpoint_str)
};
debug!("Discovered PDS for {}: {}", did, pds);
return Ok(Some(pds));
}
}
}
}
}
}
}
Ok(None)
}
pub async fn resolve_handle(&self, account: &str) -> Result<String, AppError> {
if account.starts_with("did:plc:") {
return Ok(account.to_string());
}
if account.starts_with("did:web:") {
let did = account.to_string();
if let Err(e) = self.resolve_did_web(&did).await {
return Err(AppError::DidResolveFailed(format!(
"did:web resolution failed: {}",
e.message()
)));
}
return Ok(did);
}
{
let mut cache = self.cache.lock().await;
if let Some((did, cached_at)) = cache.get(account) {
if cached_at.elapsed() < self.cache_ttl {
debug!("DID cache hit for handle: {}", account);
return Ok(did.clone());
} else {
cache.remove(account);
}
}
}
let clean_handle = account.strip_prefix('@').unwrap_or(account);
if let Some(did) = self.try_well_known(clean_handle).await? {
info!("Resolved via .well-known for {} -> {}", clean_handle, did);
self.cache_result(account, &did).await;
debug!("Resolved handle {} to DID {}", account, did);
return Ok(did);
}
let endpoints = vec![
"https://api.bsky.app/xrpc/com.atproto.identity.resolveHandle",
"https://bsky.social/xrpc/com.atproto.identity.resolveHandle",
];
let mut did: Option<String> = None;
let mut last_err: Option<AppError> = None;
for base in endpoints.into_iter() {
let url = format!("{}?handle={}", base, clean_handle);
debug!("Resolving handle {} via {}", clean_handle, url);
match self.client.get(&url).send().await {
Ok(resp) => {
if resp.status().is_success() {
match resp.json::<ResolveHandleResponse>().await {
Ok(res) => { did = Some(res.did); break; },
Err(e) => { last_err = Some(AppError::DidResolveFailed(e.to_string())); }
}
} else {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
last_err = Some(AppError::DidResolveFailed(format!("HTTP {} from {}: {}", status, base, body)));
}
}
Err(e) => { last_err = Some(AppError::DidResolveFailed(e.to_string())); }
}
}
let did = did.ok_or_else(|| last_err.unwrap_or(AppError::DidResolveFailed("Unknown handle resolution error".to_string())))?;
if !did.starts_with("did:plc:") || did.len() != 32 {
return Err(AppError::DidResolveFailed(format!(
"Invalid DID format returned: {}",
did
)));
}
self.cache_result(account, &did).await;
debug!("Resolved handle {} to DID {}", account, did);
Ok(did)
}
#[allow(dead_code)]
pub async fn cleanup_cache(&self) {
let mut cache = self.cache.lock().await;
let now = Instant::now();
cache.retain(|_, (_, cached_at)| now.duration_since(*cached_at) < self.cache_ttl);
}
}
impl Default for DidResolver {
fn default() -> Self {
Self::new()
}
}