use std::fmt;
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
use base64::Engine as _;
use ed25519_dalek::{Signature, Verifier as _, VerifyingKey};
use libcrux_ml_dsa::ml_dsa_65::{
self, MLDSA65Signature, MLDSA65SigningKey, MLDSA65VerificationKey,
};
use rand::RngCore as _;
use sha2::{Digest as _, Sha256};
use crate::ids::{AgentPubkey, NetworkId, Nonce};
pub const SIGNATURE_HEADER: &str = "Parley-Signature";
pub const SIGNATURE_VERSION: u32 = 2;
pub const ML_DSA_PUBKEY_BYTES: usize = 1952;
pub const ML_DSA_SIG_BYTES: usize = 3309;
const ML_DSA_CONTEXT: &[u8] = b"parley-auth-v2";
pub const EMPTY_BODY_SHA256: &str = "47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU";
#[must_use]
pub fn body_sha256_b64url(body: &[u8]) -> String {
let mut hasher = Sha256::new();
hasher.update(body);
let digest = hasher.finalize();
URL_SAFE_NO_PAD.encode(digest)
}
#[must_use]
#[allow(clippy::too_many_arguments)]
pub fn canonical_string(
method: &str,
path: &str,
canonical_query: &str,
ts: i64,
nonce: &Nonce,
agent: &AgentPubkey,
network: &NetworkId,
body_sha256_b64url: &str,
) -> String {
let method_upper = method.to_ascii_uppercase();
format!(
"{method_upper}\n{path}\n{canonical_query}\n{ts}\n{nonce}\n{agent}\n{network}\n{body_sha256_b64url}"
)
}
#[must_use]
pub fn canonical_query_string(raw: &str) -> String {
if raw.is_empty() {
return String::new();
}
let mut pairs: Vec<(String, String)> = raw
.split('&')
.filter(|s| !s.is_empty())
.map(|p| match p.split_once('=') {
Some((k, v)) => (
percent_decode(k).unwrap_or_else(|_| k.to_owned()),
percent_decode(v).unwrap_or_else(|_| v.to_owned()),
),
None => (
percent_decode(p).unwrap_or_else(|_| p.to_owned()),
String::new(),
),
})
.collect();
pairs.sort();
pairs
.into_iter()
.map(|(k, v)| format!("{}={}", percent_encode(&k), percent_encode(&v)))
.collect::<Vec<_>>()
.join("&")
}
fn percent_encode(s: &str) -> String {
let mut out = String::with_capacity(s.len());
for &b in s.as_bytes() {
if b.is_ascii_alphanumeric() || matches!(b, b'-' | b'_' | b'.' | b'~') {
out.push(b as char);
} else {
out.push_str(&format!("%{b:02X}"));
}
}
out
}
fn percent_decode(s: &str) -> Result<String, ()> {
let bytes = s.as_bytes();
let mut out = Vec::with_capacity(bytes.len());
let mut i = 0;
while i < bytes.len() {
if bytes[i] == b'%' {
if i + 2 >= bytes.len() {
return Err(());
}
let hi = hex_val(bytes[i + 1])?;
let lo = hex_val(bytes[i + 2])?;
out.push((hi << 4) | lo);
i += 3;
} else {
out.push(bytes[i]);
i += 1;
}
}
String::from_utf8(out).map_err(|_| ())
}
fn hex_val(b: u8) -> Result<u8, ()> {
match b {
b'0'..=b'9' => Ok(b - b'0'),
b'a'..=b'f' => Ok(b - b'a' + 10),
b'A'..=b'F' => Ok(b - b'A' + 10),
_ => Err(()),
}
}
#[must_use]
pub fn build_header_value(
agent: &AgentPubkey,
ts: i64,
nonce: &Nonce,
network: &NetworkId,
sig_bytes: &[u8; 64],
mldsa_sig: &[u8],
) -> String {
format!(
"v={v}, agent={agent}, ts={ts}, nonce={nonce}, network={network}, sig={sig}, mldsa_sig={mldsa}",
v = SIGNATURE_VERSION,
sig = URL_SAFE_NO_PAD.encode(sig_bytes),
mldsa = URL_SAFE_NO_PAD.encode(mldsa_sig),
)
}
#[derive(Debug, Clone)]
pub struct ParsedSignature {
pub v: u32,
pub agent: AgentPubkey,
pub ts: i64,
pub nonce: Nonce,
pub network: NetworkId,
pub sig: [u8; 64],
pub mldsa_sig: Option<Vec<u8>>,
}
#[derive(Debug, thiserror::Error)]
pub enum SignatureParseError {
#[error("missing field: {0}")]
MissingField(&'static str),
#[error("malformed pair: {0}")]
MalformedPair(String),
#[error("invalid value for {field}: {reason}")]
InvalidValue { field: &'static str, reason: String },
#[error("duplicate field: {0}")]
DuplicateField(&'static str),
}
pub fn parse_header_value(raw: &str) -> Result<ParsedSignature, SignatureParseError> {
let mut v: Option<u32> = None;
let mut agent: Option<AgentPubkey> = None;
let mut ts: Option<i64> = None;
let mut nonce: Option<Nonce> = None;
let mut network: Option<NetworkId> = None;
let mut sig: Option<[u8; 64]> = None;
let mut mldsa_sig: Option<Vec<u8>> = None;
for raw_pair in raw.split(',') {
let pair = raw_pair.trim();
if pair.is_empty() {
continue;
}
let (key, value) = pair
.split_once('=')
.ok_or_else(|| SignatureParseError::MalformedPair(pair.to_owned()))?;
let value = value.trim();
match key.trim() {
"v" => {
if v.is_some() {
return Err(SignatureParseError::DuplicateField("v"));
}
v = Some(value.parse().map_err(|e: std::num::ParseIntError| {
SignatureParseError::InvalidValue {
field: "v",
reason: e.to_string(),
}
})?);
}
"agent" => {
if agent.is_some() {
return Err(SignatureParseError::DuplicateField("agent"));
}
agent = Some(value.parse().map_err(|e: crate::CoreError| {
SignatureParseError::InvalidValue {
field: "agent",
reason: e.to_string(),
}
})?);
}
"ts" => {
if ts.is_some() {
return Err(SignatureParseError::DuplicateField("ts"));
}
ts = Some(value.parse().map_err(|e: std::num::ParseIntError| {
SignatureParseError::InvalidValue {
field: "ts",
reason: e.to_string(),
}
})?);
}
"nonce" => {
if nonce.is_some() {
return Err(SignatureParseError::DuplicateField("nonce"));
}
nonce = Some(value.parse().map_err(|e: crate::CoreError| {
SignatureParseError::InvalidValue {
field: "nonce",
reason: e.to_string(),
}
})?);
}
"network" => {
if network.is_some() {
return Err(SignatureParseError::DuplicateField("network"));
}
network = Some(value.parse().map_err(|e: crate::CoreError| {
SignatureParseError::InvalidValue {
field: "network",
reason: e.to_string(),
}
})?);
}
"sig" => {
if sig.is_some() {
return Err(SignatureParseError::DuplicateField("sig"));
}
let decoded = URL_SAFE_NO_PAD.decode(value).map_err(|e| {
SignatureParseError::InvalidValue {
field: "sig",
reason: e.to_string(),
}
})?;
let arr: [u8; 64] =
decoded
.try_into()
.map_err(|d: Vec<u8>| SignatureParseError::InvalidValue {
field: "sig",
reason: format!("expected 64 bytes, got {}", d.len()),
})?;
sig = Some(arr);
}
"mldsa_sig" => {
if mldsa_sig.is_some() {
return Err(SignatureParseError::DuplicateField("mldsa_sig"));
}
let decoded = URL_SAFE_NO_PAD.decode(value).map_err(|e| {
SignatureParseError::InvalidValue {
field: "mldsa_sig",
reason: e.to_string(),
}
})?;
mldsa_sig = Some(decoded);
}
other => {
let _ = other;
}
}
}
Ok(ParsedSignature {
v: v.ok_or(SignatureParseError::MissingField("v"))?,
agent: agent.ok_or(SignatureParseError::MissingField("agent"))?,
ts: ts.ok_or(SignatureParseError::MissingField("ts"))?,
nonce: nonce.ok_or(SignatureParseError::MissingField("nonce"))?,
network: network.ok_or(SignatureParseError::MissingField("network"))?,
sig: sig.ok_or(SignatureParseError::MissingField("sig"))?,
mldsa_sig,
})
}
pub fn verify_signature(
agent: &AgentPubkey,
canonical: &str,
sig: &[u8; 64],
) -> Result<(), SignatureVerifyError> {
let key = VerifyingKey::from_bytes(agent.as_bytes())
.map_err(|e| SignatureVerifyError::BadKey(e.to_string()))?;
let signature = Signature::from_bytes(sig);
key.verify(canonical.as_bytes(), &signature)
.map_err(|_| SignatureVerifyError::BadSignature)
}
#[derive(Debug, thiserror::Error)]
pub enum SignatureVerifyError {
#[error("agent pubkey is not a valid Ed25519 verifying key: {0}")]
BadKey(String),
#[error("signature does not verify")]
BadSignature,
}
pub fn ml_dsa_sign(
signing_key: &MLDSA65SigningKey,
canonical: &str,
) -> Result<Vec<u8>, MlDsaError> {
let mut randomness = [0u8; 32];
rand::thread_rng().fill_bytes(&mut randomness);
let sig = ml_dsa_65::sign(
signing_key,
canonical.as_bytes(),
ML_DSA_CONTEXT,
randomness,
)
.map_err(|_| MlDsaError::Sign)?;
Ok(sig.as_slice().to_vec())
}
pub fn ml_dsa_verify(
pubkey_bytes: &[u8],
canonical: &str,
sig_bytes: &[u8],
) -> Result<(), MlDsaError> {
let pk: [u8; ML_DSA_PUBKEY_BYTES] = pubkey_bytes.try_into().map_err(|_| MlDsaError::BadKey)?;
let sig: [u8; ML_DSA_SIG_BYTES] = sig_bytes.try_into().map_err(|_| MlDsaError::BadSignature)?;
let vk = MLDSA65VerificationKey::new(pk);
let signature = MLDSA65Signature::new(sig);
ml_dsa_65::verify(&vk, canonical.as_bytes(), ML_DSA_CONTEXT, &signature)
.map_err(|_| MlDsaError::BadSignature)
}
#[derive(Debug, thiserror::Error)]
pub enum MlDsaError {
#[error("ML-DSA signing failed")]
Sign,
#[error("ML-DSA verification key is malformed (wrong length)")]
BadKey,
#[error("ML-DSA signature does not verify or is malformed")]
BadSignature,
}
impl fmt::Display for ParsedSignature {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"v={}, agent={}, ts={}, nonce={}, network={}",
self.v, self.agent, self.ts, self.nonce, self.network
)
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
#[test]
fn empty_body_sha_constant_matches_computed() {
assert_eq!(body_sha256_b64url(b""), EMPTY_BODY_SHA256);
}
#[test]
fn canonical_query_sorts_and_encodes() {
assert_eq!(canonical_query_string(""), "");
assert_eq!(canonical_query_string("b=2&a=1"), "a=1&b=2");
assert_eq!(canonical_query_string("k=hello world"), "k=hello%20world");
assert_eq!(canonical_query_string("k="), "k=");
}
#[test]
fn canonical_string_format_is_eight_lines() {
let agent: AgentPubkey = "u9PqJ4gK2mZ8t6nVxR3hB1cW7yE5dF0aQ4sT2lN6oU8"
.parse()
.unwrap();
let nonce: Nonce = "F4Yk8vN2j5QwK3zB1aR9oA".parse().unwrap();
let network: NetworkId = "parley-mainnet".parse().unwrap();
let s = canonical_string(
"GET",
"/v1/blobs/abc",
"",
1715299200,
&nonce,
&agent,
&network,
EMPTY_BODY_SHA256,
);
assert_eq!(s.lines().count(), 8);
assert!(s.starts_with("GET\n/v1/blobs/abc\n\n1715299200\n"));
}
#[test]
fn header_roundtrips() {
let agent: AgentPubkey = "u9PqJ4gK2mZ8t6nVxR3hB1cW7yE5dF0aQ4sT2lN6oU8"
.parse()
.unwrap();
let nonce: Nonce = "F4Yk8vN2j5QwK3zB1aR9oA".parse().unwrap();
let network: NetworkId = "parley-mainnet".parse().unwrap();
let sig = [7u8; 64];
let mldsa = vec![3u8; ML_DSA_SIG_BYTES];
let header = build_header_value(&agent, 1715299200, &nonce, &network, &sig, &mldsa);
let parsed = parse_header_value(&header).unwrap();
assert_eq!(parsed.v, SIGNATURE_VERSION);
assert_eq!(parsed.agent, agent);
assert_eq!(parsed.ts, 1715299200);
assert_eq!(parsed.nonce, nonce);
assert_eq!(parsed.network, network);
assert_eq!(parsed.sig, sig);
assert_eq!(parsed.mldsa_sig.as_deref(), Some(mldsa.as_slice()));
}
#[test]
fn header_tolerates_no_space_after_comma() {
let agent: AgentPubkey = "u9PqJ4gK2mZ8t6nVxR3hB1cW7yE5dF0aQ4sT2lN6oU8"
.parse()
.unwrap();
let nonce: Nonce = "F4Yk8vN2j5QwK3zB1aR9oA".parse().unwrap();
let network: NetworkId = "parley-mainnet".parse().unwrap();
let sig = [7u8; 64];
let sig_b64 = URL_SAFE_NO_PAD.encode(sig);
let header =
format!("v=1,agent={agent},ts=1,nonce={nonce},network={network},sig={sig_b64}");
let parsed = parse_header_value(&header).unwrap();
assert_eq!(parsed.v, 1);
}
#[test]
fn sign_then_verify_roundtrip() {
use ed25519_dalek::{Signer as _, SigningKey};
let signing = SigningKey::from_bytes(&[42u8; 32]);
let agent = AgentPubkey::from_bytes(*signing.verifying_key().as_bytes());
let canonical = "GET\n/healthz\n\n0\n_\n_\n_\n_";
let sig = signing.sign(canonical.as_bytes()).to_bytes();
verify_signature(&agent, canonical, &sig).unwrap();
let mut bad = sig;
bad[0] ^= 1;
assert!(verify_signature(&agent, canonical, &bad).is_err());
}
#[test]
fn ml_dsa_sign_verify_roundtrip() {
use crate::keys::derive_auth_mldsa;
let kp = derive_auth_mldsa(&[42u8; crate::keys::SEED_BYTES]);
let pk = kp.verification_key.as_slice();
let canonical = "GET\n/healthz\n\n0\n_\n_\n_\n_";
let sig = ml_dsa_sign(&kp.signing_key, canonical).unwrap();
assert_eq!(sig.len(), ML_DSA_SIG_BYTES);
ml_dsa_verify(pk, canonical, &sig).unwrap();
assert!(ml_dsa_verify(pk, "GET\n/other\n\n0\n_\n_\n_\n_", &sig).is_err());
let mut bad = sig.clone();
bad[0] ^= 1;
assert!(ml_dsa_verify(pk, canonical, &bad).is_err());
assert!(ml_dsa_verify(&pk[..10], canonical, &sig).is_err());
assert!(ml_dsa_verify(pk, canonical, &sig[..10]).is_err());
}
}