use crate::primitives::{from_base64, to_base64};
use crate::wallet::{Counterparty, CreateHmacArgs, Protocol, SecurityLevel, WalletInterface};
use crate::{Error, Result};
use rand::RngCore;
pub const NONCE_PROTOCOL: &str = "server hmac";
const NONCE_RANDOM_SIZE: usize = 16;
const NONCE_TOTAL_SIZE: usize = 32;
pub async fn create_nonce<W: WalletInterface>(
wallet: &W,
counterparty: Option<&crate::primitives::PublicKey>,
originator: &str,
) -> Result<String> {
let mut random_bytes = [0u8; NONCE_RANDOM_SIZE];
rand::thread_rng().fill_bytes(&mut random_bytes);
let protocol = Protocol::new(SecurityLevel::App, NONCE_PROTOCOL);
let key_id = to_base64(&random_bytes);
let hmac_result = wallet
.create_hmac(
CreateHmacArgs {
data: random_bytes.to_vec(),
protocol_id: protocol,
key_id,
counterparty: counterparty.map(|pk| Counterparty::Other(pk.clone())),
},
originator,
)
.await?;
let mut nonce = Vec::with_capacity(NONCE_TOTAL_SIZE);
nonce.extend_from_slice(&random_bytes);
nonce.extend_from_slice(&hmac_result.hmac[..NONCE_RANDOM_SIZE]);
Ok(to_base64(&nonce))
}
pub async fn verify_nonce<W: WalletInterface>(
nonce: &str,
wallet: &W,
counterparty: Option<&crate::primitives::PublicKey>,
originator: &str,
) -> Result<bool> {
let nonce_bytes = from_base64(nonce)?;
if nonce_bytes.len() < NONCE_TOTAL_SIZE {
return Err(Error::InvalidNonce(format!(
"Nonce too short: expected {} bytes, got {}",
NONCE_TOTAL_SIZE,
nonce_bytes.len()
)));
}
let random_bytes = &nonce_bytes[..NONCE_RANDOM_SIZE];
let hmac_bytes = &nonce_bytes[NONCE_RANDOM_SIZE..NONCE_TOTAL_SIZE];
let protocol = Protocol::new(SecurityLevel::App, NONCE_PROTOCOL);
let key_id = to_base64(random_bytes);
let hmac_result = wallet
.create_hmac(
CreateHmacArgs {
data: random_bytes.to_vec(),
protocol_id: protocol,
key_id,
counterparty: counterparty.map(|pk| Counterparty::Other(pk.clone())),
},
originator,
)
.await?;
Ok(hmac_result.hmac[..NONCE_RANDOM_SIZE] == *hmac_bytes)
}
pub fn validate_nonce_format(nonce: &str) -> Result<()> {
let bytes = from_base64(nonce)?;
if bytes.len() < NONCE_TOTAL_SIZE {
return Err(Error::InvalidNonce(format!(
"Nonce too short: expected at least {} bytes, got {}",
NONCE_TOTAL_SIZE,
bytes.len()
)));
}
Ok(())
}
pub fn get_nonce_random(nonce: &str) -> Result<Vec<u8>> {
let bytes = from_base64(nonce)?;
if bytes.len() < NONCE_RANDOM_SIZE {
return Err(Error::InvalidNonce(format!(
"Nonce too short: expected at least {} bytes, got {}",
NONCE_RANDOM_SIZE,
bytes.len()
)));
}
Ok(bytes[..NONCE_RANDOM_SIZE].to_vec())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_validate_nonce_format() {
let valid_nonce = to_base64(&[0u8; 32]);
assert!(validate_nonce_format(&valid_nonce).is_ok());
let short_nonce = to_base64(&[0u8; 16]);
assert!(validate_nonce_format(&short_nonce).is_err());
assert!(validate_nonce_format("not-valid-base64!!!").is_err());
}
#[test]
fn test_get_nonce_random() {
let mut full_nonce = [0u8; 32];
full_nonce[..16].copy_from_slice(&[1u8; 16]); full_nonce[16..].copy_from_slice(&[2u8; 16]);
let nonce_str = to_base64(&full_nonce);
let random = get_nonce_random(&nonce_str).unwrap();
assert_eq!(random.len(), 16);
assert_eq!(random, vec![1u8; 16]);
}
}