jkipsec 0.1.0

Userspace IKEv2/IPsec VPN responder for terminating iOS VPN tunnels and exposing the inner IP traffic. Pairs with jktcp for a fully userspace TCP/IP stack.
Documentation
//! ESP (RFC 4303) packet codec for AES-GCM-16-256 child SAs. Operates on
//! UDP-encapsulated ESP (RFC 3948) as sent by iOS on port 4500.
//!
//! Wire layout (on UDP):
//!
//! ```text
//! | SPI (4) | Sequence Number (4) | IV (8) | Ciphertext + pad + pad_len + next_hdr | ICV (16) |
//! ```
//!
//! AES-GCM mapping (RFC 4106):
//!
//! - Key   = 32 bytes, derived from `prf+(SK_d, Ni|Nr)`
//! - Salt  = 4 bytes (trailing portion of the keymat block)
//! - Nonce = Salt || IV    (8-byte IV is present in every packet)
//! - AAD   = SPI || Seq    (first 8 bytes of the datagram)

use std::pin::Pin;
use std::sync::Mutex;
use std::task::{Context, Poll};

use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};

use crate::crypto::{
    Suite, aes_cbc_256_decrypt, aes_cbc_256_encrypt, aes_gcm_open, aes_gcm_seal, hmac_sha256_128,
};

/// crossfire channel flavors for the two sides of the tunnel.
///
/// Inbound (server -> jktcp reader) uses **mpmc** because `MAsyncRx` is `Sync`,
/// allowing `EspTunnel` to satisfy jktcp's `ReadWrite: Sync` bound. The
/// `AsyncStream` wrapper around it is `!Sync` on its own so we keep it behind
/// a `std::sync::Mutex`. Uncontended since only `poll_read` touches it.
///
/// Outbound (jktcp writer -> ESP encrypt task) uses **mpsc** with a `MTx`
/// sender (which is `Sync`) and a single `AsyncRx` consumer.
pub(crate) type InboundFlavor = crossfire::mpmc::List<Vec<u8>>;
pub(crate) type OutboundFlavor = crossfire::mpsc::List<Vec<u8>>;

/// Fields carved out of an inbound ESP packet.
#[derive(Debug)]
pub struct Decrypted {
    /// Inner IP packet (after stripping ESP padding + trailer).
    pub payload: Vec<u8>,
    /// IANA "Next Header" value (4 = IPv4, 41 = IPv6).
    pub next_header: u8,
    /// ESP sequence number observed.
    #[allow(dead_code)]
    pub seq: u32,
}

/// Errors returned by ESP encrypt/decrypt.
#[derive(Debug, thiserror::Error)]
pub enum EspError {
    /// Datagram is too short to contain the required ESP framing.
    #[error("packet too short ({0} bytes)")]
    TooShort(usize),
    /// Authentication tag mismatch or AEAD decryption failure.
    #[error("decrypt/verify failed")]
    Crypto,
    /// Decrypted plaintext is not well-formed (bad pad length, missing trailer).
    #[error("malformed plaintext")]
    Malformed,
}

/// Decrypt an ESP packet received UDP-encapsulated. Dispatches on suite
/// (AEAD vs CBC + HMAC).
///
/// - `key`: AES key (32 bytes for AES-256, both suites).
/// - `salt`: 4-byte GCM salt for AEAD; empty slice for CBC.
/// - `integ`: HMAC key for CBC suite; empty for AEAD.
pub fn decrypt(
    suite: Suite,
    key: &[u8],
    salt: &[u8],
    integ: &[u8],
    datagram: &[u8],
) -> Result<Decrypted, EspError> {
    let p = suite.params();
    let min = 4 + 4 + p.encr_iv_bytes + 16 + p.encr_icv_bytes; // SPI+seq+IV+1 block ct+ICV
    if datagram.len() < min {
        return Err(EspError::TooShort(datagram.len()));
    }
    let seq = u32::from_be_bytes(datagram[4..8].try_into().unwrap());
    let iv_end = 8 + p.encr_iv_bytes;
    let iv = &datagram[8..iv_end];
    let icv = &datagram[datagram.len() - p.encr_icv_bytes..];
    let ct = &datagram[iv_end..datagram.len() - p.encr_icv_bytes];

    let pt = if p.aead {
        // AAD = SPI || seq.
        let aad = &datagram[..8];
        let mut buf = ct.to_vec();
        aes_gcm_open(key, salt, iv, aad, &mut buf, icv).map_err(|_| EspError::Crypto)?;
        buf
    } else {
        // RFC 4303 §3.4.4: HMAC over SPI || seq || IV || ciphertext.
        let mut to_mac = Vec::with_capacity(datagram.len() - p.encr_icv_bytes);
        to_mac.extend_from_slice(&datagram[..datagram.len() - p.encr_icv_bytes]);
        let expected_icv = hmac_sha256_128(integ, &to_mac);
        use subtle::ConstantTimeEq;
        if !bool::from(expected_icv.ct_eq(icv)) {
            return Err(EspError::Crypto);
        }
        // Now decrypt the ciphertext under AES-CBC.
        if iv.len() != 16 {
            return Err(EspError::Crypto);
        }
        let mut iv16 = [0u8; 16];
        iv16.copy_from_slice(iv);
        let mut key32 = [0u8; 32];
        if key.len() != 32 {
            return Err(EspError::Crypto);
        }
        key32.copy_from_slice(key);
        aes_cbc_256_decrypt(&key32, &iv16, ct).map_err(|_| EspError::Crypto)?
    };

    if pt.len() < 2 {
        return Err(EspError::Malformed);
    }
    let next_header = pt[pt.len() - 1];
    let pad_len = pt[pt.len() - 2] as usize;
    if pad_len + 2 > pt.len() {
        return Err(EspError::Malformed);
    }
    let mut pt = pt;
    pt.truncate(pt.len() - 2 - pad_len);
    Ok(Decrypted {
        payload: pt,
        next_header,
        seq,
    })
}

/// Encrypt an inner IP packet as an ESP datagram. Dispatches on suite.
#[allow(clippy::too_many_arguments)]
pub fn encrypt(
    suite: Suite,
    key: &[u8],
    salt: &[u8],
    integ: &[u8],
    spi: u32,
    seq: u32,
    payload: &[u8],
    next_header: u8,
) -> Vec<u8> {
    let p = suite.params();
    // For AES-CBC the trailer must reach a 16-byte boundary; for AES-GCM
    // RFC 4303 only mandates 4-byte alignment.
    let block = if p.aead { 4 } else { 16 };
    let unaligned = payload.len() + 2;
    let pad = (block - (unaligned % block)) % block;
    let mut plaintext = Vec::with_capacity(payload.len() + pad + 2);
    plaintext.extend_from_slice(payload);
    for i in 1..=pad {
        plaintext.push(i as u8);
    }
    plaintext.push(pad as u8);
    plaintext.push(next_header);

    let mut out = Vec::with_capacity(4 + 4 + p.encr_iv_bytes + plaintext.len() + p.encr_icv_bytes);
    out.extend_from_slice(&spi.to_be_bytes());
    out.extend_from_slice(&seq.to_be_bytes());

    if p.aead {
        let mut iv = [0u8; 8];
        iv[4..].copy_from_slice(&seq.to_be_bytes());
        let aad = out.clone();
        out.extend_from_slice(&iv);
        let tag = aes_gcm_seal(key, salt, &iv, &aad, &mut plaintext)
            .expect("AES-GCM seal never fails with valid key/nonce");
        out.extend_from_slice(&plaintext);
        out.extend_from_slice(&tag);
    } else {
        // Random 16-byte IV (must be unpredictable for CBC).
        let mut iv = [0u8; 16];
        rand::Rng::fill_bytes(&mut rand::rng(), &mut iv);
        let mut key32 = [0u8; 32];
        key32.copy_from_slice(key);
        let ct = aes_cbc_256_encrypt(&key32, &iv, &plaintext)
            .expect("AES-CBC encrypt never fails when input is block-aligned");
        out.extend_from_slice(&iv);
        out.extend_from_slice(&ct);
        // HMAC over SPI || seq || IV || ciphertext.
        let icv = hmac_sha256_128(integ, &out);
        out.extend_from_slice(&icv);
    }
    out
}

// --------------------------------------------------------------- EspTunnel

/// Bridges decrypted inbound IP packets and plaintext outbound IP packets
/// between the IKE server and an application (typically jktcp's `Adapter`).
///
/// Uses lockless crossfire channels internally:
///
/// - **Inbound** (read): `crossfire::mpmc` gives an `MAsyncRx` (which is
///   `Sync`). We convert it to an `AsyncStream` for `poll_item` and wrap that
///   in `std::sync::Mutex` - the lock is never contended because only
///   `poll_read` touches it.
/// - **Outbound** (write): `crossfire::mpsc` gives an `MTx` (which is `Sync`
///   and non-blocking for unbounded channels) - perfect for `poll_write`.
pub struct EspTunnel {
    inbound: Mutex<crossfire::stream::AsyncStream<InboundFlavor>>,
    outbound_tx: crossfire::MTx<OutboundFlavor>,
    read_remainder: Mutex<Vec<u8>>,
}

impl std::fmt::Debug for EspTunnel {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("EspTunnel").finish()
    }
}

impl EspTunnel {
    /// Create the tunnel and the two channel endpoints the server needs:
    ///
    /// - `inbound_tx` - server pushes each decrypted IP packet here.
    /// - `outbound_rx` - the background ESP-encrypt task reads from here.
    pub fn channels() -> (
        Self,
        crossfire::MTx<InboundFlavor>,
        crossfire::AsyncRx<OutboundFlavor>,
    ) {
        let (in_tx, in_rx) = crossfire::mpmc::unbounded_async::<Vec<u8>>();
        let (out_tx, out_rx) = crossfire::mpsc::unbounded_async::<Vec<u8>>();
        (
            Self {
                inbound: Mutex::new(in_rx.into()),
                outbound_tx: out_tx,
                read_remainder: Mutex::new(Vec::new()),
            },
            in_tx,
            out_rx,
        )
    }
}

impl AsyncRead for EspTunnel {
    fn poll_read(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        buf: &mut ReadBuf<'_>,
    ) -> Poll<std::io::Result<()>> {
        {
            let mut rem = self.read_remainder.lock().unwrap();
            if !rem.is_empty() {
                let n = buf.remaining().min(rem.len());
                buf.put_slice(&rem[..n]);
                rem.drain(..n);
                return Poll::Ready(Ok(()));
            }
        }

        let mut stream = self.inbound.lock().unwrap();
        match stream.poll_item(cx) {
            Poll::Ready(Some(pkt)) => {
                if pkt.is_empty() {
                    cx.waker().wake_by_ref();
                    return Poll::Pending;
                }
                let n = buf.remaining().min(pkt.len());
                buf.put_slice(&pkt[..n]);
                if n < pkt.len() {
                    self.read_remainder
                        .lock()
                        .unwrap()
                        .extend_from_slice(&pkt[n..]);
                }
                Poll::Ready(Ok(()))
            }
            Poll::Ready(None) => Poll::Ready(Ok(())),
            Poll::Pending => Poll::Pending,
        }
    }
}

impl AsyncWrite for EspTunnel {
    fn poll_write(
        self: Pin<&mut Self>,
        _cx: &mut Context<'_>,
        buf: &[u8],
    ) -> Poll<std::io::Result<usize>> {
        use crossfire::BlockingTxTrait;
        self.outbound_tx.send(buf.to_vec()).map_err(|_| {
            std::io::Error::new(std::io::ErrorKind::BrokenPipe, "esp tunnel closed")
        })?;
        Poll::Ready(Ok(buf.len()))
    }

    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
        Poll::Ready(Ok(()))
    }

    fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
        Poll::Ready(Ok(()))
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn esp_round_trip_aes_gcm() {
        let suite = Suite::AesGcm256Sha256Dh19;
        let key = vec![0x42u8; 32];
        let salt = vec![0x11u8; 4];
        let integ: Vec<u8> = vec![];
        let spi: u32 = 0xdead_beef;
        let seq: u32 = 42;
        let payload = [
            0x45, 0x00, 0x00, 0x14, 0, 0, 0, 0, 64, 17, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8,
        ];
        let esp = encrypt(suite, &key, &salt, &integ, spi, seq, &payload, 4);
        let decrypted = decrypt(suite, &key, &salt, &integ, &esp).expect("decrypt");
        assert_eq!(decrypted.payload, payload);
        assert_eq!(decrypted.next_header, 4);
        assert_eq!(decrypted.seq, seq);
    }

    #[test]
    fn esp_round_trip_aes_cbc() {
        let suite = Suite::AesCbc256Sha256Dh19;
        let key = vec![0x42u8; 32];
        let salt: Vec<u8> = vec![]; // CBC has no salt
        let integ = vec![0x33u8; 32];
        let spi: u32 = 0xcafe_babe;
        let seq: u32 = 7;
        let payload = [0x45, 0x00, 0x00, 0x14, 0, 0, 0, 0, 64, 17, 0, 0, 1, 2, 3, 4];
        let esp = encrypt(suite, &key, &salt, &integ, spi, seq, &payload, 4);
        let decrypted = decrypt(suite, &key, &salt, &integ, &esp).expect("decrypt");
        assert_eq!(decrypted.payload, payload);
        assert_eq!(decrypted.next_header, 4);
        assert_eq!(decrypted.seq, seq);
    }

    #[test]
    fn esp_rejects_tamper() {
        let suite = Suite::AesGcm256Sha256Dh19;
        let key = vec![0x42u8; 32];
        let salt = vec![0x11u8; 4];
        let integ: Vec<u8> = vec![];
        let mut esp = encrypt(suite, &key, &salt, &integ, 1, 1, b"hi", 4);
        let last = esp.len() - 1;
        esp[last] ^= 0x01;
        assert!(decrypt(suite, &key, &salt, &integ, &esp).is_err());
    }
}