use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf};
use tokio::net::TcpStream;
use std::pin::Pin;
use std::task::{Context, Poll};
use anyhow::{Result, anyhow};
use crate::crypto::{CttpsCrypto, generate_keypair, derive_shared_secret, compute_transcript_hash};
use std::io;
pub struct CttpsStream {
inner: TcpStream,
crypto: CttpsCrypto,
read_nonce_counter: u64,
write_nonce_counter: u64,
read_buffer: Vec<u8>,
read_pos: usize,
}
impl CttpsStream {
pub async fn connect(mut stream: TcpStream) -> Result<Self> {
let (priv_key, pub_key) = generate_keypair();
stream.write_all(&pub_key).await?;
let mut server_pub_key = [0u8; 32];
stream.read_exact(&mut server_pub_key).await?;
let shared_secret = derive_shared_secret(priv_key, &server_pub_key)?;
let transcript_hash = compute_transcript_hash(&pub_key, &server_pub_key);
let crypto = CttpsCrypto::new(&shared_secret, &transcript_hash)?;
Ok(Self {
inner: stream,
crypto,
read_nonce_counter: 0,
write_nonce_counter: 0,
read_buffer: Vec::new(),
read_pos: 0,
})
}
pub async fn accept(mut stream: TcpStream) -> Result<Self> {
let mut client_pub_key = [0u8; 32];
stream.read_exact(&mut client_pub_key).await?;
let (priv_key, pub_key) = generate_keypair();
stream.write_all(&pub_key).await?;
let shared_secret = derive_shared_secret(priv_key, &client_pub_key)?;
let transcript_hash = compute_transcript_hash(&client_pub_key, &pub_key);
let crypto = CttpsCrypto::new(&shared_secret, &transcript_hash)?;
Ok(Self {
inner: stream,
crypto,
read_nonce_counter: 0, write_nonce_counter: 0,
read_buffer: Vec::new(),
read_pos: 0,
})
}
fn generate_nonce(counter: &mut u64, is_client: bool) -> [u8; 12] {
let mut nonce = [0u8; 12];
nonce[0] = if is_client { 0 } else { 1 };
let bytes = counter.to_le_bytes();
nonce[4..12].copy_from_slice(&bytes);
*counter += 1;
nonce
}
pub async fn write_packet(&mut self, payload: &[u8]) -> Result<()> {
let mut data = payload.to_vec();
let nonce = Self::generate_nonce(&mut self.write_nonce_counter, true);
self.crypto.seal(nonce, &mut data)?;
let total_len = (12 + data.len()) as u32;
self.inner.write_u32(total_len).await?;
self.inner.write_all(&nonce).await?;
self.inner.write_all(&data).await?;
self.inner.flush().await?;
Ok(())
}
pub async fn read_packet(&mut self) -> Result<Vec<u8>> {
let total_len = self.inner.read_u32().await? as usize;
if total_len < 12 + 16 {
return Err(anyhow!("Packet too short"));
}
let mut nonce_bytes = [0u8; 12];
self.inner.read_exact(&mut nonce_bytes).await?;
let payload_len = total_len - 12;
let mut encrypted_data = vec![0u8; payload_len];
self.inner.read_exact(&mut encrypted_data).await?;
let decrypted = self.crypto.open(nonce_bytes, &mut encrypted_data)?;
self.read_nonce_counter += 1;
Ok(decrypted.to_vec())
}
}
impl AsyncRead for CttpsStream {
fn poll_read(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
if self.read_pos < self.read_buffer.len() {
let remaining = self.read_buffer.len() - self.read_pos;
let to_copy = std::cmp::min(remaining, buf.remaining());
buf.put_slice(&self.read_buffer[self.read_pos..self.read_pos + to_copy]);
self.read_pos += to_copy;
return Poll::Ready(Ok(()));
}
Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, "Use read_packet/write_packet for CTTPS")))
}
}
impl AsyncWrite for CttpsStream {
fn poll_write(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
_buf: &[u8],
) -> Poll<io::Result<usize>> {
Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, "Use write_packet for CTTPS")))
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.inner).poll_flush(cx)
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.inner).poll_shutdown(cx)
}
}