use ring::agreement::{self, EphemeralPrivateKey, UnparsedPublicKey};
use ring::aead::{
self, Aad, BoundKey, Nonce, NonceSequence, OpeningKey,
SealingKey, UnboundKey,
};
use ring::rand::SystemRandom;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
pub const NONCE_LEN: usize = 12;
#[derive(Debug, thiserror::Error)]
pub enum CttpsError {
#[error("Handshake failed")]
HandshakeError,
#[error("Encryption failure")]
CryptoError,
#[error("IO Error: {0}")]
Io(#[from] std::io::Error),
}
struct CounterNonceSeq(u64);
impl NonceSequence for CounterNonceSeq {
fn advance(&mut self) -> Result<Nonce, ring::error::Unspecified> {
let mut bytes = [0u8; NONCE_LEN];
bytes[4..].copy_from_slice(&self.0.to_be_bytes());
self.0 += 1;
Ok(Nonce::assume_unique_for_key(bytes))
}
}
pub struct CttpsStream {
stream: tokio::net::TcpStream,
opening_key: OpeningKey<CounterNonceSeq>,
sealing_key: SealingKey<CounterNonceSeq>,
}
impl CttpsStream {
pub async fn connect(
mut stream: tokio::net::TcpStream,
) -> Result<Self, CttpsError> {
let rng = SystemRandom::new();
let priv_key =
EphemeralPrivateKey::generate(&agreement::X25519, &rng)
.map_err(|_| CttpsError::HandshakeError)?;
let pub_key = priv_key
.compute_public_key()
.map_err(|_| CttpsError::HandshakeError)?;
stream.write_all(pub_key.as_ref()).await?;
let mut peer_pub_bytes = [0u8; 32];
stream.read_exact(&mut peer_pub_bytes).await?;
let peer_pub_key =
UnparsedPublicKey::new(&agreement::X25519, peer_pub_bytes);
let shared_secret = agreement::agree_ephemeral(
priv_key,
&peer_pub_key,
|key_material| key_material.to_vec(),
)
.map_err(|_| CttpsError::HandshakeError)?;
let unbound_key_tx = UnboundKey::new(
&aead::AES_256_GCM,
&shared_secret,
)
.map_err(|_| CttpsError::CryptoError)?;
let sealing_key =
SealingKey::new(unbound_key_tx, CounterNonceSeq(0));
let unbound_key_rx = UnboundKey::new(
&aead::AES_256_GCM,
&shared_secret,
)
.map_err(|_| CttpsError::CryptoError)?;
let opening_key =
OpeningKey::new(unbound_key_rx, CounterNonceSeq(0));
Ok(Self {
stream,
opening_key,
sealing_key,
})
}
pub async fn send_packet(
&mut self,
payload: &mut Vec<u8>,
) -> Result<(), CttpsError> {
self.sealing_key
.seal_in_place_append_tag(Aad::empty(), payload)
.map_err(|_| CttpsError::CryptoError)?;
let total_len = payload.len() as u32;
self.stream.write_u32(total_len).await?;
self.stream.write_all(payload).await?;
Ok(())
}
pub async fn recv_packet(
&mut self,
) -> Result<Vec<u8>, CttpsError> {
let len = self.stream.read_u32().await?;
let mut buf = vec![0u8; len as usize];
self.stream.read_exact(&mut buf).await?;
let plaintext = self
.opening_key
.open_in_place(Aad::empty(), &mut buf)
.map_err(|_| CttpsError::CryptoError)?;
Ok(plaintext.to_vec())
}
}