use crate::socket::error::{EncryptSendError, Result, SocketError};
use crate::transport::Transport;
use async_channel;
use futures::channel::oneshot;
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering};
use wacore::handshake::NoiseCipher;
use wacore::runtime::{AbortHandle, Runtime};
const INLINE_ENCRYPT_THRESHOLD: usize = 16 * 1024;
type SendResult = std::result::Result<(), EncryptSendError>;
struct SendJob {
plaintext_buf: Vec<u8>,
out_buf: Vec<u8>,
response_tx: oneshot::Sender<SendResult>,
}
pub struct NoiseSocket {
#[allow(dead_code)] runtime: Arc<dyn Runtime>,
read_key: Arc<NoiseCipher>,
read_counter: Arc<AtomicU32>,
send_job_tx: async_channel::Sender<SendJob>,
_sender_task_handle: AbortHandle,
}
impl NoiseSocket {
pub fn new(
runtime: Arc<dyn Runtime>,
transport: Arc<dyn Transport>,
write_key: NoiseCipher,
read_key: NoiseCipher,
) -> Self {
let write_key = Arc::new(write_key);
let read_key = Arc::new(read_key);
let (send_job_tx, send_job_rx) = async_channel::bounded::<SendJob>(32);
let transport_clone = transport.clone();
let write_key_clone = write_key.clone();
let rt_clone = runtime.clone();
let sender_task_handle = runtime.spawn(Box::pin(Self::sender_task(
rt_clone,
transport_clone,
write_key_clone,
send_job_rx,
)));
Self {
runtime,
read_key,
read_counter: Arc::new(AtomicU32::new(0)),
send_job_tx,
_sender_task_handle: sender_task_handle,
}
}
async fn sender_task(
runtime: Arc<dyn Runtime>,
transport: Arc<dyn Transport>,
write_key: Arc<NoiseCipher>,
send_job_rx: async_channel::Receiver<SendJob>,
) {
let mut write_counter: u32 = 0;
while let Ok(job) = send_job_rx.recv().await {
let result = Self::process_send_job(
&runtime,
&transport,
&write_key,
&mut write_counter,
job.plaintext_buf,
job.out_buf,
)
.await;
let _ = job.response_tx.send(result);
}
}
async fn process_send_job(
runtime: &Arc<dyn Runtime>,
transport: &Arc<dyn Transport>,
write_key: &Arc<NoiseCipher>,
write_counter: &mut u32,
mut plaintext_buf: Vec<u8>,
mut out_buf: Vec<u8>,
) -> SendResult {
let counter = *write_counter;
if plaintext_buf.len() <= INLINE_ENCRYPT_THRESHOLD {
if let Err(e) = write_key.encrypt_in_place_with_counter(counter, &mut plaintext_buf) {
return Err(EncryptSendError::crypto(anyhow::anyhow!(e.to_string())));
}
out_buf.clear();
if let Err(e) = wacore::framing::encode_frame_into(&plaintext_buf, None, &mut out_buf) {
return Err(EncryptSendError::framing(e));
}
} else {
let write_key = write_key.clone();
let plaintext_arc = Arc::new(plaintext_buf);
let plaintext_arc_for_task = plaintext_arc.clone();
let encrypt_result = wacore::runtime::blocking(&**runtime, move || {
write_key.encrypt_with_counter(counter, &plaintext_arc_for_task[..])
})
.await;
plaintext_buf = Arc::try_unwrap(plaintext_arc).unwrap_or_else(|arc| (*arc).clone());
drop(plaintext_buf);
let ciphertext = match encrypt_result {
Ok(c) => c,
Err(e) => {
return Err(EncryptSendError::crypto(anyhow::anyhow!(e.to_string())));
}
};
out_buf.clear();
if let Err(e) = wacore::framing::encode_frame_into(&ciphertext, None, &mut out_buf) {
return Err(EncryptSendError::framing(e));
}
}
if let Err(e) = transport.send(out_buf).await {
return Err(EncryptSendError::transport(e));
}
*write_counter = write_counter.wrapping_add(1);
Ok(())
}
pub async fn encrypt_and_send(&self, plaintext_buf: Vec<u8>, out_buf: Vec<u8>) -> SendResult {
let (response_tx, response_rx) = oneshot::channel();
let job = SendJob {
plaintext_buf,
out_buf,
response_tx,
};
if let Err(_send_err) = self.send_job_tx.send(job).await {
return Err(EncryptSendError::channel_closed());
}
match response_rx.await {
Ok(result) => result,
Err(_) => {
Err(EncryptSendError::channel_closed())
}
}
}
pub fn decrypt_frame(&self, ciphertext: &[u8]) -> Result<Vec<u8>> {
let counter = self.read_counter.fetch_add(1, Ordering::SeqCst);
self.read_key
.decrypt_with_counter(counter, ciphertext)
.map_err(|e| SocketError::Crypto(e.to_string()))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_encrypt_and_send_succeeds() {
let transport = Arc::new(crate::transport::mock::MockTransport);
let key = [0u8; 32];
let write_key = NoiseCipher::new(&key).expect("32-byte key should be valid");
let read_key = NoiseCipher::new(&key).expect("32-byte key should be valid");
let socket = NoiseSocket::new(
Arc::new(crate::runtime_impl::TokioRuntime),
transport,
write_key,
read_key,
);
let plaintext_buf = Vec::with_capacity(1024);
let encrypted_buf = Vec::with_capacity(1024);
let result = socket.encrypt_and_send(plaintext_buf, encrypted_buf).await;
assert!(result.is_ok(), "encrypt_and_send should succeed");
}
#[tokio::test]
async fn test_concurrent_sends_maintain_order() {
use async_lock::Mutex;
use async_trait::async_trait;
use std::sync::Arc;
struct RecordingTransport {
recorded_order: Arc<Mutex<Vec<u8>>>,
read_key: NoiseCipher,
counter: std::sync::atomic::AtomicU32,
}
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
impl crate::transport::Transport for RecordingTransport {
async fn send(&self, data: Vec<u8>) -> std::result::Result<(), anyhow::Error> {
if data.len() > 16 {
let ciphertext = &data[3..];
let counter = self
.counter
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
if let Ok(plaintext) = self.read_key.decrypt_with_counter(counter, ciphertext)
&& !plaintext.is_empty()
{
let index = plaintext[0];
let mut order = self.recorded_order.lock().await;
order.push(index);
}
}
Ok(())
}
async fn disconnect(&self) {}
}
let recorded_order = Arc::new(Mutex::new(Vec::new()));
let key = [0u8; 32];
let write_key = NoiseCipher::new(&key).expect("32-byte key should be valid");
let read_key = NoiseCipher::new(&key).expect("32-byte key should be valid");
let transport = Arc::new(RecordingTransport {
recorded_order: recorded_order.clone(),
read_key: NoiseCipher::new(&key).expect("32-byte key should be valid"),
counter: std::sync::atomic::AtomicU32::new(0),
});
let socket = Arc::new(NoiseSocket::new(
Arc::new(crate::runtime_impl::TokioRuntime),
transport,
write_key,
read_key,
));
let mut handles = Vec::new();
for i in 0..10 {
let socket = socket.clone();
handles.push(tokio::spawn(async move {
let mut plaintext = vec![i as u8];
plaintext.extend_from_slice(&[0u8; 99]);
let out_buf = Vec::with_capacity(256);
socket.encrypt_and_send(plaintext, out_buf).await
}));
}
for handle in handles {
let result = handle.await.expect("task should complete");
assert!(result.is_ok(), "All sends should succeed");
}
let order = recorded_order.lock().await;
let expected: Vec<u8> = (0..10).collect();
assert_eq!(*order, expected, "Sends should maintain FIFO order");
}
#[tokio::test]
async fn test_encrypted_buffer_sizing_is_sufficient() {
use async_trait::async_trait;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
struct SizeRecordingTransport {
last_size: Arc<AtomicUsize>,
}
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
impl crate::transport::Transport for SizeRecordingTransport {
async fn send(&self, data: Vec<u8>) -> std::result::Result<(), anyhow::Error> {
self.last_size.store(data.len(), Ordering::SeqCst);
Ok(())
}
async fn disconnect(&self) {}
}
let last_size = Arc::new(AtomicUsize::new(0));
let transport = Arc::new(SizeRecordingTransport {
last_size: last_size.clone(),
});
let key = [0u8; 32];
let write_key = NoiseCipher::new(&key).expect("32-byte key should be valid");
let read_key = NoiseCipher::new(&key).expect("32-byte key should be valid");
let socket = NoiseSocket::new(
Arc::new(crate::runtime_impl::TokioRuntime),
transport,
write_key,
read_key,
);
let test_sizes = [0, 1, 50, 100, 500, 1000, 1024, 2000, 5000, 16384, 20000];
for size in test_sizes {
let plaintext = vec![0xABu8; size];
let buffer_capacity = plaintext.len() + 32;
let encrypted_buf = Vec::with_capacity(buffer_capacity);
let result = socket
.encrypt_and_send(plaintext.clone(), encrypted_buf)
.await;
assert!(
result.is_ok(),
"encrypt_and_send should succeed for payload size {}",
size
);
let actual_encrypted_size = last_size.load(Ordering::SeqCst);
let expected_max = size + 19;
assert_eq!(
actual_encrypted_size, expected_max,
"Encrypted size for {} byte payload should be {} (got {})",
size, expected_max, actual_encrypted_size
);
assert!(
buffer_capacity >= actual_encrypted_size,
"Buffer capacity {} should be >= encrypted size {} for payload size {}",
buffer_capacity,
actual_encrypted_size,
size
);
}
}
#[tokio::test]
async fn test_encrypted_buffer_sizing_edge_cases() {
use async_trait::async_trait;
use std::sync::Arc;
struct NoOpTransport;
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
impl crate::transport::Transport for NoOpTransport {
async fn send(&self, _data: Vec<u8>) -> std::result::Result<(), anyhow::Error> {
Ok(())
}
async fn disconnect(&self) {}
}
let transport = Arc::new(NoOpTransport);
let key = [0u8; 32];
let write_key = NoiseCipher::new(&key).expect("32-byte key should be valid");
let read_key = NoiseCipher::new(&key).expect("32-byte key should be valid");
let socket = NoiseSocket::new(
Arc::new(crate::runtime_impl::TokioRuntime),
transport,
write_key,
read_key,
);
let result = socket
.encrypt_and_send(vec![], Vec::with_capacity(32))
.await;
assert!(result.is_ok(), "Empty payload should encrypt successfully");
let at_threshold = vec![0u8; 16 * 1024];
let result = socket
.encrypt_and_send(at_threshold, Vec::with_capacity(16 * 1024 + 32))
.await;
assert!(
result.is_ok(),
"Payload at inline threshold should encrypt successfully"
);
let above_threshold = vec![0u8; 16 * 1024 + 1];
let result = socket
.encrypt_and_send(above_threshold, Vec::with_capacity(16 * 1024 + 33))
.await;
assert!(
result.is_ok(),
"Payload above inline threshold should encrypt successfully"
);
}
}