use base64::{engine::general_purpose, Engine};
use ciborium::de::from_reader;
use p256::ecdsa::{self, signature::Verifier, VerifyingKey};
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use std::error::Error;
use std::io::Cursor;
use super::{authenticator::AuthenticatorData, error::AppAttestError};
#[derive(Serialize, Deserialize, Debug)]
pub struct Assertion {
#[serde(rename = "authenticatorData")]
raw_authenticator_data: Vec<u8>,
#[serde(rename = "signature")]
signature: Vec<u8>,
}
#[derive(Serialize, Deserialize, Debug)]
struct ClientData {
challenge: String,
}
impl Assertion {
pub fn from_base64(base64_assertion: &str) -> Result<Self, AppAttestError> {
let decoded_bytes = general_purpose::STANDARD
.decode(base64_assertion)
.map_err(|e| AppAttestError::Message(format!("Failed to decode Base64: {}", e)))?;
let cursor = Cursor::new(decoded_bytes);
let assertion_result: Result<Assertion, _> = from_reader(cursor);
if let Ok(assertion) = assertion_result {
return Ok(assertion);
}
Err(AppAttestError::Message(
"unable to parse assertion".to_string(),
))
}
pub fn verify(
self,
client_data_byte: Vec<u8>,
app_id: &str,
public_key_byte: Vec<u8>,
previous_counter: u32,
stored_challenge: &str,
) -> Result<(), Box<dyn Error>> {
let client_data = serde_json::from_slice::<ClientData>(&client_data_byte)?;
let auth_data = AuthenticatorData::new(self.raw_authenticator_data)?;
let mut hasher = Sha256::new();
hasher.update(&client_data_byte);
let client_data_hash = hasher.finalize();
let verifying_key = VerifyingKey::from_sec1_bytes(&public_key_byte)
.map_err(|_| AppAttestError::Message("failed to parse the public key".to_string()))?;
let mut hasher = Sha256::new();
hasher.update(auth_data.bytes.as_slice());
hasher.update(client_data_hash.as_slice());
let nonce_hash = hasher.finalize();
let signature = ecdsa::Signature::from_der(&self.signature)
.map_err(|_| AppAttestError::Message("invalid signature format".to_string()))?;
if verifying_key
.verify(nonce_hash.as_slice(), &signature)
.is_err()
{
return Err(Box::new(AppAttestError::InvalidSignature));
}
auth_data.verify_app_id(app_id)?;
if auth_data.counter <= previous_counter {
return Err(Box::new(AppAttestError::InvalidCounter));
}
if stored_challenge != client_data.challenge {
return Err(Box::new(AppAttestError::Message(
"challenge mismatch".to_string(),
)));
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_from_base64_valid() {
let valid_cbor_base64 = "omlzaWduYXR1cmVYRjBEAiAImFuY4+UbGZ5/ZbjAJpjQ3bd8GxaKFpMEo58WMEUGbwIgaqdDJnVS8/3oJCz16O5Zp4Qga5g6zrFF7eoiYEWkdtNxYXV0aGVudGljYXRvckRhdGFYJaRc2WwGuoniZEqtF+kolObjxcczFdDxbrhJR/nT8ehTQAAAAAI=";
let result = Assertion::from_base64(valid_cbor_base64);
assert!(result.is_ok());
}
#[test]
fn test_from_base64_invalid_base64() {
let result = Assertion::from_base64("not-valid-base64!!!");
assert!(result.is_err());
}
#[test]
fn test_from_base64_valid_base64_invalid_cbor() {
use base64::{engine::general_purpose::STANDARD, Engine};
let not_cbor = STANDARD.encode(b"this is not CBOR data");
let result = Assertion::from_base64(¬_cbor);
assert!(result.is_err());
}
#[test]
fn test_from_base64_empty() {
let result = Assertion::from_base64("");
assert!(result.is_err());
}
#[test]
fn test_parsed_assertion_has_fields() {
let valid_cbor_base64 = "omlzaWduYXR1cmVYRjBEAiAImFuY4+UbGZ5/ZbjAJpjQ3bd8GxaKFpMEo58WMEUGbwIgaqdDJnVS8/3oJCz16O5Zp4Qga5g6zrFF7eoiYEWkdtNxYXV0aGVudGljYXRvckRhdGFYJaRc2WwGuoniZEqtF+kolObjxcczFdDxbrhJR/nT8ehTQAAAAAI=";
let assertion = Assertion::from_base64(valid_cbor_base64).unwrap();
assert!(!assertion.signature.is_empty());
assert!(!assertion.raw_authenticator_data.is_empty());
}
}