metalssh 0.0.1

Experimental SSH implementation
//! `aes128-gcm@openssh.com` cipher.
//!
//! AES-128-GCM as defined in RFC 5647. The packet length is not encrypted
//! and is used as Additional Authenticated Data (AAD). The nonce is constructed
//! from a 4-byte fixed IV and an 8-byte invocation counter (sequence number).

use aws_lc_rs::aead::AES_128_GCM;
use aws_lc_rs::aead::Aad;
use aws_lc_rs::aead::LessSafeKey;
use aws_lc_rs::aead::Nonce;
use aws_lc_rs::aead::UnboundKey;

use crate::crypto::cipher::Cipher;
use crate::types::Error;
use crate::types::Result;
use crate::wire::Packet;

/// Tag length for AES-GCM (128 bits = 16 bytes)
const TAG_LEN: usize = 16;

/// AES-128 key length (128 bits = 16 bytes)
const KEY_LEN: usize = 16;

/// Fixed IV length (32 bits = 4 bytes)
const FIXED_IV_LEN: usize = 4;

/// Invocation counter length (64 bits = 8 bytes)
const INVOCATION_COUNTER_LEN: usize = 8;

/// Total nonce length (96 bits = 12 bytes)
const NONCE_LEN: usize = FIXED_IV_LEN + INVOCATION_COUNTER_LEN;

/// See module level docs.
pub struct Aes128Gcm {
    key: LessSafeKey,
    fixed_iv: [u8; FIXED_IV_LEN],
}

impl Aes128Gcm {
    /// Creates a new AES-128-GCM cipher.
    ///
    /// # Arguments
    /// * `key` - 16-byte encryption key
    /// * `fixed_iv` - 4-byte fixed IV (implicit IV as per RFC 5647)
    #[must_use]
    pub fn new(key: [u8; KEY_LEN], fixed_iv: [u8; FIXED_IV_LEN]) -> Result<Self> {
        let key = UnboundKey::new(&AES_128_GCM, &key).map_err(|_| Error::Crypto)?;
        let key = LessSafeKey::new(key);
        Ok(Self { key, fixed_iv })
    }

    /// Constructs a 12-byte nonce from the fixed IV and sequence number.
    fn make_nonce(&self, sequence_number: u32) -> Nonce {
        let mut nonce_bytes = [0u8; NONCE_LEN];

        // First 4 bytes: fixed IV
        nonce_bytes[..FIXED_IV_LEN].copy_from_slice(&self.fixed_iv);

        // Last 8 bytes: invocation counter (sequence number as big-endian u64)
        nonce_bytes[FIXED_IV_LEN..].copy_from_slice(&(sequence_number as u64).to_be_bytes());

        Nonce::assume_unique_for_key(nonce_bytes)
    }
}

impl Cipher for Aes128Gcm {
    const AEAD_LENGTH: Option<usize> = Some(TAG_LEN);

    fn encrypt_packet<'buf, B>(
        &self,
        packet: &'buf mut Packet<&'buf mut B>,
        sequence_number: u32,
    ) -> Result<()>
    where
        B: AsRef<[u8]> + AsMut<[u8]> + ?Sized,
    {
        let nonce = self.make_nonce(sequence_number);

        let packet_len = packet.packet_length()?;
        let packet_len_bytes = packet_len.to_be_bytes();

        // The packet length is used as AAD and remains unencrypted.
        // We need to encrypt only the payload portion (padding_length + payload +
        // padding)
        let full_packet = packet.packet_mut()?;
        let (_, rest) = full_packet.split_at_mut(4); // Skip the 4-byte packet length
        let (payload, tag_space) = rest.split_at_mut(packet_len as usize);

        // Use packet length as Additional Authenticated Data
        let aad = Aad::from(&packet_len_bytes);

        // Encrypt in place and get the tag
        let tag = self
            .key
            .seal_in_place_separate_tag(nonce, aad, payload)
            .map_err(|_| Error::Crypto)?;

        // Copy tag to the tag space
        tag_space[..TAG_LEN].copy_from_slice(tag.as_ref());

        Ok(())
    }

    fn decrypt_packet_length<B>(&self, packet: &Packet<B>, _sequence_number: u32) -> Result<u32>
    where
        B: AsRef<[u8]>,
    {
        packet.packet_length()
    }

    fn decrypt_packet<'buf, B>(
        &self,
        packet: &'buf mut Packet<&'buf mut B>,
        sequence_number: u32,
    ) -> Result<()>
    where
        B: AsRef<[u8]> + AsMut<[u8]> + ?Sized,
    {
        let nonce = self.make_nonce(sequence_number);

        let packet_len = packet.packet_length()?;
        let packet_len_bytes = packet_len.to_be_bytes();

        // Use packet length as Additional Authenticated Data
        let aad = Aad::from(&packet_len_bytes);

        // Get the ciphertext + tag portion (everything after packet length)
        // The ciphertext and tag are already concatenated in the buffer
        let full_packet = packet.packet_mut()?;
        let ciphertext_and_tag = &mut full_packet[4..];

        // Decrypt and verify in place
        // open_in_place expects ciphertext || tag and returns plaintext
        self.key
            .open_in_place(nonce, aad, ciphertext_and_tag)
            .map_err(|_| Error::Crypto)?;

        Ok(())
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    // Test with simple values to verify the cipher works correctly
    #[test]
    fn encrypt_decrypt_roundtrip() {
        let key = [0x42; KEY_LEN];
        let fixed_iv = [0x01, 0x02, 0x03, 0x04];
        let cipher = Aes128Gcm::new(key, fixed_iv).unwrap();

        // Create a simple packet:
        // [packet_len=9][padding_len=6][payload='hi'][padding=000000]
        let mut data = vec![
            0x00, 0x00, 0x00, 0x09, // packet length = 9 (1 + 2 + 6)
            0x06, // padding length = 6
            b'h', b'i', // payload
            0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // padding
        ];

        // Add space for the tag
        data.extend_from_slice(&[0u8; TAG_LEN]);

        let original_packet_len = u32::from_be_bytes([data[0], data[1], data[2], data[3]]);
        let original_payload = [data[4], data[5], data[6]];

        // Encrypt
        {
            let mut packet = Packet::new(&mut data, TAG_LEN as u8);
            cipher.encrypt_packet(&mut packet, 0).unwrap();
        }

        // Verify packet length is still plaintext
        let encrypted_packet_len = u32::from_be_bytes([data[0], data[1], data[2], data[3]]);
        assert_eq!(encrypted_packet_len, original_packet_len);

        // Verify payload was encrypted (should be different)
        assert_ne!(&data[4..7], &original_payload);

        // Decrypt
        {
            let mut packet = Packet::new(&mut data, TAG_LEN as u8);
            cipher.decrypt_packet(&mut packet, 0).unwrap();
        }

        // Verify decryption worked
        assert_eq!(data[0..4], [0x00, 0x00, 0x00, 0x09]);
        assert_eq!(data[4], 0x06); // padding length
        assert_eq!(&data[5..7], b"hi"); // payload
        assert_eq!(&data[7..13], &[0x00, 0x00, 0x00, 0x00, 0x00, 0x00]); // padding
    }

    #[test]
    fn decrypt_with_wrong_sequence_fails() {
        let key = [0x42; KEY_LEN];
        let fixed_iv = [0x01, 0x02, 0x03, 0x04];
        let cipher = Aes128Gcm::new(key, fixed_iv).unwrap();

        let mut data = vec![
            0x00, 0x00, 0x00, 0x09, 0x06, b'h', b'i', 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
        ];
        data.extend_from_slice(&[0u8; TAG_LEN]);

        // Encrypt with sequence number 0
        {
            let mut packet = Packet::new(&mut data, TAG_LEN as u8);
            cipher.encrypt_packet(&mut packet, 0).unwrap();
        }

        // Try to decrypt with sequence number 1 (should fail)
        {
            let mut packet = Packet::new(&mut data, TAG_LEN as u8);
            let result = cipher.decrypt_packet(&mut packet, 1);
            assert!(result.is_err());
        }
    }

    #[test]
    fn decrypt_with_tampered_aad_fails() {
        let key = [0x42; KEY_LEN];
        let fixed_iv = [0x01, 0x02, 0x03, 0x04];
        let cipher = Aes128Gcm::new(key, fixed_iv).unwrap();

        let mut data = vec![
            0x00, 0x00, 0x00, 0x09, 0x06, b'h', b'i', 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
        ];
        data.extend_from_slice(&[0u8; TAG_LEN]);

        // Encrypt
        {
            let mut packet = Packet::new(&mut data, TAG_LEN as u8);
            cipher.encrypt_packet(&mut packet, 0).unwrap();
        }

        // Tamper with packet length (AAD)
        data[3] = 0x0a;

        // Try to decrypt (should fail due to AAD mismatch)
        {
            let mut packet = Packet::new(&mut data, TAG_LEN as u8);
            let result = cipher.decrypt_packet(&mut packet, 0);
            assert!(result.is_err());
        }
    }
}