use std::pin::Pin;
use std::sync::Mutex;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use crate::crypto::{
Suite, aes_cbc_256_decrypt, aes_cbc_256_encrypt, aes_gcm_open, aes_gcm_seal, hmac_sha256_128,
};
pub(crate) type InboundFlavor = crossfire::mpmc::List<Vec<u8>>;
pub(crate) type OutboundFlavor = crossfire::mpsc::List<Vec<u8>>;
#[derive(Debug)]
pub struct Decrypted {
pub payload: Vec<u8>,
pub next_header: u8,
#[allow(dead_code)]
pub seq: u32,
}
#[derive(Debug, thiserror::Error)]
pub enum EspError {
#[error("packet too short ({0} bytes)")]
TooShort(usize),
#[error("decrypt/verify failed")]
Crypto,
#[error("malformed plaintext")]
Malformed,
}
pub fn decrypt(
suite: Suite,
key: &[u8],
salt: &[u8],
integ: &[u8],
datagram: &[u8],
) -> Result<Decrypted, EspError> {
let p = suite.params();
let min = 4 + 4 + p.encr_iv_bytes + 16 + p.encr_icv_bytes; if datagram.len() < min {
return Err(EspError::TooShort(datagram.len()));
}
let seq = u32::from_be_bytes(datagram[4..8].try_into().unwrap());
let iv_end = 8 + p.encr_iv_bytes;
let iv = &datagram[8..iv_end];
let icv = &datagram[datagram.len() - p.encr_icv_bytes..];
let ct = &datagram[iv_end..datagram.len() - p.encr_icv_bytes];
let pt = if p.aead {
let aad = &datagram[..8];
let mut buf = ct.to_vec();
aes_gcm_open(key, salt, iv, aad, &mut buf, icv).map_err(|_| EspError::Crypto)?;
buf
} else {
let mut to_mac = Vec::with_capacity(datagram.len() - p.encr_icv_bytes);
to_mac.extend_from_slice(&datagram[..datagram.len() - p.encr_icv_bytes]);
let expected_icv = hmac_sha256_128(integ, &to_mac);
use subtle::ConstantTimeEq;
if !bool::from(expected_icv.ct_eq(icv)) {
return Err(EspError::Crypto);
}
if iv.len() != 16 {
return Err(EspError::Crypto);
}
let mut iv16 = [0u8; 16];
iv16.copy_from_slice(iv);
let mut key32 = [0u8; 32];
if key.len() != 32 {
return Err(EspError::Crypto);
}
key32.copy_from_slice(key);
aes_cbc_256_decrypt(&key32, &iv16, ct).map_err(|_| EspError::Crypto)?
};
if pt.len() < 2 {
return Err(EspError::Malformed);
}
let next_header = pt[pt.len() - 1];
let pad_len = pt[pt.len() - 2] as usize;
if pad_len + 2 > pt.len() {
return Err(EspError::Malformed);
}
let mut pt = pt;
pt.truncate(pt.len() - 2 - pad_len);
Ok(Decrypted {
payload: pt,
next_header,
seq,
})
}
#[allow(clippy::too_many_arguments)]
pub fn encrypt(
suite: Suite,
key: &[u8],
salt: &[u8],
integ: &[u8],
spi: u32,
seq: u32,
payload: &[u8],
next_header: u8,
) -> Vec<u8> {
let p = suite.params();
let block = if p.aead { 4 } else { 16 };
let unaligned = payload.len() + 2;
let pad = (block - (unaligned % block)) % block;
let mut plaintext = Vec::with_capacity(payload.len() + pad + 2);
plaintext.extend_from_slice(payload);
for i in 1..=pad {
plaintext.push(i as u8);
}
plaintext.push(pad as u8);
plaintext.push(next_header);
let mut out = Vec::with_capacity(4 + 4 + p.encr_iv_bytes + plaintext.len() + p.encr_icv_bytes);
out.extend_from_slice(&spi.to_be_bytes());
out.extend_from_slice(&seq.to_be_bytes());
if p.aead {
let mut iv = [0u8; 8];
iv[4..].copy_from_slice(&seq.to_be_bytes());
let aad = out.clone();
out.extend_from_slice(&iv);
let tag = aes_gcm_seal(key, salt, &iv, &aad, &mut plaintext)
.expect("AES-GCM seal never fails with valid key/nonce");
out.extend_from_slice(&plaintext);
out.extend_from_slice(&tag);
} else {
let mut iv = [0u8; 16];
rand::Rng::fill_bytes(&mut rand::rng(), &mut iv);
let mut key32 = [0u8; 32];
key32.copy_from_slice(key);
let ct = aes_cbc_256_encrypt(&key32, &iv, &plaintext)
.expect("AES-CBC encrypt never fails when input is block-aligned");
out.extend_from_slice(&iv);
out.extend_from_slice(&ct);
let icv = hmac_sha256_128(integ, &out);
out.extend_from_slice(&icv);
}
out
}
pub struct EspTunnel {
inbound: Mutex<crossfire::stream::AsyncStream<InboundFlavor>>,
outbound_tx: crossfire::MTx<OutboundFlavor>,
read_remainder: Mutex<Vec<u8>>,
}
impl std::fmt::Debug for EspTunnel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("EspTunnel").finish()
}
}
impl EspTunnel {
pub fn channels() -> (
Self,
crossfire::MTx<InboundFlavor>,
crossfire::AsyncRx<OutboundFlavor>,
) {
let (in_tx, in_rx) = crossfire::mpmc::unbounded_async::<Vec<u8>>();
let (out_tx, out_rx) = crossfire::mpsc::unbounded_async::<Vec<u8>>();
(
Self {
inbound: Mutex::new(in_rx.into()),
outbound_tx: out_tx,
read_remainder: Mutex::new(Vec::new()),
},
in_tx,
out_rx,
)
}
}
impl AsyncRead for EspTunnel {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
{
let mut rem = self.read_remainder.lock().unwrap();
if !rem.is_empty() {
let n = buf.remaining().min(rem.len());
buf.put_slice(&rem[..n]);
rem.drain(..n);
return Poll::Ready(Ok(()));
}
}
let mut stream = self.inbound.lock().unwrap();
match stream.poll_item(cx) {
Poll::Ready(Some(pkt)) => {
if pkt.is_empty() {
cx.waker().wake_by_ref();
return Poll::Pending;
}
let n = buf.remaining().min(pkt.len());
buf.put_slice(&pkt[..n]);
if n < pkt.len() {
self.read_remainder
.lock()
.unwrap()
.extend_from_slice(&pkt[n..]);
}
Poll::Ready(Ok(()))
}
Poll::Ready(None) => Poll::Ready(Ok(())),
Poll::Pending => Poll::Pending,
}
}
}
impl AsyncWrite for EspTunnel {
fn poll_write(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
use crossfire::BlockingTxTrait;
self.outbound_tx.send(buf.to_vec()).map_err(|_| {
std::io::Error::new(std::io::ErrorKind::BrokenPipe, "esp tunnel closed")
})?;
Poll::Ready(Ok(buf.len()))
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
Poll::Ready(Ok(()))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn esp_round_trip_aes_gcm() {
let suite = Suite::AesGcm256Sha256Dh19;
let key = vec![0x42u8; 32];
let salt = vec![0x11u8; 4];
let integ: Vec<u8> = vec![];
let spi: u32 = 0xdead_beef;
let seq: u32 = 42;
let payload = [
0x45, 0x00, 0x00, 0x14, 0, 0, 0, 0, 64, 17, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8,
];
let esp = encrypt(suite, &key, &salt, &integ, spi, seq, &payload, 4);
let decrypted = decrypt(suite, &key, &salt, &integ, &esp).expect("decrypt");
assert_eq!(decrypted.payload, payload);
assert_eq!(decrypted.next_header, 4);
assert_eq!(decrypted.seq, seq);
}
#[test]
fn esp_round_trip_aes_cbc() {
let suite = Suite::AesCbc256Sha256Dh19;
let key = vec![0x42u8; 32];
let salt: Vec<u8> = vec![]; let integ = vec![0x33u8; 32];
let spi: u32 = 0xcafe_babe;
let seq: u32 = 7;
let payload = [0x45, 0x00, 0x00, 0x14, 0, 0, 0, 0, 64, 17, 0, 0, 1, 2, 3, 4];
let esp = encrypt(suite, &key, &salt, &integ, spi, seq, &payload, 4);
let decrypted = decrypt(suite, &key, &salt, &integ, &esp).expect("decrypt");
assert_eq!(decrypted.payload, payload);
assert_eq!(decrypted.next_header, 4);
assert_eq!(decrypted.seq, seq);
}
#[test]
fn esp_rejects_tamper() {
let suite = Suite::AesGcm256Sha256Dh19;
let key = vec![0x42u8; 32];
let salt = vec![0x11u8; 4];
let integ: Vec<u8> = vec![];
let mut esp = encrypt(suite, &key, &salt, &integ, 1, 1, b"hi", 4);
let last = esp.len() - 1;
esp[last] ^= 0x01;
assert!(decrypt(suite, &key, &salt, &integ, &esp).is_err());
}
}