use crate::primitives::{from_base64, to_base64};
use crate::wallet::{Counterparty, CreateHmacArgs, Protocol, SecurityLevel, WalletInterface};
use crate::{Error, Result};
use rand::RngCore;
pub const NONCE_PROTOCOL: &str = "server hmac";
const NONCE_RANDOM_SIZE: usize = 16;
const NONCE_HMAC_SIZE: usize = 32;
const NONCE_TOTAL_SIZE: usize = NONCE_RANDOM_SIZE + NONCE_HMAC_SIZE;
const NONCE_LEGACY_SIZE: usize = NONCE_RANDOM_SIZE + NONCE_RANDOM_SIZE;
pub async fn create_nonce<W: WalletInterface>(
wallet: &W,
counterparty: Option<&crate::primitives::PublicKey>,
originator: &str,
) -> Result<String> {
let mut random_bytes = [0u8; NONCE_RANDOM_SIZE];
rand::thread_rng().fill_bytes(&mut random_bytes);
let protocol = Protocol::new(SecurityLevel::App, NONCE_PROTOCOL);
let key_id = to_base64(&random_bytes);
let hmac_result = wallet
.create_hmac(
CreateHmacArgs {
data: random_bytes.to_vec(),
protocol_id: protocol,
key_id,
counterparty: counterparty.map(|pk| Counterparty::Other(pk.clone())),
},
originator,
)
.await?;
let mut nonce = Vec::with_capacity(NONCE_TOTAL_SIZE);
nonce.extend_from_slice(&random_bytes);
nonce.extend_from_slice(&hmac_result.hmac);
Ok(to_base64(&nonce))
}
pub async fn verify_nonce<W: WalletInterface>(
nonce: &str,
wallet: &W,
counterparty: Option<&crate::primitives::PublicKey>,
originator: &str,
) -> Result<bool> {
let nonce_bytes = from_base64(nonce)?;
let (random_bytes, hmac_bytes) = match nonce_bytes.len() {
NONCE_TOTAL_SIZE => (
&nonce_bytes[..NONCE_RANDOM_SIZE],
&nonce_bytes[NONCE_RANDOM_SIZE..NONCE_TOTAL_SIZE],
),
NONCE_LEGACY_SIZE => (
&nonce_bytes[..NONCE_RANDOM_SIZE],
&nonce_bytes[NONCE_RANDOM_SIZE..NONCE_LEGACY_SIZE],
),
n => {
return Err(Error::InvalidNonce(format!(
"Nonce size invalid: expected {} or {} bytes, got {}",
NONCE_TOTAL_SIZE, NONCE_LEGACY_SIZE, n
)))
}
};
let protocol = Protocol::new(SecurityLevel::App, NONCE_PROTOCOL);
let key_id = to_base64(random_bytes);
let hmac_result = wallet
.create_hmac(
CreateHmacArgs {
data: random_bytes.to_vec(),
protocol_id: protocol,
key_id,
counterparty: counterparty.map(|pk| Counterparty::Other(pk.clone())),
},
originator,
)
.await?;
let stored_len = hmac_bytes.len();
Ok(hmac_result.hmac[..stored_len] == *hmac_bytes)
}
pub fn validate_nonce_format(nonce: &str) -> Result<()> {
let bytes = from_base64(nonce)?;
if bytes.len() < NONCE_LEGACY_SIZE {
return Err(Error::InvalidNonce(format!(
"Nonce too short: expected at least {} bytes, got {}",
NONCE_LEGACY_SIZE,
bytes.len()
)));
}
Ok(())
}
pub fn get_nonce_random(nonce: &str) -> Result<Vec<u8>> {
let bytes = from_base64(nonce)?;
if bytes.len() < NONCE_RANDOM_SIZE {
return Err(Error::InvalidNonce(format!(
"Nonce too short: expected at least {} bytes, got {}",
NONCE_RANDOM_SIZE,
bytes.len()
)));
}
Ok(bytes[..NONCE_RANDOM_SIZE].to_vec())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_validate_nonce_format_canonical() {
let canonical = to_base64(&[0u8; 48]);
assert!(validate_nonce_format(&canonical).is_ok());
}
#[test]
fn test_validate_nonce_format_legacy_still_ok() {
let legacy = to_base64(&[0u8; 32]);
assert!(validate_nonce_format(&legacy).is_ok());
}
#[test]
fn test_validate_nonce_format_too_short() {
let short = to_base64(&[0u8; 16]);
assert!(validate_nonce_format(&short).is_err());
assert!(validate_nonce_format("not-valid-base64!!!").is_err());
}
#[test]
fn test_get_nonce_random() {
let mut full_nonce = [0u8; 48];
full_nonce[..16].copy_from_slice(&[1u8; 16]);
full_nonce[16..].copy_from_slice(&[2u8; 32]);
let nonce_str = to_base64(&full_nonce);
let random = get_nonce_random(&nonce_str).unwrap();
assert_eq!(random.len(), 16);
assert_eq!(random, vec![1u8; 16]);
}
use crate::primitives::PrivateKey;
use crate::wallet::ProtoWallet;
#[tokio::test]
async fn test_create_verify_roundtrip_canonical_48_bytes() {
let wallet = ProtoWallet::new(Some(
PrivateKey::from_hex(
"0000000000000000000000000000000000000000000000000000000000000001",
)
.unwrap(),
));
let nonce = create_nonce(&wallet, None, "test.app").await.unwrap();
let decoded = from_base64(&nonce).unwrap();
assert_eq!(
decoded.len(),
NONCE_TOTAL_SIZE,
"canonical nonce must be 48 bytes"
);
assert!(verify_nonce(&nonce, &wallet, None, "test.app")
.await
.unwrap());
}
#[tokio::test]
async fn test_verify_legacy_32_byte_nonce_still_passes() {
let wallet = ProtoWallet::new(Some(
PrivateKey::from_hex(
"0000000000000000000000000000000000000000000000000000000000000002",
)
.unwrap(),
));
let mut random = [0u8; NONCE_RANDOM_SIZE];
rand::thread_rng().fill_bytes(&mut random);
let hmac = wallet
.create_hmac(CreateHmacArgs {
data: random.to_vec(),
protocol_id: Protocol::new(SecurityLevel::App, NONCE_PROTOCOL),
key_id: to_base64(&random),
counterparty: None,
})
.unwrap();
let mut legacy = Vec::with_capacity(NONCE_LEGACY_SIZE);
legacy.extend_from_slice(&random);
legacy.extend_from_slice(&hmac.hmac[..NONCE_RANDOM_SIZE]);
let nonce = to_base64(&legacy);
assert_eq!(from_base64(&nonce).unwrap().len(), NONCE_LEGACY_SIZE);
assert!(verify_nonce(&nonce, &wallet, None, "test.app")
.await
.unwrap());
}
}