use std::collections::VecDeque;
use std::time::Duration;
use async_trait::async_trait;
use calimero_crypto::{Nonce, SharedKey};
use calimero_node_primitives::sync::{EncryptionState, StreamMessage, SyncTransport};
use eyre::{bail, Result};
use tokio::sync::mpsc;
use tokio::time::timeout;
const DEFAULT_BUFFER_SIZE: usize = 64;
const DEFAULT_SIM_TIMEOUT: Duration = Duration::from_secs(5);
pub struct SimStream {
tx: Option<mpsc::Sender<Vec<u8>>>,
rx: mpsc::Receiver<Vec<u8>>,
buffer: VecDeque<Vec<u8>>,
encryption: EncryptionState,
recv_timeout: Duration,
closed: bool,
}
impl SimStream {
#[must_use]
pub fn pair() -> (Self, Self) {
Self::pair_with_buffer(DEFAULT_BUFFER_SIZE)
}
#[must_use]
pub fn pair_with_buffer(buffer_size: usize) -> (Self, Self) {
let (tx_a, rx_b) = mpsc::channel(buffer_size);
let (tx_b, rx_a) = mpsc::channel(buffer_size);
let stream_a = Self {
tx: Some(tx_a),
rx: rx_a,
buffer: VecDeque::new(),
encryption: EncryptionState::new(),
recv_timeout: DEFAULT_SIM_TIMEOUT,
closed: false,
};
let stream_b = Self {
tx: Some(tx_b),
rx: rx_b,
buffer: VecDeque::new(),
encryption: EncryptionState::new(),
recv_timeout: DEFAULT_SIM_TIMEOUT,
closed: false,
};
(stream_a, stream_b)
}
#[must_use]
pub fn one_way() -> (SimStreamSender, Self) {
Self::one_way_with_buffer(DEFAULT_BUFFER_SIZE)
}
#[must_use]
pub fn one_way_with_buffer(buffer_size: usize) -> (SimStreamSender, Self) {
let (tx, rx) = mpsc::channel(buffer_size);
let sender = SimStreamSender {
tx,
encryption: EncryptionState::new(),
};
let receiver = Self {
tx: None, rx,
buffer: VecDeque::new(),
encryption: EncryptionState::new(),
recv_timeout: DEFAULT_SIM_TIMEOUT,
closed: false,
};
(sender, receiver)
}
pub fn set_timeout(&mut self, timeout: Duration) {
self.recv_timeout = timeout;
}
#[must_use]
pub fn has_buffered(&self) -> bool {
!self.buffer.is_empty()
}
#[must_use]
pub fn buffered_count(&self) -> usize {
self.buffer.len()
}
async fn recv_raw_timeout(&mut self, budget: Duration) -> Result<Option<Vec<u8>>> {
if let Some(data) = self.buffer.pop_front() {
return Ok(Some(data));
}
if self.closed {
return Ok(None);
}
match timeout(budget, self.rx.recv()).await {
Ok(Some(data)) => Ok(Some(data)),
Ok(None) => {
self.closed = true;
Ok(None)
}
Err(_) => bail!("timeout receiving message"),
}
}
}
#[async_trait]
impl SyncTransport for SimStream {
async fn send(&mut self, message: &StreamMessage<'_>) -> Result<()> {
if self.closed {
bail!("stream is closed");
}
let tx = self
.tx
.as_ref()
.ok_or_else(|| eyre::eyre!("no sender available"))?;
let encoded = borsh::to_vec(message)?;
let encrypted = self.encryption.encrypt(encoded)?;
tx.send(encrypted)
.await
.map_err(|_| eyre::eyre!("channel closed"))?;
Ok(())
}
async fn recv(&mut self) -> Result<Option<StreamMessage<'static>>> {
self.recv_timeout(self.recv_timeout).await
}
async fn recv_timeout(&mut self, budget: Duration) -> Result<Option<StreamMessage<'static>>> {
let Some(data) = self.recv_raw_timeout(budget).await? else {
return Ok(None);
};
let decrypted = self.encryption.decrypt(data)?;
let decoded = borsh::from_slice::<StreamMessage<'static>>(&decrypted)?;
Ok(Some(decoded))
}
fn set_encryption(&mut self, encryption: Option<(SharedKey, Nonce)>) {
self.encryption.set(encryption);
}
fn encryption(&self) -> Option<(SharedKey, Nonce)> {
self.encryption.get()
}
async fn close(&mut self) -> Result<()> {
self.closed = true;
self.tx = None;
Ok(())
}
}
impl std::fmt::Debug for SimStream {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SimStream")
.field("buffered", &self.buffer.len())
.field("closed", &self.closed)
.field("has_encryption", &self.encryption.get().is_some())
.finish()
}
}
pub struct SimStreamSender {
tx: mpsc::Sender<Vec<u8>>,
encryption: EncryptionState,
}
impl SimStreamSender {
pub async fn send(&self, message: &StreamMessage<'_>) -> Result<()> {
let encoded = borsh::to_vec(message)?;
let encrypted = self.encryption.encrypt(encoded)?;
self.tx
.send(encrypted)
.await
.map_err(|_| eyre::eyre!("channel closed"))?;
Ok(())
}
pub fn set_encryption(&mut self, encryption: Option<(SharedKey, Nonce)>) {
self.encryption.set(encryption);
}
}
#[cfg(test)]
mod tests {
use super::*;
use calimero_crypto::NONCE_LEN;
use calimero_node_primitives::sync::wire::{InitPayload, MessagePayload};
use calimero_primitives::context::ContextId;
use calimero_primitives::identity::PublicKey;
fn test_context_id() -> ContextId {
ContextId::from([1u8; 32])
}
fn test_public_key() -> PublicKey {
PublicKey::from([2u8; 32])
}
#[tokio::test]
async fn test_pair_send_recv() {
let (mut alice, mut bob) = SimStream::pair();
let msg = StreamMessage::Init {
context_id: test_context_id(),
party_id: test_public_key(),
payload: InitPayload::DagHeadsRequest {
context_id: test_context_id(),
},
next_nonce: [0; NONCE_LEN],
};
alice.send(&msg).await.expect("send should succeed");
let received = bob.recv().await.expect("recv should succeed");
assert!(received.is_some());
let original_bytes = borsh::to_vec(&msg).unwrap();
let received_bytes = borsh::to_vec(&received.unwrap()).unwrap();
assert_eq!(original_bytes, received_bytes);
}
#[tokio::test]
async fn test_bidirectional() {
let (mut alice, mut bob) = SimStream::pair();
let msg_from_alice = StreamMessage::Init {
context_id: test_context_id(),
party_id: test_public_key(),
payload: InitPayload::DagHeadsRequest {
context_id: test_context_id(),
},
next_nonce: [1; NONCE_LEN],
};
let msg_from_bob = StreamMessage::Message {
sequence_id: 1,
payload: MessagePayload::DagHeadsResponse {
dag_heads: vec![],
root_hash: [0u8; 32].into(),
},
next_nonce: [2; NONCE_LEN],
};
alice.send(&msg_from_alice).await.unwrap();
bob.send(&msg_from_bob).await.unwrap();
let from_alice = bob.recv().await.unwrap().unwrap();
let from_bob = alice.recv().await.unwrap().unwrap();
assert_eq!(
borsh::to_vec(&msg_from_alice).unwrap(),
borsh::to_vec(&from_alice).unwrap()
);
assert_eq!(
borsh::to_vec(&msg_from_bob).unwrap(),
borsh::to_vec(&from_bob).unwrap()
);
}
#[tokio::test]
async fn test_timeout() {
let (mut _alice, mut bob) = SimStream::pair();
let result = bob.recv_timeout(Duration::from_millis(10)).await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("timeout"));
}
#[tokio::test]
async fn test_close() {
let (mut alice, mut bob) = SimStream::pair();
alice.close().await.unwrap();
let msg = StreamMessage::Init {
context_id: test_context_id(),
party_id: test_public_key(),
payload: InitPayload::DagHeadsRequest {
context_id: test_context_id(),
},
next_nonce: [0; NONCE_LEN],
};
let result = alice.send(&msg).await;
assert!(result.is_err());
bob.set_timeout(Duration::from_millis(50));
}
}