use std::{collections::HashMap, time::Duration};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use tracing::{debug, error, trace};
use crate::{
fido::AuthenticatorData,
ops::webauthn::{
client_data::ClientData,
idl::{
get::{
HmacGetSecretInputJson, LargeBlobInputJson, PrfInputJson,
PublicKeyCredentialRequestOptionsJSON,
},
response::{
AuthenticationExtensionsClientOutputsJSON, AuthenticationResponseJSON,
AuthenticatorAssertionResponseJSON, HMACGetSecretOutputJSON, LargeBlobOutputJSON,
PRFOutputJSON, PRFValuesJSON, ResponseSerializationError, WebAuthnIDLResponse,
},
rp_id_authorised, Base64UrlString, FromIdlModel, JsonError, RequestSettings,
},
Operation,
},
pin::PinUvAuthProtocol,
proto::ctap2::{
Ctap2GetAssertionResponseExtensions, Ctap2PublicKeyCredentialDescriptor,
Ctap2PublicKeyCredentialUserEntity,
},
webauthn::CtapError,
};
use super::timeout::DEFAULT_TIMEOUT;
use super::{
DowngradableRequest, RelyingPartyId, RequestOrigin, SignRequest, UserVerificationRequirement,
};
#[derive(Debug, Default, Clone, PartialEq, Eq)]
pub struct PrfInputValue {
pub first: Vec<u8>,
pub second: Option<Vec<u8>>,
}
#[derive(Debug, Default, Clone, Serialize, PartialEq, Eq)]
pub struct PrfOutputValue {
#[serde(with = "serde_bytes")]
pub first: [u8; 32],
#[serde(skip_serializing_if = "Option::is_none", with = "serde_bytes")]
pub second: Option<[u8; 32]>,
}
impl PrfInputValue {
pub fn to_hmac_secret_input(&self) -> HMACGetSecretInput {
const PREFIX: &[u8] = b"WebAuthn PRF\x00";
let hash = |slice: &[u8]| -> [u8; 32] {
let mut hasher = Sha256::default();
hasher.update(PREFIX);
hasher.update(slice);
hasher.finalize().into()
};
HMACGetSecretInput {
salt1: hash(&self.first),
salt2: self.second.as_deref().map(hash),
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct GetAssertionRequest {
pub relying_party_id: String,
pub challenge: Vec<u8>,
pub origin: String,
pub top_origin: Option<String>,
pub allow: Vec<Ctap2PublicKeyCredentialDescriptor>,
pub extensions: Option<GetAssertionRequestExtensions>,
pub user_verification: UserVerificationRequirement,
pub timeout: Duration,
}
impl GetAssertionRequest {
fn client_data(&self) -> ClientData {
ClientData {
operation: Operation::GetAssertion,
challenge: self.challenge.clone(),
origin: self.origin.clone(),
top_origin: self.top_origin.clone(),
}
}
pub fn client_data_hash(&self) -> Vec<u8> {
self.client_data().hash()
}
pub fn client_data_json(&self) -> String {
self.client_data().to_json()
}
}
#[derive(thiserror::Error, Debug)]
pub enum GetAssertionPrepareError {
#[error("Invalid JSON format: {0}")]
EncodingError(#[from] JsonError),
#[error("Unexpected length for {0}: {1}")]
UnexpectedLengthError(String, usize),
#[error("Not supported: {0}")]
NotSupported(String),
#[error("Invalid relying party ID: {0}")]
InvalidRelyingPartyId(String),
#[error("Mismatching relying party ID: {0} != {1}")]
MismatchingRelyingPartyId(String, String),
#[error("Invalid AppID: {0}")]
InvalidAppId(String),
}
fn validate_appid(appid: &str) -> Result<String, GetAssertionPrepareError> {
if appid.is_empty() {
return Err(GetAssertionPrepareError::InvalidAppId(
"appid must not be empty".to_string(),
));
}
if !appid.starts_with("https://") {
return Err(GetAssertionPrepareError::InvalidAppId(format!(
"appid must be an https URL, got: {appid}"
)));
}
Ok(appid.to_string())
}
impl GetAssertionRequest {
pub async fn prepare(
request_origin: &RequestOrigin,
json: &str,
settings: &RequestSettings<'_>,
) -> Result<Self, GetAssertionPrepareError> {
let model: PublicKeyCredentialRequestOptionsJSON = serde_json::from_str(json)?;
Self::from_idl_model(request_origin, settings, model).await
}
}
#[async_trait]
impl FromIdlModel<PublicKeyCredentialRequestOptionsJSON> for GetAssertionRequest {
type Error = GetAssertionPrepareError;
async fn from_idl_model(
request_origin: &RequestOrigin,
settings: &RequestSettings<'_>,
inner: PublicKeyCredentialRequestOptionsJSON,
) -> Result<Self, GetAssertionPrepareError> {
let effective_rp_id = request_origin.origin.host.as_str();
let resolved_rp_id = if let Some(relying_party_id) = inner.relying_party_id.as_deref() {
let parsed = RelyingPartyId::try_from(relying_party_id)
.map_err(|err| GetAssertionPrepareError::InvalidRelyingPartyId(err.to_string()))?;
if !rp_id_authorised(request_origin, &parsed, settings).await {
return Err(GetAssertionPrepareError::MismatchingRelyingPartyId(
parsed.0,
effective_rp_id.to_string(),
));
}
parsed.0
} else {
effective_rp_id.to_string()
};
let prf = match inner.extensions.as_ref() {
Some(ext) => match &ext.prf {
Some(prf_json) => Some(PrfInput::try_from(prf_json.clone())?),
None => None,
},
None => None,
};
let appid = match inner.extensions.as_ref().and_then(|e| e.appid.as_ref()) {
Some(s) => Some(validate_appid(s)?),
None => None,
};
let large_blob = match inner
.extensions
.as_ref()
.and_then(|e| e.large_blob.as_ref())
{
Some(lb) if lb.support.is_none() && lb.read != Some(true) && lb.write.is_none() => None,
Some(lb) => Some(GetAssertionLargeBlobExtension::try_from(lb.clone())?),
None => None,
};
let extensions =
inner
.extensions
.as_ref()
.map(|extensions_opt| GetAssertionRequestExtensions {
cred_blob: extensions_opt.cred_blob.unwrap_or(false),
large_blob: large_blob.clone(),
prf: prf.clone(),
appid: appid.clone(),
});
let timeout: Duration = inner
.timeout
.map(|s| Duration::from_millis(s.into()))
.unwrap_or(DEFAULT_TIMEOUT);
Ok(GetAssertionRequest {
relying_party_id: resolved_rp_id,
challenge: inner.challenge.to_vec(),
origin: request_origin.origin.to_string(),
top_origin: request_origin.top_origin.as_ref().map(|o| o.to_string()),
allow: inner
.allow_credentials
.into_iter()
.map(|c| c.into())
.collect(),
extensions,
user_verification: inner.user_verification,
timeout,
})
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum GetAssertionHmacOrPrfInput {
HmacGetSecret(HMACGetSecretInput),
Prf(PrfInput),
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PrfInput {
pub eval: Option<PrfInputValue>,
pub eval_by_credential: HashMap<String, PrfInputValue>,
}
impl TryFrom<PrfInputJson> for PrfInput {
type Error = GetAssertionPrepareError;
fn try_from(value: PrfInputJson) -> Result<Self, Self::Error> {
let eval = value.eval.map(|v| PrfInputValue {
first: v.first.into(),
second: v.second.map(Into::into),
});
let eval_by_credential = value
.eval_by_credential
.map(|map| {
map.into_iter()
.map(|(k, v)| {
(
k,
PrfInputValue {
first: v.first.into(),
second: v.second.map(Into::into),
},
)
})
.collect()
})
.unwrap_or_default();
Ok(PrfInput {
eval,
eval_by_credential,
})
}
}
#[derive(Debug, Default, Clone, Serialize)]
pub struct GetAssertionPrfOutput {
#[serde(skip_serializing_if = "Option::is_none")]
pub results: Option<PrfOutputValue>,
}
#[derive(Clone, Debug, Default, Eq, PartialEq)]
pub struct HMACGetSecretInput {
pub salt1: [u8; 32],
pub salt2: Option<[u8; 32]>,
}
impl TryFrom<HmacGetSecretInputJson> for HMACGetSecretInput {
type Error = GetAssertionPrepareError;
fn try_from(value: HmacGetSecretInputJson) -> Result<Self, Self::Error> {
let salt1 = value.salt1.as_slice().try_into().map_err(|_| {
GetAssertionPrepareError::UnexpectedLengthError(
"extensions.hmacCreateSecret.salt1".to_string(),
value.salt1.as_slice().len(),
)
})?;
let salt2 = match value.salt2 {
Some(s) => Some(s.as_slice().try_into().map_err(|_| {
GetAssertionPrepareError::UnexpectedLengthError(
"extensions.hmacCreateSecret.salt2".to_string(),
s.as_slice().len(),
)
})?),
None => None,
};
Ok(HMACGetSecretInput { salt1, salt2 })
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum GetAssertionLargeBlobExtension {
Read,
Write(Vec<u8>),
Delete,
}
impl TryFrom<LargeBlobInputJson> for GetAssertionLargeBlobExtension {
type Error = GetAssertionPrepareError;
fn try_from(value: LargeBlobInputJson) -> Result<Self, Self::Error> {
if value.support.is_some() {
return Err(GetAssertionPrepareError::NotSupported(
"largeBlob.support is only valid at registration".to_string(),
));
}
if value.read.is_some() && value.write.is_some() {
return Err(GetAssertionPrepareError::NotSupported(
"largeBlob.read and largeBlob.write are mutually exclusive".to_string(),
));
}
if let Some(write) = value.write {
return Ok(GetAssertionLargeBlobExtension::Write(write.to_vec()));
}
match value.read {
Some(true) => Ok(GetAssertionLargeBlobExtension::Read),
_ => Err(GetAssertionPrepareError::NotSupported(
"largeBlob input must set read=true or write".to_string(),
)),
}
}
}
#[derive(Debug, Default, Clone, PartialEq, Eq, Serialize)]
pub struct GetAssertionLargeBlobExtensionOutput {
#[serde(skip_serializing_if = "Option::is_none")]
pub blob: Option<Vec<u8>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub written: Option<bool>,
}
#[derive(Debug, Default, Clone, PartialEq)]
pub struct GetAssertionRequestExtensions {
pub cred_blob: bool,
pub prf: Option<PrfInput>,
pub large_blob: Option<GetAssertionLargeBlobExtension>,
pub appid: Option<String>,
}
#[derive(Clone, Debug, Default, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct HMACGetSecretOutput {
pub output1: [u8; 32],
#[serde(skip_serializing_if = "Option::is_none")]
pub output2: Option<[u8; 32]>,
}
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
#[serde(transparent)]
pub struct Ctap2HMACGetSecretOutput {
#[serde(with = "serde_bytes")]
pub(crate) encrypted_output: Vec<u8>,
}
impl Ctap2HMACGetSecretOutput {
pub(crate) fn decrypt_output(
&self,
shared_secret: &[u8],
uv_proto: &dyn PinUvAuthProtocol,
) -> Option<HMACGetSecretOutput> {
let output = match uv_proto.decrypt(shared_secret, &self.encrypted_output) {
Ok(o) => o,
Err(e) => {
error!("Failed to decrypt HMAC Secret output with the shared secret: {e:?}. Skipping HMAC extension");
return None;
}
};
let mut res = HMACGetSecretOutput::default();
if output.len() == 32 {
res.output1.copy_from_slice(&output);
} else if output.len() == 64 {
let (o1, o2) = output.split_at(32);
res.output1.copy_from_slice(o1);
let mut output2 = [0u8; 32];
output2.copy_from_slice(o2);
res.output2 = Some(output2);
} else {
error!("Failed to split HMAC Secret outputs. Unexpected output length: {}. Skipping HMAC extension", output.len());
return None;
}
Some(res)
}
}
pub type GetAssertionResponseExtensions = Ctap2GetAssertionResponseExtensions;
#[derive(Debug, Default, Clone, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct GetAssertionResponseUnsignedExtensions {
#[serde(skip_serializing_if = "Option::is_none")]
pub hmac_get_secret: Option<HMACGetSecretOutput>,
#[serde(skip_serializing_if = "Option::is_none")]
pub large_blob: Option<GetAssertionLargeBlobExtensionOutput>,
#[serde(skip_serializing_if = "Option::is_none")]
pub prf: Option<GetAssertionPrfOutput>,
#[serde(skip_serializing_if = "Option::is_none")]
pub appid: Option<bool>,
#[serde(flatten)]
pub unsigned_extension_outputs: serde_json::Map<String, serde_json::Value>,
}
#[derive(Debug, Clone)]
pub struct GetAssertionResponse {
pub assertions: Vec<Assertion>,
}
#[derive(Debug, Clone)]
pub struct Assertion {
pub credential_id: Option<Ctap2PublicKeyCredentialDescriptor>,
pub authenticator_data: AuthenticatorData<GetAssertionResponseExtensions>,
pub signature: Vec<u8>,
pub user: Option<Ctap2PublicKeyCredentialUserEntity>,
pub credentials_count: Option<u32>,
pub user_selected: Option<bool>,
pub unsigned_extensions_output: Option<GetAssertionResponseUnsignedExtensions>,
}
impl WebAuthnIDLResponse for Assertion {
type IdlModel = AuthenticationResponseJSON;
type Context = GetAssertionRequest;
fn to_idl_model(
&self,
request: &Self::Context,
) -> Result<Self::IdlModel, ResponseSerializationError> {
let credential_id_bytes = self
.credential_id
.as_ref()
.map(|cred| cred.id.to_vec())
.unwrap_or_default();
let id = base64_url::encode(&credential_id_bytes);
let raw_id = Base64UrlString::from(credential_id_bytes);
let authenticator_data_bytes = self
.authenticator_data
.to_response_bytes()
.map_err(|e| ResponseSerializationError::AuthenticatorDataError(e.to_string()))?;
let user_handle = self
.user
.as_ref()
.map(|user| Base64UrlString::from(user.id.as_ref()));
let client_extension_results = self.build_client_extension_results();
Ok(AuthenticationResponseJSON {
id,
raw_id,
response: AuthenticatorAssertionResponseJSON {
client_data_json: Base64UrlString::from(request.client_data_json().into_bytes()),
authenticator_data: Base64UrlString::from(authenticator_data_bytes),
signature: Base64UrlString::from(self.signature.clone()),
user_handle,
},
authenticator_attachment: None,
client_extension_results,
r#type: "public-key".to_string(),
})
}
}
impl Assertion {
fn build_client_extension_results(&self) -> AuthenticationExtensionsClientOutputsJSON {
let mut results = AuthenticationExtensionsClientOutputsJSON::default();
if let Some(unsigned_ext) = &self.unsigned_extensions_output {
results.appid = unsigned_ext.appid;
if let Some(hmac_output) = &unsigned_ext.hmac_get_secret {
results.hmac_get_secret = Some(HMACGetSecretOutputJSON {
output1: Base64UrlString::from(hmac_output.output1.as_slice()),
output2: hmac_output
.output2
.as_ref()
.map(|o| Base64UrlString::from(o.as_slice())),
});
}
if let Some(large_blob) = &unsigned_ext.large_blob {
results.large_blob = Some(LargeBlobOutputJSON {
supported: None,
blob: large_blob
.blob
.as_ref()
.map(|b| Base64UrlString::from(b.as_slice())),
written: large_blob.written,
});
}
if let Some(prf_output) = &unsigned_ext.prf {
results.prf = Some(PRFOutputJSON {
enabled: None,
results: prf_output.results.as_ref().map(|prf_value| PRFValuesJSON {
first: Base64UrlString::from(prf_value.first.as_slice()),
second: prf_value
.second
.as_ref()
.map(|s| Base64UrlString::from(s.as_slice())),
}),
});
}
const TYPED_MEMBERS: [&str; 6] = [
"appid",
"credProps",
"hmacCreateSecret",
"hmacGetSecret",
"largeBlob",
"prf",
];
for (id, value) in &unsigned_ext.unsigned_extension_outputs {
if !TYPED_MEMBERS.contains(&id.as_str()) {
results
.unsigned_extension_outputs
.insert(id.clone(), value.clone());
}
}
}
results
}
}
impl From<&[Assertion]> for GetAssertionResponse {
fn from(assertions: &[Assertion]) -> Self {
Self {
assertions: assertions.to_owned(),
}
}
}
impl From<Assertion> for GetAssertionResponse {
fn from(assertion: Assertion) -> Self {
Self {
assertions: vec![assertion],
}
}
}
impl DowngradableRequest<Vec<SignRequest>> for GetAssertionRequest {
fn is_downgradable(&self) -> bool {
if let UserVerificationRequirement::Required = self.user_verification {
debug!("Not downgradable: relying party (RP) requires user verification");
return false;
}
if self.allow.is_empty() {
debug!("Not downgradable: allowList is empty.");
return false;
}
if matches!(
self.extensions.as_ref().and_then(|e| e.large_blob.as_ref()),
Some(GetAssertionLargeBlobExtension::Write(_))
| Some(GetAssertionLargeBlobExtension::Delete)
) {
debug!("Not downgradable: largeBlob write/delete requires FIDO2");
return false;
}
true
}
fn try_downgrade(&self) -> Result<Vec<SignRequest>, CtapError> {
trace!(?self);
let challenge = self.client_data_hash();
let mut hasher = Sha256::default();
hasher.update(self.relying_party_id.as_bytes());
let rp_id_hash = hasher.finalize().to_vec();
let appid_hash: Option<Vec<u8>> = self
.extensions
.as_ref()
.and_then(|e| e.appid.as_ref())
.map(|appid| {
let mut hasher = Sha256::default();
hasher.update(appid.as_bytes());
hasher.finalize().to_vec()
});
let mut downgraded_requests: Vec<SignRequest> = Vec::new();
for credential in &self.allow {
let credential_id = &credential.id;
downgraded_requests.push(SignRequest::new_upgraded(
&rp_id_hash,
&challenge,
credential_id,
self.timeout,
));
if let Some(ref appid_hash) = appid_hash {
downgraded_requests.push(SignRequest::new_upgraded(
appid_hash,
&challenge,
credential_id,
self.timeout,
));
}
}
trace!(?downgraded_requests);
Ok(downgraded_requests)
}
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use async_trait::async_trait;
use serde_bytes::ByteBuf;
use crate::ops::webauthn::psl::{MockPublicSuffixList, PublicSuffixList};
use crate::ops::webauthn::related_origins::{
HttpClientError, MaxRegistrableLabels, RelatedOrigins, RelatedOriginsError,
RelatedOriginsSource,
};
use crate::ops::webauthn::{GetAssertionRequest, OriginValidation, RequestOrigin};
use crate::proto::ctap2::Ctap2PublicKeyCredentialType;
use super::*;
struct MockSource {
result: Option<Result<Vec<String>, RelatedOriginsError>>,
}
impl MockSource {
fn origins(items: &[&str]) -> Self {
Self {
result: Some(Ok(items.iter().map(|s| s.to_string()).collect())),
}
}
fn err(e: RelatedOriginsError) -> Self {
Self {
result: Some(Err(e)),
}
}
fn panicking() -> Self {
Self { result: None }
}
}
#[async_trait]
impl RelatedOriginsSource for MockSource {
async fn allowed_origins(
&self,
_: &RelyingPartyId,
) -> Result<Vec<String>, RelatedOriginsError> {
match &self.result {
Some(r) => r.clone(),
None => panic!("allowed_origins should not be called"),
}
}
}
async fn from_json(
origin: &RequestOrigin,
psl: &dyn PublicSuffixList,
related_origins: RelatedOrigins<'_>,
json: &str,
) -> Result<GetAssertionRequest, GetAssertionPrepareError> {
GetAssertionRequest::prepare(
origin,
json,
&RequestSettings {
origin: OriginValidation::Validate {
public_suffix_list: psl,
related_origins,
},
},
)
.await
}
pub const REQUEST_BASE_JSON: &str = r#"
{
"challenge": "Y3JlZGVudGlhbHMtZm9yLWxpbnV4L2xpYndlYmF1dGhu",
"timeout": 30000,
"rpId": "example.org",
"allowCredentials": [
{
"type": "public-key",
"id": "bXktY3JlZGVudGlhbC1pZA"
}
],
"userVerification": "preferred"
}
"#;
fn request_base() -> GetAssertionRequest {
GetAssertionRequest {
relying_party_id: "example.org".to_owned(),
challenge: base64_url::decode("Y3JlZGVudGlhbHMtZm9yLWxpbnV4L2xpYndlYmF1dGhu").unwrap(),
origin: "https://example.org".to_string(),
top_origin: None,
allow: vec![Ctap2PublicKeyCredentialDescriptor {
r#type: Ctap2PublicKeyCredentialType::PublicKey,
id: ByteBuf::from(base64_url::decode("bXktY3JlZGVudGlhbC1pZA").unwrap()),
transports: None,
}],
extensions: None, user_verification: UserVerificationRequirement::Preferred,
timeout: Duration::from_secs(30),
}
}
fn json_field_add(str: &str, field: &str, value: &str) -> String {
let mut v: serde_json::Value = serde_json::from_str(str).unwrap();
v.as_object_mut()
.unwrap()
.insert(field.to_owned(), serde_json::from_str(value).unwrap());
serde_json::to_string(&v).unwrap()
}
fn json_field_rm(str: &str, field: &str) -> String {
let mut v: serde_json::Value = serde_json::from_str(str).unwrap();
v.as_object_mut().unwrap().remove(field);
serde_json::to_string(&v).unwrap()
}
#[tokio::test]
async fn test_request_from_json_base() {
let request_origin: RequestOrigin = "https://example.org".parse().unwrap();
let req: GetAssertionRequest = from_json(
&request_origin,
&MockPublicSuffixList,
RelatedOrigins::Disabled,
REQUEST_BASE_JSON,
)
.await
.unwrap();
assert_eq!(req, request_base());
}
#[tokio::test]
async fn test_request_from_json_ignore_missing_rp_id() {
let request_origin: RequestOrigin = "https://example.org".parse().unwrap();
let req_json = json_field_rm(REQUEST_BASE_JSON, "rpId");
let req: GetAssertionRequest = from_json(
&request_origin,
&MockPublicSuffixList,
RelatedOrigins::Disabled,
&req_json,
)
.await
.unwrap();
assert_eq!(req, request_base());
}
#[tokio::test]
async fn test_request_from_json_invalid_rp_id() {
let request_origin: RequestOrigin = "https://example.org".parse().unwrap();
let req_json = json_field_add(REQUEST_BASE_JSON, "rpId", r#""example.org.""#);
let result = from_json(
&request_origin,
&MockPublicSuffixList,
RelatedOrigins::Disabled,
&req_json,
)
.await;
assert!(matches!(
result,
Err(GetAssertionPrepareError::InvalidRelyingPartyId(_))
));
}
#[tokio::test]
async fn test_request_from_json_mismatching_rp_id() {
let request_origin: RequestOrigin = "https://example.org".parse().unwrap();
let req_json = json_field_add(REQUEST_BASE_JSON, "rpId", r#""other.example.org""#);
let result = from_json(
&request_origin,
&MockPublicSuffixList,
RelatedOrigins::Disabled,
&req_json,
)
.await;
assert!(matches!(
result,
Err(GetAssertionPrepareError::MismatchingRelyingPartyId(_, _))
));
}
#[tokio::test]
async fn origin_trust_accepts_mismatching_rp_id() {
let request_origin: RequestOrigin = "https://app.example.org".parse().unwrap();
let req_json = json_field_add(REQUEST_BASE_JSON, "rpId", r#""example.com""#);
let req = GetAssertionRequest::prepare(
&request_origin,
&req_json,
&RequestSettings {
origin: OriginValidation::Trust,
},
)
.await
.unwrap();
assert_eq!(req.relying_party_id, "example.com");
}
#[tokio::test]
async fn test_request_from_json_rp_id_is_parent_registrable_suffix() {
let request_origin: RequestOrigin = "https://login.example.org".parse().unwrap();
let req_json = json_field_add(REQUEST_BASE_JSON, "rpId", r#""example.org""#);
let req = from_json(
&request_origin,
&MockPublicSuffixList,
RelatedOrigins::Disabled,
&req_json,
)
.await
.unwrap();
assert_eq!(req.relying_party_id, "example.org");
assert_eq!(req.origin, "https://login.example.org");
}
#[tokio::test]
async fn test_request_from_json_rp_id_is_etld_rejected() {
let request_origin: RequestOrigin = "https://example.co.uk".parse().unwrap();
let req_json = json_field_add(REQUEST_BASE_JSON, "rpId", r#""co.uk""#);
let result = from_json(
&request_origin,
&MockPublicSuffixList,
RelatedOrigins::Disabled,
&req_json,
)
.await;
assert!(matches!(
result,
Err(GetAssertionPrepareError::MismatchingRelyingPartyId(_, _))
));
}
#[tokio::test]
async fn related_origins_match_resolves_mismatch() {
let request_origin: RequestOrigin = "https://app.example.org".parse().unwrap();
let req_json = json_field_add(REQUEST_BASE_JSON, "rpId", r#""example.com""#);
let source = MockSource::origins(&["https://app.example.org"]);
let req = from_json(
&request_origin,
&MockPublicSuffixList,
RelatedOrigins::Enabled {
source: &source,
max_labels: MaxRegistrableLabels::default(),
},
&req_json,
)
.await
.unwrap();
assert_eq!(req.relying_party_id, "example.com");
}
#[tokio::test]
async fn related_origins_no_match_keeps_mismatch_error() {
let request_origin: RequestOrigin = "https://app.example.org".parse().unwrap();
let req_json = json_field_add(REQUEST_BASE_JSON, "rpId", r#""example.com""#);
let source = MockSource::origins(&["https://other.org"]);
let result = from_json(
&request_origin,
&MockPublicSuffixList,
RelatedOrigins::Enabled {
source: &source,
max_labels: MaxRegistrableLabels::default(),
},
&req_json,
)
.await;
assert!(matches!(
result,
Err(GetAssertionPrepareError::MismatchingRelyingPartyId(_, _))
));
}
#[tokio::test]
async fn related_origins_fetch_error_keeps_mismatch_error() {
let request_origin: RequestOrigin = "https://app.example.org".parse().unwrap();
let req_json = json_field_add(REQUEST_BASE_JSON, "rpId", r#""example.com""#);
let source = MockSource::err(RelatedOriginsError::Http(HttpClientError::Transport(
"simulated".into(),
)));
let result = from_json(
&request_origin,
&MockPublicSuffixList,
RelatedOrigins::Enabled {
source: &source,
max_labels: MaxRegistrableLabels::default(),
},
&req_json,
)
.await;
assert!(matches!(
result,
Err(GetAssertionPrepareError::MismatchingRelyingPartyId(_, _))
));
}
#[tokio::test]
async fn related_origins_not_consulted_when_suffix_matches() {
let request_origin: RequestOrigin = "https://login.example.com".parse().unwrap();
let req_json = json_field_add(REQUEST_BASE_JSON, "rpId", r#""example.com""#);
let source = MockSource::panicking();
let req = from_json(
&request_origin,
&MockPublicSuffixList,
RelatedOrigins::Enabled {
source: &source,
max_labels: MaxRegistrableLabels::default(),
},
&req_json,
)
.await
.unwrap();
assert_eq!(req.relying_party_id, "example.com");
}
#[tokio::test]
async fn test_request_from_json_ignore_missing_allow_credentials() {
let request_origin: RequestOrigin = "https://example.org".parse().unwrap();
let req_json = json_field_rm(REQUEST_BASE_JSON, "allowCredentials");
let req: GetAssertionRequest = from_json(
&request_origin,
&MockPublicSuffixList,
RelatedOrigins::Disabled,
&req_json,
)
.await
.unwrap();
assert_eq!(
req,
GetAssertionRequest {
allow: vec![],
..request_base()
}
);
}
#[tokio::test]
async fn test_request_from_json_default_timeout() {
let request_origin: RequestOrigin = "https://example.org".parse().unwrap();
let req_json = json_field_rm(REQUEST_BASE_JSON, "timeout");
let req: GetAssertionRequest = from_json(
&request_origin,
&MockPublicSuffixList,
RelatedOrigins::Disabled,
&req_json,
)
.await
.unwrap();
assert_eq!(req.timeout, DEFAULT_TIMEOUT);
}
#[tokio::test]
async fn test_request_from_json_empty_extensions() {
let request_origin: RequestOrigin = "https://example.org".parse().unwrap();
let req_json = json_field_add(REQUEST_BASE_JSON, "extensions", r#"{}"#);
let req: GetAssertionRequest = from_json(
&request_origin,
&MockPublicSuffixList,
RelatedOrigins::Disabled,
&req_json,
)
.await
.unwrap();
assert_eq!(
req.extensions,
Some(GetAssertionRequestExtensions::default())
);
}
#[tokio::test]
async fn test_request_from_json_appid_extension() {
let request_origin: RequestOrigin = "https://example.org".parse().unwrap();
let req_json = json_field_add(
REQUEST_BASE_JSON,
"extensions",
r#"{"appid":"https://www.example.org/u2f/origins.json"}"#,
);
let req: GetAssertionRequest = from_json(
&request_origin,
&MockPublicSuffixList,
RelatedOrigins::Disabled,
&req_json,
)
.await
.unwrap();
let ext = req.extensions.expect("extensions should be present");
assert_eq!(
ext.appid.as_deref(),
Some("https://www.example.org/u2f/origins.json")
);
}
#[tokio::test]
async fn test_request_from_json_appid_extension_invalid_non_https() {
let request_origin: RequestOrigin = "https://example.org".parse().unwrap();
let req_json = json_field_add(
REQUEST_BASE_JSON,
"extensions",
r#"{"appid":"http://www.example.org/u2f/origins.json"}"#,
);
let res = from_json(
&request_origin,
&MockPublicSuffixList,
RelatedOrigins::Disabled,
&req_json,
)
.await;
assert!(matches!(
res,
Err(GetAssertionPrepareError::InvalidAppId(_))
));
}
#[test]
fn test_try_downgrade_with_appid_uses_appid_hash() {
use sha2::{Digest, Sha256};
let mut req = request_base();
req.extensions = Some(GetAssertionRequestExtensions {
cred_blob: false,
prf: None,
large_blob: None,
appid: Some("https://www.example.org/u2f/origins.json".to_string()),
});
let sign_requests = req.try_downgrade().expect("downgrade ok");
assert_eq!(sign_requests.len(), 2);
let mut rp_hasher = Sha256::default();
rp_hasher.update(b"example.org");
let rp_hash = rp_hasher.finalize().to_vec();
let mut appid_hasher = Sha256::default();
appid_hasher.update(b"https://www.example.org/u2f/origins.json");
let appid_hash = appid_hasher.finalize().to_vec();
assert_eq!(sign_requests[0].app_id_hash, rp_hash);
assert_eq!(sign_requests[1].app_id_hash, appid_hash);
assert_eq!(sign_requests[0].key_handle, sign_requests[1].key_handle);
}
#[test]
fn test_try_downgrade_without_appid_uses_rp_hash() {
use sha2::{Digest, Sha256};
let req = request_base();
let sign_requests = req.try_downgrade().expect("downgrade ok");
assert_eq!(sign_requests.len(), 1);
let mut rp_hasher = Sha256::default();
rp_hasher.update(b"example.org");
let rp_hash = rp_hasher.finalize().to_vec();
assert_eq!(sign_requests[0].app_id_hash, rp_hash);
}
async fn parse_prf(extensions_json: &str) -> PrfInput {
let request_origin: RequestOrigin = "https://example.org".parse().unwrap();
let req_json = json_field_add(REQUEST_BASE_JSON, "extensions", extensions_json);
let req = from_json(
&request_origin,
&MockPublicSuffixList,
RelatedOrigins::Disabled,
&req_json,
)
.await
.expect("request should parse");
req.extensions
.expect("extensions")
.prf
.expect("prf extension")
}
#[tokio::test]
async fn test_request_from_json_prf_extension() {
let prf = parse_prf(r#"{"prf":{"eval":{"first":"AQID","second":"BAUG"}}}"#).await;
let eval = prf.eval.expect("eval");
assert_eq!(eval.first, vec![0x01, 0x02, 0x03]);
assert_eq!(eval.second.as_deref(), Some(&[0x04u8, 0x05, 0x06][..]));
}
#[tokio::test]
async fn test_prf_input_variable_length() {
for len in [1usize, 16, 31, 33, 64, 256] {
let bytes = vec![0xABu8; len];
let b64 = base64_url::encode(&bytes);
let prf = parse_prf(&format!(r#"{{"prf":{{"eval":{{"first":"{b64}"}}}}}}"#)).await;
let eval = prf.eval.unwrap();
assert_eq!(eval.first.len(), len, "len {len}");
assert_eq!(eval.first, bytes, "len {len}");
assert!(eval.second.is_none());
}
}
#[tokio::test]
async fn test_prf_input_short_via_json() {
let prf = parse_prf(r#"{"prf":{"eval":{"first":"aGk"}}}"#).await; let eval = prf.eval.expect("eval");
assert_eq!(eval.first, b"hi");
assert!(eval.second.is_none());
}
#[tokio::test]
async fn test_prf_input_empty_allowed() {
let prf = parse_prf(r#"{"prf":{"eval":{"first":""}}}"#).await;
let eval = prf.eval.unwrap();
assert!(eval.first.is_empty());
assert!(eval.second.is_none());
}
#[tokio::test]
async fn test_prf_eval_by_credential_variable_length() {
let prf = parse_prf(
r#"{"prf":{"eval_by_credential":{"Y3JlZDE":{"first":"AQ","second":"AgIC"}}}}"#,
)
.await;
let v = prf.eval_by_credential.get("Y3JlZDE").expect("entry");
assert_eq!(v.first, vec![0x01]);
assert_eq!(v.second.as_deref(), Some(&[0x02u8, 0x02, 0x02][..]));
}
fn create_test_assertion() -> Assertion {
use crate::fido::{AuthenticatorData, AuthenticatorDataFlags};
let authenticator_data = AuthenticatorData {
rp_id_hash: [0u8; 32],
flags: AuthenticatorDataFlags::USER_PRESENT,
signature_count: 1,
attested_credential: None,
extensions: None,
raw: None,
};
Assertion {
credential_id: Some(Ctap2PublicKeyCredentialDescriptor {
r#type: Ctap2PublicKeyCredentialType::PublicKey,
id: ByteBuf::from(vec![0x01, 0x02, 0x03, 0x04]),
transports: None,
}),
authenticator_data,
signature: vec![0xDE, 0xAD, 0xC0, 0xDE],
user: None,
credentials_count: None,
user_selected: None,
unsigned_extensions_output: None,
}
}
fn create_test_request() -> GetAssertionRequest {
GetAssertionRequest {
relying_party_id: "example.org".to_owned(),
challenge: b"DEADCODE_challenge".to_vec(),
origin: "example.org".to_string(),
top_origin: None,
allow: vec![],
extensions: None,
user_verification: UserVerificationRequirement::Preferred,
timeout: Duration::from_secs(30),
}
}
#[test]
fn test_assertion_to_json() {
use crate::ops::webauthn::idl::response::JsonFormat;
let assertion = create_test_assertion();
let request = create_test_request();
let json = assertion.to_json_string(&request, JsonFormat::default());
assert!(json.is_ok());
let json_str = json.unwrap();
let parsed: serde_json::Value = serde_json::from_str(&json_str).unwrap();
let expected_credential_id = base64_url::encode(&[0x01, 0x02, 0x03, 0x04]);
assert_eq!(parsed.get("id").unwrap(), &expected_credential_id);
assert_eq!(parsed.get("rawId").unwrap(), &expected_credential_id);
assert_eq!(parsed.get("type").unwrap(), "public-key");
let response_obj = parsed.get("response").unwrap();
assert!(response_obj.get("clientDataJSON").is_some());
assert!(response_obj.get("authenticatorData").is_some());
let expected_signature = base64_url::encode(&[0xDE, 0xAD, 0xC0, 0xDE]);
assert_eq!(response_obj.get("signature").unwrap(), &expected_signature);
}
#[test]
fn test_assertion_to_idl_model() {
let assertion = create_test_assertion();
let request = create_test_request();
let model = assertion.to_idl_model(&request).unwrap();
assert_eq!(model.raw_id.0, vec![0x01, 0x02, 0x03, 0x04]);
assert_eq!(model.r#type, "public-key");
assert_eq!(model.response.signature.0, vec![0xDE, 0xAD, 0xC0, 0xDE]);
}
#[test]
fn test_assertion_with_user_handle() {
use crate::proto::ctap2::Ctap2PublicKeyCredentialUserEntity;
let mut assertion = create_test_assertion();
assertion.user = Some(Ctap2PublicKeyCredentialUserEntity::new(
b"test-user-id",
"testuser",
"Test User",
));
let request = create_test_request();
let model = assertion.to_idl_model(&request).unwrap();
assert!(model.response.user_handle.is_some());
assert_eq!(
model.response.user_handle.as_ref().unwrap().0,
b"test-user-id".to_vec()
);
}
#[test]
fn test_assertion_with_extensions() {
let mut assertion = create_test_assertion();
assertion.unsigned_extensions_output = Some(GetAssertionResponseUnsignedExtensions {
hmac_get_secret: None,
large_blob: None,
prf: Some(GetAssertionPrfOutput {
results: Some(PrfOutputValue {
first: [0x01u8; 32],
second: None,
}),
}),
appid: None,
unsigned_extension_outputs: Default::default(),
});
let request = create_test_request();
let model = assertion.to_idl_model(&request).unwrap();
let prf = model.client_extension_results.prf.as_ref().unwrap();
assert!(prf.enabled.is_none()); let results = prf.results.as_ref().unwrap();
assert_eq!(results.first.0, vec![0x01u8; 32]);
assert!(results.second.is_none());
}
#[test]
fn test_assertion_appid_extension_output_true() {
let mut assertion = create_test_assertion();
assertion.unsigned_extensions_output = Some(GetAssertionResponseUnsignedExtensions {
hmac_get_secret: None,
large_blob: None,
prf: None,
appid: Some(true),
unsigned_extension_outputs: Default::default(),
});
let request = create_test_request();
let model = assertion.to_idl_model(&request).unwrap();
assert_eq!(model.client_extension_results.appid, Some(true));
let json = serde_json::to_value(&model.client_extension_results).unwrap();
assert_eq!(
json.get("appid").and_then(|v| v.as_bool()),
Some(true),
"JSON output should include `appid: true`"
);
}
#[test]
fn test_assertion_appid_extension_output_omitted_when_none() {
let mut assertion = create_test_assertion();
assertion.unsigned_extensions_output = Some(GetAssertionResponseUnsignedExtensions {
hmac_get_secret: None,
large_blob: None,
prf: None,
appid: None,
unsigned_extension_outputs: Default::default(),
});
let request = create_test_request();
let model = assertion.to_idl_model(&request).unwrap();
assert_eq!(model.client_extension_results.appid, None);
let json = serde_json::to_value(&model.client_extension_results).unwrap();
assert!(
json.get("appid").is_none(),
"JSON output should omit `appid` when not requested"
);
}
#[test]
fn large_blob_json_read_input_and_identifier() {
use crate::ops::webauthn::idl::get::{
GetAssertionRequestExtensionsJSON, LargeBlobInputJson,
};
assert_eq!(
GetAssertionLargeBlobExtension::try_from(LargeBlobInputJson {
support: None,
read: Some(true),
write: None,
})
.unwrap(),
GetAssertionLargeBlobExtension::Read
);
assert!(matches!(
GetAssertionLargeBlobExtension::try_from(LargeBlobInputJson {
support: Some("required".to_string()),
read: Some(true),
write: None,
}),
Err(GetAssertionPrepareError::NotSupported(_))
));
let ext: GetAssertionRequestExtensionsJSON =
serde_json::from_str(r#"{"largeBlob":{"read":true}}"#).unwrap();
assert!(
ext.large_blob.is_some(),
"extension keyed under `largeBlob`"
);
let ext: GetAssertionRequestExtensionsJSON =
serde_json::from_str(r#"{"largeBlobKey":{"read":true}}"#).unwrap();
assert!(
ext.large_blob.is_none(),
"legacy `largeBlobKey` must not bind"
);
}
#[tokio::test]
async fn test_request_from_json_large_blob_read_false_is_noop() {
let request_origin: RequestOrigin = "https://example.org".parse().unwrap();
let json = json_field_add(
REQUEST_BASE_JSON,
"extensions",
r#"{"largeBlob":{"read":false}}"#,
);
let req = from_json(
&request_origin,
&MockPublicSuffixList,
RelatedOrigins::Disabled,
&json,
)
.await
.expect("largeBlob.read=false must be a no-op, not an error");
assert!(req.extensions.and_then(|e| e.large_blob).is_none());
}
#[test]
fn large_blob_json_write_input_and_mutual_exclusion() {
use crate::ops::webauthn::idl::get::LargeBlobInputJson;
let blob = b"blob to write".to_vec();
assert_eq!(
GetAssertionLargeBlobExtension::try_from(LargeBlobInputJson {
support: None,
read: None,
write: Some(Base64UrlString::from(blob.clone())),
})
.unwrap(),
GetAssertionLargeBlobExtension::Write(blob)
);
assert!(matches!(
GetAssertionLargeBlobExtension::try_from(LargeBlobInputJson {
support: None,
read: Some(true),
write: Some(Base64UrlString::from(b"x".to_vec())),
}),
Err(GetAssertionPrepareError::NotSupported(_))
));
}
}