mod handshake;
mod replay;
mod session;
use ring::aead::{Aad, CHACHA20_POLY1305, LessSafeKey, Nonce, UnboundKey};
use std::fmt;
use thiserror::Error;
pub use handshake::HandshakeState;
pub use replay::ReplayWindow;
pub use session::NoiseSession;
pub(crate) const PROTOCOL_NAME_IK: &[u8] = b"Noise_IK_secp256k1_ChaChaPoly_SHA256";
pub(crate) const PROTOCOL_NAME_XK: &[u8] = b"Noise_XK_secp256k1_ChaChaPoly_SHA256";
pub const MAX_MESSAGE_SIZE: usize = 65535;
pub const TAG_SIZE: usize = 16;
pub const PUBKEY_SIZE: usize = 33;
pub const EPOCH_SIZE: usize = 8;
pub const EPOCH_ENCRYPTED_SIZE: usize = EPOCH_SIZE + TAG_SIZE;
pub const HANDSHAKE_MSG1_SIZE: usize = PUBKEY_SIZE + PUBKEY_SIZE + TAG_SIZE + EPOCH_ENCRYPTED_SIZE;
pub const HANDSHAKE_MSG2_SIZE: usize = PUBKEY_SIZE + EPOCH_ENCRYPTED_SIZE;
pub const XK_HANDSHAKE_MSG1_SIZE: usize = PUBKEY_SIZE;
pub const XK_HANDSHAKE_MSG2_SIZE: usize = PUBKEY_SIZE + EPOCH_ENCRYPTED_SIZE;
pub const XK_HANDSHAKE_MSG3_SIZE: usize = PUBKEY_SIZE + TAG_SIZE + EPOCH_ENCRYPTED_SIZE;
pub const REPLAY_WINDOW_SIZE: usize = 2048;
#[derive(Debug, Error)]
pub enum NoiseError {
#[error("handshake not complete")]
HandshakeNotComplete,
#[error("handshake already complete")]
HandshakeAlreadyComplete,
#[error("wrong handshake state: expected {expected}, got {got}")]
WrongState { expected: String, got: String },
#[error("invalid public key")]
InvalidPublicKey,
#[error("decryption failed")]
DecryptionFailed,
#[error("encryption failed")]
EncryptionFailed,
#[error("message too large: {size} > {max}")]
MessageTooLarge { size: usize, max: usize },
#[error("message too short: expected at least {expected}, got {got}")]
MessageTooShort { expected: usize, got: usize },
#[error("nonce overflow")]
NonceOverflow,
#[error("replay detected: counter {0} already seen or too old")]
ReplayDetected(u64),
#[error("secp256k1 error: {0}")]
Secp256k1(#[from] secp256k1::Error),
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum HandshakeRole {
Initiator,
Responder,
}
impl fmt::Display for HandshakeRole {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
HandshakeRole::Initiator => write!(f, "initiator"),
HandshakeRole::Responder => write!(f, "responder"),
}
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum NoisePattern {
Ik,
Xk,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum HandshakeProgress {
Initial,
Message1Done,
Message2Done,
Complete,
}
impl fmt::Display for HandshakeProgress {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
HandshakeProgress::Initial => write!(f, "initial"),
HandshakeProgress::Message1Done => write!(f, "message1_done"),
HandshakeProgress::Message2Done => write!(f, "message2_done"),
HandshakeProgress::Complete => write!(f, "complete"),
}
}
}
pub struct CipherState {
key: [u8; 32],
cipher: Option<LessSafeKey>,
pub(super) nonce: u64,
has_key: bool,
}
impl Clone for CipherState {
fn clone(&self) -> Self {
let cipher = if self.has_key {
Self::build_cipher(&self.key)
} else {
None
};
Self {
key: self.key,
cipher,
nonce: self.nonce,
has_key: self.has_key,
}
}
}
impl CipherState {
pub(crate) fn new(key: [u8; 32]) -> Self {
let cipher = Self::build_cipher(&key);
Self {
key,
cipher,
nonce: 0,
has_key: true,
}
}
pub(super) fn empty() -> Self {
Self {
key: [0u8; 32],
cipher: None,
nonce: 0,
has_key: false,
}
}
pub(super) fn initialize_key(&mut self, key: [u8; 32]) {
self.key = key;
self.cipher = Self::build_cipher(&key);
self.nonce = 0;
self.has_key = true;
}
fn build_cipher(key: &[u8; 32]) -> Option<LessSafeKey> {
UnboundKey::new(&CHACHA20_POLY1305, key)
.ok()
.map(LessSafeKey::new)
}
pub fn encrypt(&mut self, plaintext: &[u8]) -> Result<Vec<u8>, NoiseError> {
if !self.has_key {
return Ok(plaintext.to_vec());
}
if plaintext.len() > MAX_MESSAGE_SIZE - TAG_SIZE {
return Err(NoiseError::MessageTooLarge {
size: plaintext.len(),
max: MAX_MESSAGE_SIZE - TAG_SIZE,
});
}
let counter = self.advance_nonce()?;
seal(self.cipher.as_ref(), counter, &[], plaintext)
}
pub fn decrypt(&mut self, ciphertext: &[u8]) -> Result<Vec<u8>, NoiseError> {
if !self.has_key {
return Ok(ciphertext.to_vec());
}
if ciphertext.len() < TAG_SIZE {
return Err(NoiseError::MessageTooShort {
expected: TAG_SIZE,
got: ciphertext.len(),
});
}
let counter = self.advance_nonce()?;
open(self.cipher.as_ref(), counter, &[], ciphertext)
}
pub fn decrypt_with_counter(
&self,
ciphertext: &[u8],
counter: u64,
) -> Result<Vec<u8>, NoiseError> {
if !self.has_key {
return Ok(ciphertext.to_vec());
}
if ciphertext.len() < TAG_SIZE {
return Err(NoiseError::MessageTooShort {
expected: TAG_SIZE,
got: ciphertext.len(),
});
}
open(self.cipher.as_ref(), counter, &[], ciphertext)
}
pub fn encrypt_with_aad(
&mut self,
plaintext: &[u8],
aad: &[u8],
) -> Result<Vec<u8>, NoiseError> {
if !self.has_key {
return Ok(plaintext.to_vec());
}
if plaintext.len() > MAX_MESSAGE_SIZE - TAG_SIZE {
return Err(NoiseError::MessageTooLarge {
size: plaintext.len(),
max: MAX_MESSAGE_SIZE - TAG_SIZE,
});
}
let counter = self.advance_nonce()?;
seal(self.cipher.as_ref(), counter, aad, plaintext)
}
pub fn encrypt_with_counter(
&self,
plaintext: &[u8],
counter: u64,
) -> Result<Vec<u8>, NoiseError> {
if !self.has_key {
return Ok(plaintext.to_vec());
}
if plaintext.len() > MAX_MESSAGE_SIZE - TAG_SIZE {
return Err(NoiseError::MessageTooLarge {
size: plaintext.len(),
max: MAX_MESSAGE_SIZE - TAG_SIZE,
});
}
seal(self.cipher.as_ref(), counter, &[], plaintext)
}
pub fn encrypt_with_counter_and_aad(
&self,
plaintext: &[u8],
counter: u64,
aad: &[u8],
) -> Result<Vec<u8>, NoiseError> {
if !self.has_key {
return Ok(plaintext.to_vec());
}
if plaintext.len() > MAX_MESSAGE_SIZE - TAG_SIZE {
return Err(NoiseError::MessageTooLarge {
size: plaintext.len(),
max: MAX_MESSAGE_SIZE - TAG_SIZE,
});
}
seal(self.cipher.as_ref(), counter, aad, plaintext)
}
pub fn cipher_clone(&self) -> Option<LessSafeKey> {
if self.has_key {
Self::build_cipher(&self.key)
} else {
None
}
}
pub fn decrypt_with_counter_and_aad(
&self,
ciphertext: &[u8],
counter: u64,
aad: &[u8],
) -> Result<Vec<u8>, NoiseError> {
if !self.has_key {
return Ok(ciphertext.to_vec());
}
if ciphertext.len() < TAG_SIZE {
return Err(NoiseError::MessageTooShort {
expected: TAG_SIZE,
got: ciphertext.len(),
});
}
open(self.cipher.as_ref(), counter, aad, ciphertext)
}
pub fn decrypt_with_counter_and_aad_in_place(
&self,
buf: &mut [u8],
counter: u64,
aad: &[u8],
) -> Result<usize, NoiseError> {
if !self.has_key {
return Ok(buf.len());
}
open_in_place(self.cipher.as_ref(), counter, aad, buf)
}
pub(crate) fn counter_to_nonce(counter: u64) -> Nonce {
let mut nonce_bytes = [0u8; 12];
nonce_bytes[4..12].copy_from_slice(&counter.to_le_bytes());
Nonce::assume_unique_for_key(nonce_bytes)
}
fn advance_nonce(&mut self) -> Result<u64, NoiseError> {
if self.nonce == u64::MAX {
return Err(NoiseError::NonceOverflow);
}
let n = self.nonce;
self.nonce += 1;
Ok(n)
}
pub fn nonce(&self) -> u64 {
self.nonce
}
pub fn has_key(&self) -> bool {
self.has_key
}
}
impl fmt::Debug for CipherState {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("CipherState")
.field("nonce", &self.nonce)
.field("has_key", &self.has_key)
.field("key", &"[redacted]")
.finish()
}
}
pub(crate) fn seal(
cipher: Option<&LessSafeKey>,
counter: u64,
aad: &[u8],
plaintext: &[u8],
) -> Result<Vec<u8>, NoiseError> {
let cipher = cipher.ok_or(NoiseError::EncryptionFailed)?;
let mut buf = Vec::with_capacity(plaintext.len() + TAG_SIZE);
buf.extend_from_slice(plaintext);
let nonce = CipherState::counter_to_nonce(counter);
cipher
.seal_in_place_append_tag(nonce, Aad::from(aad), &mut buf)
.map_err(|_| NoiseError::EncryptionFailed)?;
Ok(buf)
}
pub(crate) fn open(
cipher: Option<&LessSafeKey>,
counter: u64,
aad: &[u8],
ciphertext: &[u8],
) -> Result<Vec<u8>, NoiseError> {
let cipher = cipher.ok_or(NoiseError::DecryptionFailed)?;
let mut buf = ciphertext.to_vec();
let nonce = CipherState::counter_to_nonce(counter);
let plaintext_len = cipher
.open_in_place(nonce, Aad::from(aad), &mut buf)
.map_err(|_| NoiseError::DecryptionFailed)?
.len();
buf.truncate(plaintext_len);
Ok(buf)
}
pub(crate) fn open_in_place(
cipher: Option<&LessSafeKey>,
counter: u64,
aad: &[u8],
buf: &mut [u8],
) -> Result<usize, NoiseError> {
let cipher = cipher.ok_or(NoiseError::DecryptionFailed)?;
if buf.len() < TAG_SIZE {
return Err(NoiseError::MessageTooShort {
expected: TAG_SIZE,
got: buf.len(),
});
}
let nonce = CipherState::counter_to_nonce(counter);
let plaintext = cipher
.open_in_place(nonce, Aad::from(aad), buf)
.map_err(|_| NoiseError::DecryptionFailed)?;
Ok(plaintext.len())
}
#[cfg(test)]
mod tests;