use crate::{
base64::{DecodeUrlSafeBase64, ToUrlSafeBase64},
http::HeaderName,
jws::JwsHeader,
openssl,
sign::build_v2_signing_payload,
Error,
};
use anyhow::anyhow;
use indexmap::{IndexMap, IndexSet};
use std::fmt;
pub struct Verifier<'a> {
public_key: PublicKey<'a>,
body: &'a [u8],
method: &'a str,
path: &'a str,
headers: IndexMap<HeaderName<'a>, &'a [u8]>,
required_headers: IndexSet<HeaderName<'a>>,
allow_v1: bool,
}
#[derive(Clone, Copy)]
pub(crate) enum PublicKey<'a> {
Pem(&'a [u8]),
Jwks(&'a [u8]),
}
impl fmt::Debug for Verifier<'_> {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
write!(fmt, "Verifier")
}
}
impl<'a> Verifier<'a> {
pub(crate) fn new(public_key: PublicKey<'a>) -> Self {
Self {
public_key,
body: &[],
method: "",
path: "",
headers: <_>::default(),
required_headers: <_>::default(),
allow_v1: false,
}
}
pub fn body(mut self, body: &'a [u8]) -> Self {
self.body = body;
self
}
pub fn method(mut self, method: &'a str) -> Self {
self.method = method;
self
}
pub fn path(mut self, path: &'a str) -> Self {
assert!(
path.starts_with('/'),
"Invalid path \"{path}\" must start with '/'",
);
self.path = path;
self
}
pub fn header(mut self, key: &'a str, value: &'a [u8]) -> Self {
self.add_header(key, value);
self
}
pub fn add_header(&mut self, key: &'a str, value: &'a [u8]) {
self.headers.insert(HeaderName(key), value);
}
pub fn headers(mut self, headers: impl IntoIterator<Item = (&'a str, &'a [u8])>) -> Self {
self.headers
.extend(headers.into_iter().map(|(k, v)| (HeaderName(k), v)));
self
}
pub fn require_header(mut self, key: &'a str) -> Self {
self.required_headers.insert(HeaderName(key));
self
}
pub fn allow_v1(mut self, allow: bool) -> Self {
self.allow_v1 = allow;
self
}
pub fn verify(&self, tl_signature: &str) -> Result<(), Error> {
let (jws_header, header_b64, signature) = parse_tl_signature(tl_signature)?;
let public_key = match self.public_key {
PublicKey::Pem(pem) => openssl::parse_ec_public_key(pem),
PublicKey::Jwks(jwks) => openssl::find_and_parse_ec_jwk(&jws_header.kid, jwks),
}
.map_err(Error::InvalidKey)?;
if jws_header.alg != "ES512" {
return Err(Error::JwsError(anyhow!("unexpected header alg")));
}
if jws_header.tl_version.is_empty() || jws_header.tl_version == "1" {
if !self.allow_v1 {
return Err(Error::JwsError(anyhow!("v1 signature not allowed")));
}
let payload = format!("{header_b64}.{}", self.body.to_url_safe_base64());
return openssl::verify_es512(&public_key, payload.as_bytes(), &signature)
.map_err(Error::JwsError);
}
let ordered_headers = jws_header
.filter_headers(&self.headers)
.map_err(Error::JwsError)?;
if let Some(header) = self
.required_headers
.iter()
.find(|h| !ordered_headers.contains_key(*h))
{
return Err(Error::JwsError(anyhow!(
"signature is missing required header {}",
header
)));
}
let signing_payload =
build_v2_signing_payload(self.method, self.path, &ordered_headers, self.body, false);
let payload = format!("{header_b64}.{}", signing_payload.to_url_safe_base64());
openssl::verify_es512(&public_key, payload.as_bytes(), &signature)
.or_else(|e| {
let (path, slash) = match self.path {
p if p.ends_with('/') => (&p[..p.len() - 1], false),
p => (p, true),
};
let signing_payload =
build_v2_signing_payload(self.method, path, &ordered_headers, self.body, slash);
let payload = format!("{header_b64}.{}", signing_payload.to_url_safe_base64());
openssl::verify_es512(&public_key, payload.as_bytes(), &signature).map_err(|_| e)
})
.map_err(Error::JwsError)
}
}
pub(crate) fn parse_tl_signature(tl_signature: &str) -> Result<(JwsHeader, &str, Vec<u8>), Error> {
let (header_b64, signature_b64) = tl_signature
.split_once("..")
.ok_or_else(|| Error::JwsError(anyhow!("invalid signature format")))?;
let header: JwsHeader = serde_json::from_slice(
&header_b64
.decode_url_safe_base64()
.map_err(|e| Error::JwsError(anyhow!("header decode failed: {}", e)))?,
)
.map_err(|e| Error::JwsError(anyhow!("header decode failed: {}", e)))?;
let signature = signature_b64
.decode_url_safe_base64()
.map_err(|e| Error::JwsError(anyhow!("signature decode failed: {}", e)))?;
Ok((header, header_b64, signature))
}