use std::collections::HashMap;
use std::sync::{Arc, Mutex, PoisonError, RwLock};
use std::time::{Duration, Instant};
use futures_util::StreamExt;
use tokio::net::TcpStream;
use crate::errors::RequestError;
use crate::{PubkyHttpClient, PublicKey, Result, cross_log};
use reqwest::{IntoUrl, Method, RequestBuilder};
use url::Url;
const TRANSPORT_CACHE_TTL: Duration = Duration::from_secs(60);
const PROBE_TIMEOUT: Duration = Duration::from_millis(1500);
#[derive(Debug, Clone)]
pub(crate) enum ResolvedTransport {
PubkyTls,
Icann { domain: String, port: Option<u16> },
}
#[derive(Debug, Clone)]
pub(crate) struct TransportResolver {
cache: Arc<RwLock<HashMap<String, (Instant, ResolvedTransport)>>>,
guards: Arc<Mutex<HashMap<String, Arc<tokio::sync::Mutex<()>>>>>,
}
impl TransportResolver {
pub(crate) fn new() -> Self {
Self {
cache: Arc::new(RwLock::new(HashMap::new())),
guards: Arc::new(Mutex::new(HashMap::new())),
}
}
pub(crate) async fn resolve(&self, pk: &str, pkarr: &pkarr::Client) -> ResolvedTransport {
if let Some(t) = self.cached(pk) {
return t;
}
self.resolve_and_cache(pk, pkarr).await
}
fn cached(&self, pk: &str) -> Option<ResolvedTransport> {
let cache = self.cache.read().unwrap_or_else(PoisonError::into_inner);
cache
.get(pk)
.filter(|(ts, _)| ts.elapsed() < TRANSPORT_CACHE_TTL)
.map(|(_, t)| t.clone())
}
async fn resolve_and_cache(&self, pk: &str, pkarr: &pkarr::Client) -> ResolvedTransport {
let guard = {
let mut guards = self.guards.lock().unwrap_or_else(PoisonError::into_inner);
Arc::clone(guards.entry(pk.to_string()).or_default())
};
let _lock = guard.lock().await;
if let Some(t) = self.cached(pk) {
return t;
}
let t = Self::resolve_from_pkarr(pkarr, pk).await;
self.cache
.write()
.unwrap_or_else(PoisonError::into_inner)
.insert(pk.to_string(), (Instant::now(), t.clone()));
t
}
async fn resolve_from_pkarr(pkarr: &pkarr::Client, qname: &str) -> ResolvedTransport {
let stream = pkarr.resolve_https_endpoints(qname);
futures_util::pin_mut!(stream);
let mut has_direct = false;
let mut direct_addrs = Vec::new();
let mut icann: Option<(String, Option<u16>)> = None;
while let Some(ep) = stream.next().await {
if let Some(domain) = ep.domain() {
if icann.is_none() {
icann = Some((domain.to_string(), ep.port()));
}
} else {
has_direct = true;
direct_addrs.extend(ep.to_socket_addrs());
}
}
let Some((domain, port)) = icann else {
return ResolvedTransport::PubkyTls;
};
if !has_direct {
return ResolvedTransport::Icann { domain, port };
}
if probe_reachable(&direct_addrs, PROBE_TIMEOUT).await {
ResolvedTransport::PubkyTls
} else {
cross_log!(
warn,
"Direct endpoint unreachable for {qname}; ICANN fallback to {domain}"
);
ResolvedTransport::Icann { domain, port }
}
}
}
async fn probe_reachable(addrs: &[std::net::SocketAddr], timeout: Duration) -> bool {
for addr in addrs {
if let Ok(Ok(_)) = tokio::time::timeout(timeout, TcpStream::connect(addr)).await {
return true;
}
}
false
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum HostKind {
ResolvedPubky,
Icann,
Pubky,
}
fn classify_host(host: &str) -> HostKind {
if let Some(pk_host) = host.strip_prefix("_pubky.") {
if PublicKey::is_pubky_prefixed(pk_host) {
return HostKind::Icann;
}
if PublicKey::try_from_z32(pk_host).is_ok() {
return HostKind::ResolvedPubky;
}
} else if PublicKey::is_pubky_prefixed(host) || PublicKey::try_from_z32(host).is_err() {
return HostKind::Icann;
}
HostKind::Pubky
}
impl PubkyHttpClient {
pub(crate) async fn cross_request(
&self,
method: Method,
mut url: Url,
) -> Result<RequestBuilder> {
let Some(pk) = self.prepare_request(&mut url).await? else {
return Ok(self.request(method, &url));
};
let transport = self.transport.resolve(&pk, &self.pkarr).await;
self.build_pubky_request(method, &url, &pk, &transport)
}
fn build_pubky_request(
&self,
method: Method,
url: &Url,
pk: &str,
transport: &ResolvedTransport,
) -> Result<RequestBuilder> {
match transport {
ResolvedTransport::PubkyTls => Ok(self.http.request(method, url.as_str())),
ResolvedTransport::Icann { domain, port } => {
let mut icann_url = url.clone();
icann_url.set_host(Some(domain))?;
if let Some(p) = port {
icann_url
.set_port(Some(*p))
.map_err(|_err| url::ParseError::InvalidPort)?;
}
cross_log!(debug, "ICANN fallback for {pk} via {domain}");
Ok(self
.icann_http
.request(method, icann_url.as_str())
.header("pubky-host", pk))
}
}
}
#[allow(
clippy::unused_async,
reason = "keep async signature aligned with WASM build"
)]
pub async fn prepare_request(&self, url: &mut Url) -> Result<Option<String>> {
let host = url.host_str().unwrap_or("");
if let Some(stripped) = host.strip_prefix("_pubky.") {
if PublicKey::is_pubky_prefixed(stripped) {
return Err(RequestError::Validation {
message: "pubky prefix is not allowed in transport hosts; use raw z32"
.to_string(),
}
.into());
}
if PublicKey::try_from_z32(stripped).is_ok() {
return Ok(Some(stripped.to_string()));
}
} else {
if PublicKey::is_pubky_prefixed(host) {
return Err(RequestError::Validation {
message: "pubky prefix is not allowed in transport hosts; use raw z32"
.to_string(),
}
.into());
}
if PublicKey::try_from_z32(host).is_ok() {
return Ok(Some(host.to_string()));
}
}
Ok(None)
}
pub fn request<U: IntoUrl>(&self, method: Method, url: &U) -> RequestBuilder {
let url_str = url.as_str();
let host = Url::parse(url_str)
.ok()
.and_then(|url| url.host_str().map(str::to_owned));
if let Some(ref host) = host {
match classify_host(host) {
HostKind::ResolvedPubky => {
cross_log!(debug, "PubkyTLS request for resolved _pubky host {}", host);
return self.http.request(method, url_str);
}
HostKind::Icann => {
cross_log!(debug, "Standard TLS request for ICANN host {}", host);
return self.icann_http.request(method, url_str);
}
HostKind::Pubky => {
cross_log!(debug, "PubkyTLS request for pubky host {}", host);
}
}
}
self.http.request(method, url_str)
}
}
#[cfg(test)]
mod tests {
use super::*;
use pkarr::dns::rdata::SVCB;
use pkarr::{Keypair, SignedPacket};
#[test]
fn classify_hosts() {
assert_eq!(classify_host("example.com"), HostKind::Icann);
let z32 = "o4dksfbqk85ogzdb5osziw6befigbuxmuxkuxq8434q89uj56uyy";
assert_eq!(
classify_host(&format!("_pubky.{z32}")),
HostKind::ResolvedPubky
);
assert_eq!(classify_host(z32), HostKind::Pubky);
}
#[tokio::test]
async fn probe_unreachable_returns_false() {
let addr = "192.0.2.1:1".parse().unwrap(); assert!(!probe_reachable(&[addr], Duration::from_millis(100)).await);
}
fn pkarr_with_packet(keypair: &Keypair, packet: &SignedPacket) -> pkarr::Client {
let mut builder = PubkyHttpClient::builder();
builder.pkarr(|b| b.no_default_network().bootstrap(&["127.0.0.1:1"]));
let client = builder.build().unwrap();
let cache_key: pkarr::CacheKey = keypair.public_key().into();
client.pkarr.cache().unwrap().put(&cache_key, packet);
client.pkarr
}
#[test]
fn build_pubky_request_icann_rewrites_url_and_sets_header() {
let client = PubkyHttpClient::builder()
.pkarr(|b| b.no_default_network().bootstrap(&["127.0.0.1:1"]))
.build()
.unwrap();
let z32 = "o4dksfbqk85ogzdb5osziw6befigbuxmuxkuxq8434q89uj56uyy";
let url = Url::parse(&format!("https://{z32}/pub/app/file.txt")).unwrap();
let transport = ResolvedTransport::Icann {
domain: "example.com".to_string(),
port: Some(8443),
};
let req = client
.build_pubky_request(Method::GET, &url, z32, &transport)
.unwrap()
.build()
.unwrap();
assert_eq!(req.url().host_str(), Some("example.com"));
assert_eq!(req.url().port(), Some(8443));
assert_eq!(req.url().path(), "/pub/app/file.txt");
assert_eq!(req.headers().get("pubky-host").unwrap(), z32);
}
#[tokio::test]
async fn resolve_transport_direct_only() {
let kp = Keypair::random();
let mut svcb = SVCB::new(1, ".".try_into().unwrap());
svcb.set_port(6881);
let packet = SignedPacket::builder()
.https(".".try_into().unwrap(), svcb, 3600)
.address(".".try_into().unwrap(), "192.0.2.1".parse().unwrap(), 3600)
.sign(&kp)
.unwrap();
let pkarr = pkarr_with_packet(&kp, &packet);
let t = TransportResolver::resolve_from_pkarr(&pkarr, &kp.public_key().to_string()).await;
assert!(matches!(t, ResolvedTransport::PubkyTls));
}
#[tokio::test]
async fn resolve_transport_icann_only() {
let kp = Keypair::random();
let svcb = SVCB::new(1, "example.com".try_into().unwrap());
let packet = SignedPacket::builder()
.https(".".try_into().unwrap(), svcb, 3600)
.sign(&kp)
.unwrap();
let pkarr = pkarr_with_packet(&kp, &packet);
let t = TransportResolver::resolve_from_pkarr(&pkarr, &kp.public_key().to_string()).await;
assert!(matches!(t, ResolvedTransport::Icann { .. }));
if let ResolvedTransport::Icann { domain, .. } = t {
assert_eq!(domain, "example.com");
}
}
#[tokio::test]
async fn resolve_transport_both_unreachable_direct_falls_back() {
let kp = Keypair::random();
let mut direct = SVCB::new(1, ".".try_into().unwrap());
direct.set_port(6881);
let icann = SVCB::new(10, "example.com".try_into().unwrap());
let packet = SignedPacket::builder()
.https(".".try_into().unwrap(), direct, 3600)
.https(".".try_into().unwrap(), icann, 3600)
.address(".".try_into().unwrap(), "192.0.2.1".parse().unwrap(), 3600)
.sign(&kp)
.unwrap();
let pkarr = pkarr_with_packet(&kp, &packet);
let t = TransportResolver::resolve_from_pkarr(&pkarr, &kp.public_key().to_string()).await;
assert!(
matches!(t, ResolvedTransport::Icann { ref domain, .. } if domain == "example.com"),
"expected ICANN fallback, got {t:?}"
);
}
}