use crate::error::{Error, Result};
use blake2::{Blake2s256, Digest};
use chacha20poly1305::{
aead::{Aead, KeyInit, Payload},
ChaCha20Poly1305, Key, Nonce,
};
use curve25519_dalek::montgomery::MontgomeryPoint;
use curve25519_dalek::scalar::clamp_integer;
use rand::RngCore;
pub const PROTOCOL: &[u8] = b"Noise_XK_25519_ChaChaPoly_BLAKE2s";
pub const HASHLEN: usize = 32;
pub const BLOCKLEN: usize = 64;
fn blake2s(data: &[u8]) -> [u8; 32] {
let mut h = Blake2s256::new();
h.update(data);
h.finalize().into()
}
pub fn hmac_blake2s(key: &[u8], data: &[u8]) -> [u8; 32] {
let k_bytes;
let k: &[u8] = if key.len() > BLOCKLEN {
k_bytes = blake2s(key);
&k_bytes
} else {
key
};
let mut padded = [0u8; BLOCKLEN];
padded[..k.len()].copy_from_slice(k);
let mut ipad = [0u8; BLOCKLEN];
let mut opad = [0u8; BLOCKLEN];
for i in 0..BLOCKLEN {
ipad[i] = padded[i] ^ 0x36;
opad[i] = padded[i] ^ 0x5C;
}
let mut inner = Vec::with_capacity(BLOCKLEN + data.len());
inner.extend_from_slice(&ipad);
inner.extend_from_slice(data);
let inner_h = blake2s(&inner);
let mut outer = Vec::with_capacity(BLOCKLEN + HASHLEN);
outer.extend_from_slice(&opad);
outer.extend_from_slice(&inner_h);
blake2s(&outer)
}
pub fn hkdf2(ck: &[u8], ikm: &[u8]) -> ([u8; 32], [u8; 32]) {
let t0 = hmac_blake2s(ck, ikm);
let t1 = hmac_blake2s(&t0, &[0x01]);
let mut t2_input = Vec::with_capacity(33);
t2_input.extend_from_slice(&t1);
t2_input.push(0x02);
let t2 = hmac_blake2s(&t0, &t2_input);
(t1, t2)
}
fn nonce_bytes(n: u64) -> [u8; 12] {
let mut out = [0u8; 12];
out[4..].copy_from_slice(&n.to_le_bytes());
out
}
fn x25519(scalar: &[u8; 32], u: &[u8; 32]) -> [u8; 32] {
let clamped = clamp_integer(*scalar);
let scalar_inner = curve25519_dalek::scalar::Scalar::from_bytes_mod_order(clamped);
let point = MontgomeryPoint(*u);
(point * scalar_inner).to_bytes()
}
pub fn x25519_public_from_private(priv_: &[u8; 32]) -> [u8; 32] {
let mut base = [0u8; 32];
base[0] = 9;
x25519(priv_, &base)
}
pub fn x25519_random_private() -> [u8; 32] {
let mut k = [0u8; 32];
rand::thread_rng().fill_bytes(&mut k);
k
}
#[derive(Clone)]
pub struct CipherState {
pub k: Option<[u8; 32]>,
pub n: u64,
}
impl CipherState {
pub fn new(k: Option<[u8; 32]>) -> Self {
Self { k, n: 0 }
}
pub fn encrypt_with_ad(&mut self, ad: &[u8], plaintext: &[u8]) -> Result<Vec<u8>> {
let key = match self.k {
Some(k) => k,
None => return Ok(plaintext.to_vec()),
};
let aead = ChaCha20Poly1305::new(Key::from_slice(&key));
let nonce = nonce_bytes(self.n);
let ct = aead
.encrypt(
Nonce::from_slice(&nonce),
Payload {
msg: plaintext,
aad: ad,
},
)
.map_err(|e| Error::Noise(format!("encrypt: {e}")))?;
self.n += 1;
Ok(ct)
}
pub fn decrypt_with_ad(&mut self, ad: &[u8], ciphertext: &[u8]) -> Result<Vec<u8>> {
let key = match self.k {
Some(k) => k,
None => return Ok(ciphertext.to_vec()),
};
let aead = ChaCha20Poly1305::new(Key::from_slice(&key));
let nonce = nonce_bytes(self.n);
let pt = aead
.decrypt(
Nonce::from_slice(&nonce),
Payload {
msg: ciphertext,
aad: ad,
},
)
.map_err(|e| Error::Noise(format!("decrypt: {e}")))?;
self.n += 1;
Ok(pt)
}
}
pub struct SymmetricState {
pub ck: [u8; 32],
pub h: [u8; 32],
pub cs: CipherState,
}
impl SymmetricState {
pub fn new() -> Self {
let mut h = [0u8; HASHLEN];
if PROTOCOL.len() <= HASHLEN {
h[..PROTOCOL.len()].copy_from_slice(PROTOCOL);
} else {
h = blake2s(PROTOCOL);
}
let ck = h;
Self {
ck,
h,
cs: CipherState::new(None),
}
}
pub fn mix_hash(&mut self, data: &[u8]) {
let mut buf = Vec::with_capacity(HASHLEN + data.len());
buf.extend_from_slice(&self.h);
buf.extend_from_slice(data);
self.h = blake2s(&buf);
}
pub fn mix_key(&mut self, ikm: &[u8]) {
let (new_ck, temp_k) = hkdf2(&self.ck, ikm);
self.ck = new_ck;
self.cs = CipherState::new(Some(temp_k));
}
pub fn encrypt_and_hash(&mut self, plaintext: &[u8]) -> Result<Vec<u8>> {
let ct = self.cs.encrypt_with_ad(&self.h, plaintext)?;
self.mix_hash(&ct);
Ok(ct)
}
pub fn decrypt_and_hash(&mut self, ciphertext: &[u8]) -> Result<Vec<u8>> {
let pt = self.cs.decrypt_with_ad(&self.h, ciphertext)?;
self.mix_hash(ciphertext);
Ok(pt)
}
pub fn split(self) -> (CipherState, CipherState) {
let (k1, k2) = hkdf2(&self.ck, &[]);
(CipherState::new(Some(k1)), CipherState::new(Some(k2)))
}
}
impl Default for SymmetricState {
fn default() -> Self {
Self::new()
}
}
pub struct HandshakeResult {
pub send_cs: CipherState,
pub recv_cs: CipherState,
}
impl HandshakeResult {
pub fn send(&mut self, plaintext: &[u8]) -> Result<Vec<u8>> {
self.send_cs.encrypt_with_ad(&[], plaintext)
}
pub fn recv(&mut self, ciphertext: &[u8]) -> Result<Vec<u8>> {
self.recv_cs.decrypt_with_ad(&[], ciphertext)
}
}
pub fn build_prologue(initiator_did: &str, responder_did: &str) -> Vec<u8> {
let prefix = b"agent-phone/1";
let init = initiator_did.as_bytes();
let resp = responder_did.as_bytes();
let mut out = Vec::with_capacity(prefix.len() + 2 + init.len() + 2 + resp.len());
out.extend_from_slice(prefix);
out.extend_from_slice(&(init.len() as u16).to_be_bytes());
out.extend_from_slice(init);
out.extend_from_slice(&(resp.len() as u16).to_be_bytes());
out.extend_from_slice(resp);
out
}
pub struct InitiatorHandshake {
ss: SymmetricState,
static_priv: [u8; 32],
static_pub: [u8; 32],
responder_static_pub: [u8; 32],
e_priv: Option<[u8; 32]>,
re_pub: Option<[u8; 32]>,
}
impl InitiatorHandshake {
pub fn new(
prologue: &[u8],
static_priv: [u8; 32],
static_pub: [u8; 32],
responder_static_pub: [u8; 32],
) -> Self {
let mut ss = SymmetricState::new();
ss.mix_hash(prologue);
ss.mix_hash(&responder_static_pub);
Self {
ss,
static_priv,
static_pub,
responder_static_pub,
e_priv: None,
re_pub: None,
}
}
pub fn write_message_1(&mut self) -> Result<Vec<u8>> {
let e_priv = x25519_random_private();
let e_pub = x25519_public_from_private(&e_priv);
self.ss.mix_hash(&e_pub);
self.ss
.mix_key(&x25519(&e_priv, &self.responder_static_pub));
let enc = self.ss.encrypt_and_hash(&[])?;
self.e_priv = Some(e_priv);
let mut out = Vec::with_capacity(32 + enc.len());
out.extend_from_slice(&e_pub);
out.extend_from_slice(&enc);
Ok(out)
}
pub fn read_message_2(&mut self, msg: &[u8]) -> Result<()> {
let e_priv = self
.e_priv
.ok_or_else(|| Error::Noise("write_message_1 must run first".into()))?;
if msg.len() < 32 {
return Err(Error::Noise("message 2 too short".into()));
}
let mut re_pub = [0u8; 32];
re_pub.copy_from_slice(&msg[..32]);
self.ss.mix_hash(&re_pub);
self.ss.mix_key(&x25519(&e_priv, &re_pub));
self.ss.decrypt_and_hash(&msg[32..])?;
self.re_pub = Some(re_pub);
Ok(())
}
pub fn write_message_3(&mut self) -> Result<Vec<u8>> {
let re_pub = self
.re_pub
.ok_or_else(|| Error::Noise("read_message_2 must run first".into()))?;
let enc_s = self.ss.encrypt_and_hash(&self.static_pub)?;
self.ss.mix_key(&x25519(&self.static_priv, &re_pub));
let enc_payload = self.ss.encrypt_and_hash(&[])?;
let mut out = Vec::with_capacity(enc_s.len() + enc_payload.len());
out.extend_from_slice(&enc_s);
out.extend_from_slice(&enc_payload);
Ok(out)
}
pub fn finish(self) -> HandshakeResult {
let (send_cs, recv_cs) = self.ss.split();
HandshakeResult { send_cs, recv_cs }
}
}
pub struct ResponderHandshake {
ss: SymmetricState,
static_priv: [u8; 32],
#[allow(dead_code)]
static_pub: [u8; 32],
e_priv: Option<[u8; 32]>,
re_init_pub: Option<[u8; 32]>,
}
impl ResponderHandshake {
pub fn new(prologue: &[u8], static_priv: [u8; 32], static_pub: [u8; 32]) -> Self {
let mut ss = SymmetricState::new();
ss.mix_hash(prologue);
ss.mix_hash(&static_pub);
Self {
ss,
static_priv,
static_pub,
e_priv: None,
re_init_pub: None,
}
}
pub fn read_message_1(&mut self, msg: &[u8]) -> Result<()> {
if msg.len() < 32 {
return Err(Error::Noise("message 1 too short".into()));
}
let mut re_init = [0u8; 32];
re_init.copy_from_slice(&msg[..32]);
self.ss.mix_hash(&re_init);
self.ss.mix_key(&x25519(&self.static_priv, &re_init));
self.ss.decrypt_and_hash(&msg[32..])?;
self.re_init_pub = Some(re_init);
Ok(())
}
pub fn write_message_2(&mut self) -> Result<Vec<u8>> {
let re_init = self
.re_init_pub
.ok_or_else(|| Error::Noise("read_message_1 must run first".into()))?;
let e_priv = x25519_random_private();
let e_pub = x25519_public_from_private(&e_priv);
self.ss.mix_hash(&e_pub);
self.ss.mix_key(&x25519(&e_priv, &re_init));
let enc = self.ss.encrypt_and_hash(&[])?;
self.e_priv = Some(e_priv);
let mut out = Vec::with_capacity(32 + enc.len());
out.extend_from_slice(&e_pub);
out.extend_from_slice(&enc);
Ok(out)
}
pub fn read_message_3(&mut self, msg: &[u8]) -> Result<()> {
let e_priv = self
.e_priv
.ok_or_else(|| Error::Noise("write_message_2 must run first".into()))?;
if msg.len() < 48 {
return Err(Error::Noise("message 3 too short".into()));
}
let enc_s = &msg[..48];
let rest = &msg[48..];
let ris_bytes = self.ss.decrypt_and_hash(enc_s)?;
if ris_bytes.len() != 32 {
return Err(Error::Noise("decrypted static key wrong size".into()));
}
let mut ris = [0u8; 32];
ris.copy_from_slice(&ris_bytes);
self.ss.mix_key(&x25519(&e_priv, &ris));
self.ss.decrypt_and_hash(rest)?;
Ok(())
}
pub fn finish(self) -> HandshakeResult {
let (recv_cs, send_cs) = self.ss.split();
HandshakeResult { send_cs, recv_cs }
}
}