use std::{borrow::Cow, collections::HashMap, ops::ControlFlow};
use itertools::Itertools;
use passkey_types::webauthn::WellKnown;
use url::Url;
use crate::{Origin, WebauthnError};
#[cfg(doc)]
use crate::Client;
#[cfg(test)]
pub(crate) mod tests;
#[cfg(feature = "reqwest")]
mod reqwest_fetcher;
#[cfg(feature = "android-asset-validation")]
pub(crate) mod android;
#[cfg(feature = "android-asset-validation")]
use android::UnverifiedAssetLink;
pub struct RpIdVerifier<P, F> {
tld_provider: Box<P>,
allows_insecure_localhost: bool,
fetcher: Option<F>,
}
impl<P, F> RpIdVerifier<P, F>
where
P: public_suffix::EffectiveTLDProvider + Sync + 'static,
F: Fetcher + Sync,
{
pub fn new(tld_provider: P, fetcher: Option<F>) -> Self {
Self {
tld_provider: Box::new(tld_provider),
allows_insecure_localhost: false,
fetcher,
}
}
pub fn allows_insecure_localhost(mut self, is_allowed: bool) -> Self {
self.allows_insecure_localhost = is_allowed;
self
}
pub async fn assert_domain<'a>(
&self,
origin: &'a Origin<'a>,
rp_id: Option<&'a str>,
) -> Result<&'a str, WebauthnError> {
match origin {
Origin::Web(url) => self.assert_web_rp_id(url, rp_id).await,
#[cfg(feature = "android-asset-validation")]
Origin::Android(unverified) => self.assert_android_rp_id(unverified, rp_id),
}
}
async fn assert_web_rp_id<'a>(
&self,
origin: &'a Url,
rp_id: Option<&'a str>,
) -> Result<&'a str, WebauthnError> {
let mut effective_domain = origin.domain().ok_or(WebauthnError::OriginMissingDomain)?;
if let Some(rp_id) = rp_id {
if !effective_domain.ends_with(rp_id) {
effective_domain = self
.validate_related_origins(rp_id, effective_domain)
.await?;
} else {
effective_domain = rp_id;
}
}
if let ControlFlow::Break(res) = self.assert_valid_rp_id(effective_domain) {
return res;
}
if !(origin.scheme().eq_ignore_ascii_case("https")) {
return Err(WebauthnError::UnprotectedOrigin);
}
Ok(effective_domain)
}
fn assert_valid_rp_id<'a>(
&self,
rp_id: &'a str,
) -> ControlFlow<Result<&'a str, WebauthnError>, ()> {
if rp_id == "localhost" {
return if self.allows_insecure_localhost {
ControlFlow::Break(Ok(rp_id))
} else {
ControlFlow::Break(Err(WebauthnError::InsecureLocalhostNotAllowed))
};
}
if decode_host(rp_id)
.as_ref()
.and_then(|s| self.tld_provider.effective_tld_plus_one(s).ok())
.is_none()
{
return ControlFlow::Break(Err(WebauthnError::InvalidRpId));
}
ControlFlow::Continue(())
}
pub fn is_valid_rp_id(&self, rp_id: &str) -> bool {
match self.assert_valid_rp_id(rp_id) {
ControlFlow::Continue(_) | ControlFlow::Break(Ok(_)) => true,
ControlFlow::Break(Err(_)) => false,
}
}
#[cfg(feature = "android-asset-validation")]
fn assert_android_rp_id<'a>(
&self,
target_link: &'a UnverifiedAssetLink,
rp_id: Option<&'a str>,
) -> Result<&'a str, WebauthnError> {
let mut effective_rp_id = target_link.host();
if let Some(rp_id) = rp_id {
if !effective_rp_id.ends_with(rp_id) {
return Err(WebauthnError::OriginRpMissmatch);
}
effective_rp_id = rp_id;
}
if decode_host(effective_rp_id)
.as_ref()
.and_then(|s| self.tld_provider.effective_tld_plus_one(s).ok())
.is_none()
{
return Err(WebauthnError::InvalidRpId);
}
Ok(effective_rp_id)
}
const ORIGIN_LABEL_LIMIT: usize = 5;
async fn validate_related_origins<'a>(
&self,
rp_id: &'a str,
effective_domain: &'a str,
) -> Result<&'a str, WebauthnError> {
let Some(ref fetcher) = self.fetcher else {
return Err(WebauthnError::OriginRpMissmatch);
};
if let ControlFlow::Break(res) = self.assert_valid_rp_id(rp_id) {
return res;
}
let well_known_url = Url::parse(&format!("https://{rp_id}/.well-known/webauthn"))
.expect("Building well_known_url unexpectedly failed");
let RelatedOriginResponse { payload, final_url } =
fetcher.fetch_related_origins(well_known_url).await?;
if final_url
.domain()
.filter(|domain| domain.ends_with(rp_id))
.is_none()
{
return Err(WebauthnError::RedirectError);
}
let WellKnown { origins } = payload;
let origin_domains: Vec<_> = origins
.iter()
.filter_map(|origin| decode_host(origin.domain()?))
.collect();
let labels_to_origins: HashMap<_, _> = origin_domains
.iter()
.filter_map(|origin| {
let etld = self.tld_provider.effective_tld_plus_one(origin).ok()?;
let (label, _) = etld.split_once('.')?;
if label.is_empty() {
None
} else {
Some((label, origin))
}
})
.into_group_map();
if labels_to_origins.len() > Self::ORIGIN_LABEL_LIMIT {
return Err(WebauthnError::ExceedsMaxLabelLimit);
}
let decoded_effective_domain =
decode_host(effective_domain).ok_or(WebauthnError::InvalidRpId)?;
let Some((requesting_label, _)) = self
.tld_provider
.effective_tld_plus_one(&decoded_effective_domain)
.ok()
.and_then(|etld| etld.split_once('.'))
else {
return Err(WebauthnError::InvalidRpId);
};
let Some(matching_origins) = labels_to_origins.get(requesting_label) else {
return Err(WebauthnError::OriginRpMissmatch);
};
if !matching_origins.contains(&&decoded_effective_domain) {
return Err(WebauthnError::OriginRpMissmatch);
};
Ok(rp_id)
}
}
fn decode_host(host: &str) -> Option<Cow<'_, str>> {
if host.split('.').any(|s| s.starts_with("xn--")) {
let (decoded, result) = idna::domain_to_unicode(host);
result.ok().map(|_| Cow::from(decoded))
} else {
Some(Cow::from(host))
}
}
#[expect(async_fn_in_trait)]
pub trait Fetcher {
async fn fetch_related_origins(&self, url: Url)
-> Result<RelatedOriginResponse, WebauthnError>;
}
pub struct RelatedOriginResponse {
pub payload: WellKnown,
pub final_url: Url,
}
impl Fetcher for () {
async fn fetch_related_origins(
&self,
_url: Url,
) -> Result<RelatedOriginResponse, WebauthnError> {
Err(WebauthnError::FetcherError)
}
}