use std::fmt;
use snow::params::NoiseParams;
use snow::resolvers::{CryptoResolver, DefaultResolver};
use snow::{Builder, HandshakeState, TransportState};
use crate::crypto::{PublicKey, SecretKey};
use crate::error::CryptoError;
pub const MAX_NOISE_MSG_SIZE: usize = 65535;
const NOISE_PATTERN: &str = "Noise_NK_25519_ChaChaPoly_BLAKE2s";
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum HandshakeRole {
Initiator,
Responder,
}
pub enum NoiseState {
Handshake(Box<HandshakeState>),
Transport(Box<TransportState>),
Failed,
}
impl fmt::Debug for NoiseState {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Handshake(_) => write!(f, "NoiseState::Handshake"),
Self::Transport(_) => write!(f, "NoiseState::Transport"),
Self::Failed => write!(f, "NoiseState::Failed"),
}
}
}
struct TriglavResolver {
default: DefaultResolver,
}
impl TriglavResolver {
fn new() -> Self {
Self {
default: DefaultResolver,
}
}
}
impl CryptoResolver for TriglavResolver {
fn resolve_rng(&self) -> Option<Box<dyn snow::types::Random>> {
self.default.resolve_rng()
}
fn resolve_dh(&self, choice: &snow::params::DHChoice) -> Option<Box<dyn snow::types::Dh>> {
self.default.resolve_dh(choice)
}
fn resolve_hash(
&self,
choice: &snow::params::HashChoice,
) -> Option<Box<dyn snow::types::Hash>> {
match choice {
snow::params::HashChoice::Blake2s => Some(Box::new(Blake3Hash::default())),
_ => self.default.resolve_hash(choice),
}
}
fn resolve_cipher(
&self,
choice: &snow::params::CipherChoice,
) -> Option<Box<dyn snow::types::Cipher>> {
self.default.resolve_cipher(choice)
}
}
#[derive(Default)]
struct Blake3Hash {
hasher: blake3::Hasher,
}
impl snow::types::Hash for Blake3Hash {
fn name(&self) -> &'static str {
"BLAKE3"
}
fn block_len(&self) -> usize {
64 }
fn hash_len(&self) -> usize {
32 }
fn reset(&mut self) {
self.hasher = blake3::Hasher::new();
}
fn input(&mut self, data: &[u8]) {
self.hasher.update(data);
}
fn result(&mut self, out: &mut [u8]) {
let hash = self.hasher.finalize();
let hash_bytes = hash.as_bytes();
let len = self.hash_len().min(out.len());
out[..len].copy_from_slice(&hash_bytes[..len]);
}
}
pub struct NoiseSession {
state: NoiseState,
role: HandshakeRole,
remote_public: Option<PublicKey>,
handshake_complete: bool,
}
impl NoiseSession {
pub fn new_initiator(
local_secret: &SecretKey,
remote_public: &PublicKey,
) -> Result<Self, CryptoError> {
let params: NoiseParams = NOISE_PATTERN
.parse()
.map_err(|e| CryptoError::NoiseProtocol(format!("invalid pattern: {e}")))?;
let secret_bytes = local_secret.as_bytes();
let builder = Builder::with_resolver(params, Box::new(TriglavResolver::new()))
.local_private_key(&secret_bytes)
.remote_public_key(remote_public.as_bytes());
let handshake = builder
.build_initiator()
.map_err(|e| CryptoError::NoiseProtocol(format!("build initiator failed: {e}")))?;
Ok(Self {
state: NoiseState::Handshake(Box::new(handshake)),
role: HandshakeRole::Initiator,
remote_public: Some(*remote_public),
handshake_complete: false,
})
}
pub fn new_responder(local_secret: &SecretKey) -> Result<Self, CryptoError> {
let params: NoiseParams = NOISE_PATTERN
.parse()
.map_err(|e| CryptoError::NoiseProtocol(format!("invalid pattern: {e}")))?;
let secret_bytes = local_secret.as_bytes();
let builder = Builder::with_resolver(params, Box::new(TriglavResolver::new()))
.local_private_key(&secret_bytes);
let handshake = builder
.build_responder()
.map_err(|e| CryptoError::NoiseProtocol(format!("build responder failed: {e}")))?;
Ok(Self {
state: NoiseState::Handshake(Box::new(handshake)),
role: HandshakeRole::Responder,
remote_public: None,
handshake_complete: false,
})
}
pub fn is_handshake_complete(&self) -> bool {
self.handshake_complete
}
pub fn is_transport(&self) -> bool {
matches!(self.state, NoiseState::Transport(_))
}
pub fn role(&self) -> HandshakeRole {
self.role
}
pub fn remote_public(&self) -> Option<&PublicKey> {
self.remote_public.as_ref()
}
pub fn write_handshake(&mut self, payload: &[u8]) -> Result<Vec<u8>, CryptoError> {
match &mut self.state {
NoiseState::Handshake(hs) => {
let mut buf = vec![0u8; MAX_NOISE_MSG_SIZE];
let len = hs
.write_message(payload, &mut buf)
.map_err(|e| CryptoError::NoiseProtocol(format!("write handshake: {e}")))?;
buf.truncate(len);
if hs.is_handshake_finished() {
self.complete_handshake()?;
}
Ok(buf)
}
NoiseState::Transport(_) => Err(CryptoError::NoiseProtocol(
"already in transport mode".into(),
)),
NoiseState::Failed => Err(CryptoError::NoiseProtocol("session failed".into())),
}
}
pub fn read_handshake(&mut self, message: &[u8]) -> Result<Vec<u8>, CryptoError> {
match &mut self.state {
NoiseState::Handshake(hs) => {
let mut buf = vec![0u8; MAX_NOISE_MSG_SIZE];
let len = hs
.read_message(message, &mut buf)
.map_err(|e| CryptoError::NoiseProtocol(format!("read handshake: {e}")))?;
buf.truncate(len);
if hs.is_handshake_finished() {
self.complete_handshake()?;
}
Ok(buf)
}
NoiseState::Transport(_) => Err(CryptoError::NoiseProtocol(
"already in transport mode".into(),
)),
NoiseState::Failed => Err(CryptoError::NoiseProtocol("session failed".into())),
}
}
fn complete_handshake(&mut self) -> Result<(), CryptoError> {
let state = std::mem::replace(&mut self.state, NoiseState::Failed);
match state {
NoiseState::Handshake(hs) => {
if self.remote_public.is_none() {
if let Some(rs) = hs.get_remote_static() {
let mut key = [0u8; 32];
key.copy_from_slice(rs);
self.remote_public = Some(PublicKey(key));
}
}
let transport = hs
.into_transport_mode()
.map_err(|e| CryptoError::NoiseProtocol(format!("transport mode: {e}")))?;
self.state = NoiseState::Transport(Box::new(transport));
self.handshake_complete = true;
Ok(())
}
_ => Err(CryptoError::NoiseProtocol("not in handshake mode".into())),
}
}
pub fn encrypt(&mut self, plaintext: &[u8]) -> Result<Vec<u8>, CryptoError> {
match &mut self.state {
NoiseState::Transport(ts) => {
let mut buf = vec![0u8; plaintext.len() + 16];
let len = ts
.write_message(plaintext, &mut buf)
.map_err(|e| CryptoError::EncryptionFailed(format!("noise encrypt: {e}")))?;
buf.truncate(len);
Ok(buf)
}
NoiseState::Handshake(_) => Err(CryptoError::EncryptionFailed(
"handshake not complete".into(),
)),
NoiseState::Failed => Err(CryptoError::EncryptionFailed("session failed".into())),
}
}
pub fn decrypt(&mut self, ciphertext: &[u8]) -> Result<Vec<u8>, CryptoError> {
match &mut self.state {
NoiseState::Transport(ts) => {
if ciphertext.len() < 16 {
return Err(CryptoError::InvalidCiphertextLength);
}
let mut buf = vec![0u8; ciphertext.len() - 16];
let len = ts
.read_message(ciphertext, &mut buf)
.map_err(|e| CryptoError::DecryptionFailed(format!("noise decrypt: {e}")))?;
buf.truncate(len);
Ok(buf)
}
NoiseState::Handshake(_) => Err(CryptoError::DecryptionFailed(
"handshake not complete".into(),
)),
NoiseState::Failed => Err(CryptoError::DecryptionFailed("session failed".into())),
}
}
pub fn nonce_counter(&self) -> Option<u64> {
match &self.state {
NoiseState::Transport(ts) => Some(ts.sending_nonce()),
_ => None,
}
}
pub fn rekey_outgoing(&mut self) -> Result<(), CryptoError> {
match &mut self.state {
NoiseState::Transport(ts) => {
ts.rekey_outgoing();
Ok(())
}
_ => Err(CryptoError::NoiseProtocol("not in transport mode".into())),
}
}
pub fn rekey_incoming(&mut self) -> Result<(), CryptoError> {
match &mut self.state {
NoiseState::Transport(ts) => {
ts.rekey_incoming();
Ok(())
}
_ => Err(CryptoError::NoiseProtocol("not in transport mode".into())),
}
}
}
impl fmt::Debug for NoiseSession {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("NoiseSession")
.field("role", &self.role)
.field("state", &self.state)
.field("handshake_complete", &self.handshake_complete)
.field("remote_public", &self.remote_public)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::crypto::KeyPair;
fn perform_handshake(
initiator_secret: &SecretKey,
responder_secret: &SecretKey,
responder_public: &PublicKey,
) -> Result<(NoiseSession, NoiseSession), CryptoError> {
let mut initiator = NoiseSession::new_initiator(initiator_secret, responder_public)?;
let mut responder = NoiseSession::new_responder(responder_secret)?;
let msg1 = initiator.write_handshake(&[])?;
let _ = responder.read_handshake(&msg1)?;
let msg2 = responder.write_handshake(&[])?;
let _ = initiator.read_handshake(&msg2)?;
assert!(initiator.is_transport());
assert!(responder.is_transport());
Ok((initiator, responder))
}
#[test]
fn test_noise_handshake() {
let client_kp = KeyPair::generate();
let server_kp = KeyPair::generate();
let (mut client, mut server) =
perform_handshake(&client_kp.secret, &server_kp.secret, &server_kp.public).unwrap();
let plaintext = b"hello from client";
let ciphertext = client.encrypt(plaintext).unwrap();
let decrypted = server.decrypt(&ciphertext).unwrap();
assert_eq!(plaintext.as_slice(), decrypted.as_slice());
let plaintext2 = b"hello from server";
let ciphertext2 = server.encrypt(plaintext2).unwrap();
let decrypted2 = client.decrypt(&ciphertext2).unwrap();
assert_eq!(plaintext2.as_slice(), decrypted2.as_slice());
}
#[test]
fn test_multiple_messages() {
let client_kp = KeyPair::generate();
let server_kp = KeyPair::generate();
let (mut client, mut server) =
perform_handshake(&client_kp.secret, &server_kp.secret, &server_kp.public).unwrap();
for i in 0..100 {
let msg = format!("message {i}");
let ciphertext = client.encrypt(msg.as_bytes()).unwrap();
let decrypted = server.decrypt(&ciphertext).unwrap();
assert_eq!(msg.as_bytes(), decrypted.as_slice());
}
}
#[test]
fn test_large_message() {
let client_kp = KeyPair::generate();
let server_kp = KeyPair::generate();
let (mut client, mut server) =
perform_handshake(&client_kp.secret, &server_kp.secret, &server_kp.public).unwrap();
let plaintext = vec![0x42u8; 8192];
let ciphertext = client.encrypt(&plaintext).unwrap();
let decrypted = server.decrypt(&ciphertext).unwrap();
assert_eq!(plaintext, decrypted);
}
}