use alloc::{sync::Arc, vec::Vec};
use async_lock::Mutex;
use future_form::{FutureForm, Local, Sendable, future_form};
use rand::{RngCore, rngs::OsRng};
use subduction_core::{peer::id::PeerId, transport::Transport};
use crate::error::{DisconnectionError, RecvError, SendError};
const OUTBOUND_CHANNEL_CAPACITY: usize = 1024;
const INBOUND_CHANNEL_CAPACITY: usize = 128;
#[derive(Debug)]
struct Inner {
chan_id: u64,
peer_id: PeerId,
outbound_tx: async_channel::Sender<Vec<u8>>,
inbound_writer: async_channel::Sender<Vec<u8>>,
inbound_reader: async_channel::Receiver<Vec<u8>>,
cancel_guard: Mutex<Option<async_channel::Sender<()>>>,
}
#[derive(Debug, Clone)]
pub struct HttpLongPollTransport {
inner: Arc<Inner>,
outbound_rx: async_channel::Receiver<Vec<u8>>,
}
impl HttpLongPollTransport {
#[must_use]
pub fn new(peer_id: PeerId) -> Self {
let (inbound_writer, inbound_reader) = async_channel::bounded(INBOUND_CHANNEL_CAPACITY);
let (outbound_tx, outbound_rx) = async_channel::bounded(OUTBOUND_CHANNEL_CAPACITY);
let chan_id = OsRng.next_u64();
Self {
inner: Arc::new(Inner {
chan_id,
peer_id,
outbound_tx,
inbound_writer,
inbound_reader,
cancel_guard: Mutex::new(None),
}),
outbound_rx,
}
}
pub async fn set_cancel_guard(&self, guard: async_channel::Sender<()>) {
*self.inner.cancel_guard.lock().await = Some(guard);
}
pub async fn push_inbound(
&self,
bytes: Vec<u8>,
) -> Result<(), async_channel::SendError<Vec<u8>>> {
self.inner.inbound_writer.send(bytes).await
}
pub async fn pull_outbound(&self) -> Result<Vec<u8>, async_channel::RecvError> {
self.outbound_rx.recv().await
}
pub fn close(&self) {
self.inner.inbound_writer.close();
self.inner.outbound_tx.close();
self.outbound_rx.close();
self.inner.inbound_reader.close();
if let Some(mut guard) = self.inner.cancel_guard.try_lock() {
*guard = None;
}
}
}
#[future_form(Sendable, Local)]
impl<K: FutureForm> Transport<K> for HttpLongPollTransport {
type SendError = SendError;
type RecvError = RecvError;
type DisconnectionError = DisconnectionError;
fn disconnect(&self) -> K::Future<'_, Result<(), Self::DisconnectionError>> {
tracing::info!(peer_id = %self.inner.peer_id, "HttpLongPoll::disconnect");
let conn = self.clone();
K::from_future(async move {
conn.close();
Ok(())
})
}
fn send_bytes(&self, bytes: &[u8]) -> K::Future<'_, Result<(), Self::SendError>> {
tracing::debug!(
"http-lp: sending {} outbound bytes to peer {}",
bytes.len(),
self.inner.peer_id
);
let data = bytes.to_vec();
let tx = self.inner.outbound_tx.clone();
K::from_future(async move {
tx.send(data).await.map_err(|_| SendError)?;
Ok(())
})
}
fn recv_bytes(&self) -> K::Future<'_, Result<Vec<u8>, Self::RecvError>> {
let chan = self.inner.inbound_reader.clone();
tracing::debug!(
chan_id = self.inner.chan_id,
"waiting on recv {:?}",
self.inner.peer_id
);
K::from_future(async move {
let bytes = chan.recv().await.map_err(|_| {
tracing::error!("inbound channel closed unexpectedly");
RecvError
})?;
tracing::debug!("recv: inbound {} bytes", bytes.len());
Ok(bytes)
})
}
}
impl PartialEq for HttpLongPollTransport {
fn eq(&self, other: &Self) -> bool {
Arc::ptr_eq(&self.inner, &other.inner)
}
}
#[cfg(test)]
#[allow(clippy::expect_used)]
mod tests {
use super::*;
use subduction_core::connection::message::SyncMessage;
#[tokio::test]
async fn peer_id_preserved() {
let peer_id = PeerId::new([1u8; 32]);
let conn = HttpLongPollTransport::new(peer_id);
assert_eq!(conn.inner.peer_id, peer_id);
}
#[tokio::test]
async fn push_inbound_and_recv() {
use sedimentree_core::id::SedimentreeId;
use subduction_core::{
connection::{Connection, message::RemoveSubscriptions},
transport::message::MessageTransport,
};
let peer_id = PeerId::new([2u8; 32]);
let conn = HttpLongPollTransport::new(peer_id);
let msg = SyncMessage::RemoveSubscriptions(RemoveSubscriptions {
ids: alloc::vec![SedimentreeId::from_bytes([0u8; 32])],
});
conn.push_inbound(msg.encode()).await.expect("push ok");
let mt = MessageTransport::new(conn);
let received = Connection::<Sendable, SyncMessage>::recv(&mt)
.await
.expect("recv ok");
assert!(matches!(received, SyncMessage::RemoveSubscriptions(_)));
}
#[tokio::test]
async fn send_and_pull_outbound() {
use sedimentree_core::id::SedimentreeId;
use subduction_core::{
connection::{Connection, message::RemoveSubscriptions},
transport::message::MessageTransport,
};
let peer_id = PeerId::new([3u8; 32]);
let conn = HttpLongPollTransport::new(peer_id);
let msg = SyncMessage::RemoveSubscriptions(RemoveSubscriptions {
ids: alloc::vec![SedimentreeId::from_bytes([0u8; 32])],
});
let mt = MessageTransport::new(conn.clone());
Connection::<Sendable, SyncMessage>::send(&mt, &msg)
.await
.expect("send ok");
let pulled = conn.pull_outbound().await.expect("pull ok");
let decoded = SyncMessage::try_decode(&pulled).expect("decode ok");
assert!(matches!(decoded, SyncMessage::RemoveSubscriptions(_)));
}
}