use crate::error::{Lib3mfError, Result};
use crate::model::crypto::*;
use base64::prelude::*;
use rsa::RsaPublicKey;
use rsa::signature::Verifier;
use sha1::Sha1;
use sha2::{Digest, Sha256};
use x509_parser::prelude::FromDer;
pub fn verify_signature<F>(
signature: &Signature,
public_key: &RsaPublicKey,
content_resolver: F,
signed_info_bytes: &[u8],
) -> Result<bool>
where
F: Fn(&str) -> Result<Vec<u8>>,
{
for reference in &signature.signed_info.references {
verify_reference(reference, &content_resolver)?;
}
let sig_value = BASE64_STANDARD
.decode(&signature.signature_value.value)
.map_err(|e| Lib3mfError::Validation(format!("Invalid base64 signature: {}", e)))?;
match signature.signed_info.signature_method.algorithm.as_str() {
"http://www.w3.org/2001/04/xmldsig-more#rsa-sha256" => {
let verifying_key = rsa::pkcs1v15::VerifyingKey::<Sha256>::new(public_key.clone());
let rsa_signature =
rsa::pkcs1v15::Signature::try_from(sig_value.as_slice()).map_err(|e| {
Lib3mfError::Validation(format!("Invalid RSA signature format: {}", e))
})?;
verifying_key
.verify(signed_info_bytes, &rsa_signature)
.map_err(|e| {
Lib3mfError::Validation(format!("Signature verification failed: {}", e))
})?;
}
_ => {
return Err(Lib3mfError::Validation(format!(
"Unsupported signature method: {}",
signature.signed_info.signature_method.algorithm
)));
}
}
Ok(true)
}
fn verify_reference<F>(reference: &Reference, content_resolver: &F) -> Result<()>
where
F: Fn(&str) -> Result<Vec<u8>>,
{
let content = content_resolver(&reference.uri)?;
let calculated_digest = match reference.digest_method.algorithm.as_str() {
"http://www.w3.org/2001/04/xmlenc#sha256" => {
let mut hasher = Sha256::new();
hasher.update(&content);
hasher.finalize().to_vec()
}
"http://www.w3.org/2000/09/xmldsig#sha1" => {
let mut hasher = Sha1::new();
hasher.update(&content);
hasher.finalize().to_vec()
}
_ => {
return Err(Lib3mfError::Validation(format!(
"Unsupported digest method: {}",
reference.digest_method.algorithm
)));
}
};
let stored_digest = BASE64_STANDARD
.decode(&reference.digest_value.value)
.map_err(|e| Lib3mfError::Validation(format!("Invalid base64 digest: {}", e)))?;
if calculated_digest != stored_digest {
return Err(Lib3mfError::Validation(format!(
"Digest mismatch for URI {}",
reference.uri
)));
}
Ok(())
}
pub fn verify_signature_extended<F>(
signature: &Signature,
content_resolver: F,
signed_info_bytes: &[u8],
) -> Result<bool>
where
F: Fn(&str) -> Result<Vec<u8>>,
{
let key = extract_key_from_signature(signature)?;
verify_signature(signature, &key, content_resolver, signed_info_bytes)
}
pub fn extract_key_from_signature(signature: &Signature) -> Result<RsaPublicKey> {
if let Some(info) = &signature.key_info {
if let Some(kv) = &info.key_value
&& let Some(rsa_val) = &kv.rsa_key_value
{
let n_bytes = BASE64_STANDARD
.decode(&rsa_val.modulus)
.map_err(|e| Lib3mfError::Validation(format!("Invalid modulus base64: {}", e)))?;
let e_bytes = BASE64_STANDARD
.decode(&rsa_val.exponent)
.map_err(|e| Lib3mfError::Validation(format!("Invalid exponent base64: {}", e)))?;
let n = rsa::BigUint::from_bytes_be(&n_bytes);
let e = rsa::BigUint::from_bytes_be(&e_bytes);
return RsaPublicKey::new(n, e).map_err(|e| {
Lib3mfError::Validation(format!("Invalid RSA key components: {}", e))
});
}
if let Some(x509) = &info.x509_data
&& let Some(cert_b64) = &x509.certificate
{
let clean_b64: String = cert_b64.chars().filter(|c| !c.is_whitespace()).collect();
let cert_der = BASE64_STANDARD
.decode(&clean_b64)
.map_err(|e| Lib3mfError::Validation(format!("Invalid X509 base64: {}", e)))?;
let (_, cert) = x509_parser::certificate::X509Certificate::from_der(&cert_der)
.map_err(|e| Lib3mfError::Validation(format!("Invalid X509 certificate: {}", e)))?;
use rsa::pkcs8::DecodePublicKey;
return RsaPublicKey::from_public_key_der(cert.tbs_certificate.subject_pki.raw)
.map_err(|e| Lib3mfError::Validation(format!("Invalid RSA key in cert: {}", e)));
}
}
Err(Lib3mfError::Validation(
"No usable KeyValue or X509Certificate found in KeyInfo".into(),
))
}