use crate::error::ViiperError;
use chacha20poly1305::{
aead::{Aead, KeyInit},
ChaCha20Poly1305, Nonce,
};
use hmac::{Hmac, Mac};
use pbkdf2::pbkdf2_hmac;
use rand::RngCore;
use sha2::{Digest, Sha256};
use std::io::{Read, Write};
use std::net::TcpStream;
#[cfg(feature = "async")]
use std::pin::Pin;
#[cfg(feature = "async")]
use std::task::{Context, Poll};
#[cfg(feature = "async")]
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf};
#[cfg(feature = "async")]
use tokio::net::TcpStream as AsyncTcpStream;
const HANDSHAKE_MAGIC: &[u8] = b"eVI1\x00";
const NONCE_SIZE: usize = 32;
const AUTH_CONTEXT: &[u8] = b"VIIPER-Auth-v1";
const SESSION_CONTEXT: &[u8] = b"VIIPER-Session-v1";
const PBKDF2_ITERATIONS: u32 = 100_000;
const PBKDF2_SALT: &[u8] = b"VIIPER-Key-v1";
fn derive_key(password: &str) -> Result<[u8; 32], ViiperError> {
if password.is_empty() {
return Err(ViiperError::UnexpectedResponse("Password cannot be empty".into()));
}
let mut key = [0u8; 32];
pbkdf2_hmac::<Sha256>(password.as_bytes(), PBKDF2_SALT, PBKDF2_ITERATIONS, &mut key);
Ok(key)
}
fn derive_session_key(key: &[u8], server_nonce: &[u8], client_nonce: &[u8]) -> [u8; 32] {
let mut hasher = Sha256::new();
hasher.update(key);
hasher.update(server_nonce);
hasher.update(client_nonce);
hasher.update(SESSION_CONTEXT);
hasher.finalize().into()
}
pub fn perform_handshake(mut stream: TcpStream, password: &str) -> Result<EncryptedStream, ViiperError> {
let key = derive_key(password)?;
let mut client_nonce = [0u8; NONCE_SIZE];
rand::thread_rng().fill_bytes(&mut client_nonce);
let mut mac = <Hmac::<Sha256> as KeyInit>::new_from_slice(&key)
.map_err(|_| ViiperError::UnexpectedResponse("Invalid key length".into()))?;
mac.update(AUTH_CONTEXT);
mac.update(&client_nonce);
let auth_tag = mac.finalize().into_bytes();
let mut handshake_msg = Vec::with_capacity(HANDSHAKE_MAGIC.len() + NONCE_SIZE + 32);
handshake_msg.extend_from_slice(HANDSHAKE_MAGIC);
handshake_msg.extend_from_slice(&client_nonce);
handshake_msg.extend_from_slice(&auth_tag);
stream.write_all(&handshake_msg)?;
let mut response = vec![0u8; 3 + NONCE_SIZE];
stream.read_exact(&mut response)?;
if &response[0..3] != b"OK\x00" {
let mut error_buf = Vec::new();
let _ = stream.read_to_end(&mut error_buf);
let full_response = [response, error_buf].concat();
let error_str = String::from_utf8_lossy(&full_response);
if let Ok(problem) = serde_json::from_str::<crate::error::ProblemJson>(&error_str) {
return Err(ViiperError::Protocol(problem));
}
return Err(ViiperError::UnexpectedResponse(format!("Invalid handshake response: {}", error_str)));
}
let server_nonce = &response[3..];
let session_key = derive_session_key(&key, server_nonce, &client_nonce);
Ok(EncryptedStream::new(stream, session_key)?)
}
#[cfg(feature = "async")]
pub async fn perform_handshake_async(mut stream: AsyncTcpStream, password: &str) -> Result<AsyncEncryptedStream, ViiperError> {
let key = derive_key(password)?;
let mut client_nonce = [0u8; NONCE_SIZE];
rand::thread_rng().fill_bytes(&mut client_nonce);
let mut mac = <Hmac::<Sha256> as KeyInit>::new_from_slice(&key)
.map_err(|_| ViiperError::UnexpectedResponse("Invalid key length".into()))?;
mac.update(AUTH_CONTEXT);
mac.update(&client_nonce);
let auth_tag = mac.finalize().into_bytes();
let mut handshake_msg = Vec::with_capacity(HANDSHAKE_MAGIC.len() + NONCE_SIZE + 32);
handshake_msg.extend_from_slice(HANDSHAKE_MAGIC);
handshake_msg.extend_from_slice(&client_nonce);
handshake_msg.extend_from_slice(&auth_tag);
stream.write_all(&handshake_msg).await?;
let mut response = vec![0u8; 3 + NONCE_SIZE];
stream.read_exact(&mut response).await?;
if &response[0..3] != b"OK\x00" {
let mut error_buf = Vec::new();
let _ = stream.read_to_end(&mut error_buf).await;
let full_response = [response, error_buf].concat();
let error_str = String::from_utf8_lossy(&full_response);
if let Ok(problem) = serde_json::from_str::<crate::error::ProblemJson>(&error_str) {
return Err(ViiperError::Protocol(problem));
}
return Err(ViiperError::UnexpectedResponse(format!("Invalid handshake response: {}", error_str)));
}
let server_nonce = &response[3..];
let session_key = derive_session_key(&key, server_nonce, &client_nonce);
Ok(AsyncEncryptedStream::new(stream, session_key))
}
pub struct EncryptedStream {
read: std::sync::Arc<std::sync::Mutex<EncryptedReadState>>,
write: std::sync::Arc<std::sync::Mutex<EncryptedWriteState>>,
}
struct EncryptedReadState {
stream: TcpStream,
cipher: ChaCha20Poly1305,
recv_buffer: Vec<u8>,
}
struct EncryptedWriteState {
stream: TcpStream,
cipher: ChaCha20Poly1305,
send_counter: u64,
}
impl EncryptedStream {
fn new(inner: TcpStream, session_key: [u8; 32]) -> Result<Self, ViiperError> {
let read_stream = inner.try_clone()?;
let read_cipher = ChaCha20Poly1305::new(&session_key.into());
let write_cipher = ChaCha20Poly1305::new(&session_key.into());
Ok(Self {
read: std::sync::Arc::new(std::sync::Mutex::new(EncryptedReadState {
stream: read_stream,
cipher: read_cipher,
recv_buffer: Vec::new(),
})),
write: std::sync::Arc::new(std::sync::Mutex::new(EncryptedWriteState {
stream: inner,
cipher: write_cipher,
send_counter: 0,
})),
})
}
pub fn set_nodelay(&self, nodelay: bool) -> std::io::Result<()> {
let read = self.read.lock().unwrap();
let write = self.write.lock().unwrap();
read.stream.set_nodelay(nodelay)?;
write.stream.set_nodelay(nodelay)
}
pub fn try_clone(&self) -> std::io::Result<Self> {
Ok(Self {
read: std::sync::Arc::clone(&self.read),
write: std::sync::Arc::clone(&self.write),
})
}
pub fn shutdown(&self, how: std::net::Shutdown) -> std::io::Result<()> {
let read = self.read.lock().unwrap();
let write = self.write.lock().unwrap();
let _ = read.stream.shutdown(how);
write.stream.shutdown(how)
}
}
impl Read for EncryptedStream {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
let mut inner = self.read.lock().unwrap();
if inner.recv_buffer.is_empty() {
let mut first_byte = [0u8; 1];
let n = inner.stream.read(&mut first_byte)?;
if n == 0 {
return Ok(0);
}
let mut len_buf = [0u8; 4];
len_buf[0] = first_byte[0];
inner.stream.read_exact(&mut len_buf[1..])?;
let packet_len = u32::from_be_bytes(len_buf) as usize;
if packet_len > 2 * 1024 * 1024 {
return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "Packet too large"));
}
let mut packet = vec![0u8; packet_len];
inner.stream.read_exact(&mut packet)?;
let nonce = Nonce::from_slice(&packet[0..12]);
let ciphertext_and_tag = &packet[12..];
let plaintext = inner.cipher.decrypt(nonce, ciphertext_and_tag)
.map_err(|_| std::io::Error::new(std::io::ErrorKind::InvalidData, "Decryption failed"))?;
inner.recv_buffer = plaintext;
}
let to_copy = buf.len().min(inner.recv_buffer.len());
buf[..to_copy].copy_from_slice(&inner.recv_buffer[..to_copy]);
inner.recv_buffer.drain(..to_copy);
Ok(to_copy)
}
}
impl Write for EncryptedStream {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
let mut inner = self.write.lock().unwrap();
let mut nonce_bytes = [0u8; 12];
nonce_bytes[4..].copy_from_slice(&inner.send_counter.to_be_bytes());
inner.send_counter += 1;
let nonce = Nonce::from_slice(&nonce_bytes);
let ciphertext = inner.cipher.encrypt(nonce, buf)
.map_err(|_| std::io::Error::new(std::io::ErrorKind::Other, "Encryption failed"))?;
let packet = [&nonce_bytes[..], ciphertext.as_slice()].concat();
let len_buf = (packet.len() as u32).to_be_bytes();
inner.stream.write_all(&len_buf)?;
inner.stream.write_all(&packet)?;
Ok(buf.len())
}
fn flush(&mut self) -> std::io::Result<()> {
let mut inner = self.write.lock().unwrap();
inner.stream.flush()
}
}
#[cfg(feature = "async")]
pub struct AsyncEncryptedStream {
read: AsyncEncryptedRead,
write: AsyncEncryptedWrite,
}
#[cfg(feature = "async")]
pub struct AsyncEncryptedRead {
inner: tokio::net::tcp::OwnedReadHalf,
cipher: ChaCha20Poly1305,
recv_buffer: Vec<u8>,
read_state: ReadState,
}
#[cfg(feature = "async")]
pub struct AsyncEncryptedWrite {
inner: tokio::net::tcp::OwnedWriteHalf,
cipher: ChaCha20Poly1305,
send_counter: u64,
}
#[cfg(feature = "async")]
enum ReadState {
ReadingLength { buf: [u8; 4], pos: usize },
ReadingPacket { expected_len: usize, buf: Vec<u8>, pos: usize },
Ready,
}
#[cfg(feature = "async")]
impl AsyncEncryptedStream {
fn new(inner: AsyncTcpStream, session_key: [u8; 32]) -> Self {
let (read_half, write_half) = inner.into_split();
let read_cipher = ChaCha20Poly1305::new(&session_key.into());
let write_cipher = ChaCha20Poly1305::new(&session_key.into());
Self {
read: AsyncEncryptedRead {
inner: read_half,
cipher: read_cipher,
recv_buffer: Vec::new(),
read_state: ReadState::ReadingLength { buf: [0; 4], pos: 0 },
},
write: AsyncEncryptedWrite {
inner: write_half,
cipher: write_cipher,
send_counter: 0,
},
}
}
pub fn into_split(self) -> (AsyncEncryptedRead, AsyncEncryptedWrite) {
(self.read, self.write)
}
}
#[cfg(feature = "async")]
impl AsyncRead for AsyncEncryptedRead {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
if !self.recv_buffer.is_empty() {
let to_copy = buf.remaining().min(self.recv_buffer.len());
buf.put_slice(&self.recv_buffer[..to_copy]);
self.recv_buffer.drain(..to_copy);
return Poll::Ready(Ok(()));
}
loop {
let state = std::mem::replace(&mut self.read_state, ReadState::Ready);
match state {
ReadState::ReadingLength { buf: mut len_buf, pos } => {
let mut read_buf = ReadBuf::new(&mut len_buf[pos..]);
match Pin::new(&mut self.inner).poll_read(cx, &mut read_buf) {
Poll::Ready(Ok(())) => {
let bytes_read = read_buf.filled().len();
if bytes_read == 0 {
if pos == 0 {
return Poll::Ready(Ok(())); } else {
return Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"Connection closed while reading length"
)));
}
}
let new_pos = pos + bytes_read;
if new_pos < 4 {
self.read_state = ReadState::ReadingLength { buf: len_buf, pos: new_pos };
} else {
let packet_len = u32::from_be_bytes(len_buf) as usize;
if packet_len > 2 * 1024 * 1024 {
return Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"Packet too large"
)));
}
self.read_state = ReadState::ReadingPacket {
expected_len: packet_len,
buf: vec![0u8; packet_len],
pos: 0,
};
}
}
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
Poll::Pending => {
self.read_state = ReadState::ReadingLength { buf: len_buf, pos };
return Poll::Pending;
}
}
}
ReadState::ReadingPacket { expected_len, buf: mut packet_buf, pos } => {
let mut read_buf = ReadBuf::new(&mut packet_buf[pos..]);
match Pin::new(&mut self.inner).poll_read(cx, &mut read_buf) {
Poll::Ready(Ok(())) => {
let bytes_read = read_buf.filled().len();
if bytes_read == 0 {
return Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"Connection closed while reading packet"
)));
}
let new_pos = pos + bytes_read;
if new_pos < expected_len {
self.read_state = ReadState::ReadingPacket {
expected_len,
buf: packet_buf,
pos: new_pos,
};
} else {
let nonce = Nonce::from_slice(&packet_buf[0..12]);
let ciphertext_and_tag = &packet_buf[12..];
match self.cipher.decrypt(nonce, ciphertext_and_tag) {
Ok(plaintext) => {
self.recv_buffer = plaintext;
self.read_state = ReadState::ReadingLength { buf: [0; 4], pos: 0 };
let to_copy = buf.remaining().min(self.recv_buffer.len());
buf.put_slice(&self.recv_buffer[..to_copy]);
self.recv_buffer.drain(..to_copy);
return Poll::Ready(Ok(()));
}
Err(_) => {
return Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"Decryption failed"
)));
}
}
}
}
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
Poll::Pending => {
self.read_state = ReadState::ReadingPacket {
expected_len,
buf: packet_buf,
pos,
};
return Poll::Pending;
}
}
}
ReadState::Ready => {
self.read_state = ReadState::ReadingLength { buf: [0; 4], pos: 0 };
}
}
}
}
}
#[cfg(feature = "async")]
impl AsyncWrite for AsyncEncryptedWrite {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> {
let mut nonce_bytes = [0u8; 12];
nonce_bytes[4..].copy_from_slice(&self.send_counter.to_be_bytes());
self.send_counter += 1;
let nonce = Nonce::from_slice(&nonce_bytes);
let ciphertext = self.cipher.encrypt(nonce, buf)
.map_err(|_| std::io::Error::new(std::io::ErrorKind::Other, "Encryption failed"))?;
let packet = [&nonce_bytes[..], ciphertext.as_slice()].concat();
let len_buf = (packet.len() as u32).to_be_bytes();
let full_packet = [&len_buf[..], &packet].concat();
match Pin::new(&mut self.inner).poll_write(cx, &full_packet) {
Poll::Ready(Ok(n)) if n >= full_packet.len() => Poll::Ready(Ok(buf.len())),
Poll::Ready(Ok(_)) => Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::WriteZero,
"Failed to write complete packet"
))),
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
Poll::Pending => Poll::Pending,
}
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
Pin::new(&mut self.inner).poll_flush(cx)
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
Pin::new(&mut self.inner).poll_shutdown(cx)
}
}
#[cfg(feature = "async")]
impl AsyncRead for AsyncEncryptedStream {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
Pin::new(&mut self.read).poll_read(cx, buf)
}
}
#[cfg(feature = "async")]
impl AsyncWrite for AsyncEncryptedStream {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> {
Pin::new(&mut self.write).poll_write(cx, buf)
}
fn poll_flush(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
Pin::new(&mut self.write).poll_flush(cx)
}
fn poll_shutdown(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
Pin::new(&mut self.write).poll_shutdown(cx)
}
}