use crate::error::{NexarError, Result};
use crate::transport::BulkTransport;
use futures::future::BoxFuture;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::{Mutex, mpsc};
type TaggedReceiverMap = HashMap<u64, Arc<Mutex<mpsc::Receiver<Vec<u8>>>>>;
struct RecvState {
senders: HashMap<u64, mpsc::Sender<Vec<u8>>>,
pending: HashMap<u64, Vec<Vec<u8>>>,
}
pub struct TcpBulkTransport {
writer: Mutex<Box<dyn AsyncWrite + Unpin + Send>>,
untagged_rx: Mutex<mpsc::Receiver<Vec<u8>>>,
state: Arc<Mutex<RecvState>>,
tagged_rx: Mutex<TaggedReceiverMap>,
_recv_handle: tokio::task::JoinHandle<()>,
}
pub trait TaggedBulkTransport: BulkTransport {
fn send_bulk_tagged<'a>(&'a self, tag: u64, data: &'a [u8]) -> BoxFuture<'a, Result<()>>;
fn recv_bulk_tagged<'a>(
&'a self,
tag: u64,
expected_size: usize,
) -> BoxFuture<'a, Result<Vec<u8>>>;
}
impl TcpBulkTransport {
pub fn from_stream(stream: TcpStream) -> Self {
let (reader, writer) = tokio::io::split(stream);
Self::from_split(Box::new(reader), Box::new(writer))
}
pub fn from_split(
reader: Box<dyn AsyncRead + Unpin + Send>,
writer: Box<dyn AsyncWrite + Unpin + Send>,
) -> Self {
let (untagged_tx, untagged_rx) = mpsc::channel(64);
let state = Arc::new(Mutex::new(RecvState {
senders: HashMap::new(),
pending: HashMap::new(),
}));
let recv_state = Arc::clone(&state);
let recv_handle = tokio::spawn(async move {
recv_loop(reader, untagged_tx, recv_state).await;
});
Self {
writer: Mutex::new(writer),
untagged_rx: Mutex::new(untagged_rx),
state,
tagged_rx: Mutex::new(HashMap::new()),
_recv_handle: recv_handle,
}
}
async fn write_frame(&self, tag: u64, data: &[u8]) -> Result<()> {
let mut writer = self.writer.lock().await;
writer
.write_all(&tag.to_le_bytes())
.await
.map_err(|e| NexarError::transport_with_source("tcp bulk write tag", e))?;
writer
.write_all(&(data.len() as u64).to_le_bytes())
.await
.map_err(|e| NexarError::transport_with_source("tcp bulk write len", e))?;
writer
.write_all(data)
.await
.map_err(|e| NexarError::transport_with_source("tcp bulk write payload", e))?;
writer
.flush()
.await
.map_err(|e| NexarError::transport_with_source("tcp bulk flush", e))?;
Ok(())
}
async fn get_tag_receiver(&self, tag: u64) -> Arc<Mutex<mpsc::Receiver<Vec<u8>>>> {
{
let map = self.tagged_rx.lock().await;
if let Some(rx) = map.get(&tag) {
return Arc::clone(rx);
}
}
let (tx, rx) = mpsc::channel(64);
let flush_tx = tx.clone();
let pending_data = {
let mut st = self.state.lock().await;
let pending = st.pending.remove(&tag);
st.senders.insert(tag, tx);
pending
};
if let Some(data_vec) = pending_data {
for data in data_vec {
let _ = flush_tx.send(data).await;
}
}
let rx_arc = Arc::new(Mutex::new(rx));
self.tagged_rx.lock().await.insert(tag, Arc::clone(&rx_arc));
rx_arc
}
}
impl BulkTransport for TcpBulkTransport {
fn send_bulk<'a>(&'a self, data: &'a [u8]) -> BoxFuture<'a, Result<()>> {
Box::pin(self.write_frame(0, data))
}
fn recv_bulk<'a>(&'a self, _expected_size: usize) -> BoxFuture<'a, Result<Vec<u8>>> {
Box::pin(async move {
self.untagged_rx
.lock()
.await
.recv()
.await
.ok_or_else(|| NexarError::transport("tcp bulk connection closed"))
})
}
}
impl TaggedBulkTransport for TcpBulkTransport {
fn send_bulk_tagged<'a>(&'a self, tag: u64, data: &'a [u8]) -> BoxFuture<'a, Result<()>> {
Box::pin(self.write_frame(tag, data))
}
fn recv_bulk_tagged<'a>(
&'a self,
tag: u64,
_expected_size: usize,
) -> BoxFuture<'a, Result<Vec<u8>>> {
Box::pin(async move {
let rx_arc = self.get_tag_receiver(tag).await;
rx_arc
.lock()
.await
.recv()
.await
.ok_or_else(|| NexarError::transport("tcp bulk connection closed"))
})
}
}
const MAX_TCP_FRAME_SIZE: usize = 4 * 1024 * 1024 * 1024;
async fn recv_loop(
mut reader: Box<dyn AsyncRead + Unpin + Send>,
untagged_tx: mpsc::Sender<Vec<u8>>,
state: Arc<Mutex<RecvState>>,
) {
let mut tag_buf = [0u8; 8];
let mut len_buf = [0u8; 8];
loop {
if let Err(e) = reader.read_exact(&mut tag_buf).await {
tracing::debug!("tcp bulk recv loop ended: {e}");
return;
}
if let Err(e) = reader.read_exact(&mut len_buf).await {
tracing::debug!("tcp bulk recv loop ended reading len: {e}");
return;
}
let tag = u64::from_le_bytes(tag_buf);
let len = u64::from_le_bytes(len_buf) as usize;
if len > MAX_TCP_FRAME_SIZE {
tracing::warn!(len, "tcp bulk: frame too large, closing connection");
return;
}
let mut payload = vec![0u8; len];
if let Err(e) = reader.read_exact(&mut payload).await {
tracing::debug!("tcp bulk recv loop ended reading payload: {e}");
return;
}
if tag == 0 {
if untagged_tx.send(payload).await.is_err() {
return;
}
} else {
let tx = {
let st = state.lock().await;
st.senders.get(&tag).cloned()
};
if let Some(tx) = tx {
if tx.send(payload).await.is_err() {
return;
}
} else {
let mut st = state.lock().await;
st.pending.entry(tag).or_default().push(payload);
}
}
}
}
pub async fn tcp_bulk_listen(
addr: std::net::SocketAddr,
) -> Result<(TcpListener, std::net::SocketAddr)> {
let listener = TcpListener::bind(addr)
.await
.map_err(|e| NexarError::transport_with_source("tcp bulk listen", e))?;
let local = listener
.local_addr()
.map_err(|e| NexarError::transport_with_source("tcp bulk local_addr", e))?;
Ok((listener, local))
}
pub async fn tcp_bulk_connect(addr: std::net::SocketAddr) -> Result<TcpBulkTransport> {
let stream = TcpStream::connect(addr)
.await
.map_err(|e| NexarError::transport_with_source("tcp bulk connect", e))?;
stream
.set_nodelay(true)
.map_err(|e| NexarError::transport_with_source("tcp bulk set_nodelay", e))?;
Ok(TcpBulkTransport::from_stream(stream))
}
pub async fn tcp_bulk_accept(listener: &TcpListener) -> Result<TcpBulkTransport> {
let (stream, _addr) = listener
.accept()
.await
.map_err(|e| NexarError::transport_with_source("tcp bulk accept", e))?;
stream
.set_nodelay(true)
.map_err(|e| NexarError::transport_with_source("tcp bulk set_nodelay", e))?;
Ok(TcpBulkTransport::from_stream(stream))
}
pub async fn tcp_bulk_connect_tls(
addr: std::net::SocketAddr,
tls_config: Arc<rustls::ClientConfig>,
) -> Result<TcpBulkTransport> {
let stream = TcpStream::connect(addr)
.await
.map_err(|e| NexarError::transport_with_source("tcp bulk tls connect", e))?;
stream
.set_nodelay(true)
.map_err(|e| NexarError::transport_with_source("tcp bulk tls set_nodelay", e))?;
let connector = tokio_rustls::TlsConnector::from(tls_config);
let server_name = rustls::pki_types::ServerName::try_from("localhost")
.map_err(|e| NexarError::Tls(format!("bulk TLS server name: {e}")))?;
let tls_stream = connector
.connect(server_name, stream)
.await
.map_err(|e| NexarError::transport_with_source("tcp bulk tls handshake (client)", e))?;
let (reader, writer) = tokio::io::split(tls_stream);
Ok(TcpBulkTransport::from_split(
Box::new(reader),
Box::new(writer),
))
}
pub async fn tcp_bulk_accept_tls(
listener: &TcpListener,
tls_config: Arc<rustls::ServerConfig>,
) -> Result<TcpBulkTransport> {
let (stream, _addr) = listener
.accept()
.await
.map_err(|e| NexarError::transport_with_source("tcp bulk tls accept", e))?;
stream
.set_nodelay(true)
.map_err(|e| NexarError::transport_with_source("tcp bulk tls set_nodelay", e))?;
let acceptor = tokio_rustls::TlsAcceptor::from(tls_config);
let tls_stream = acceptor
.accept(stream)
.await
.map_err(|e| NexarError::transport_with_source("tcp bulk tls handshake (server)", e))?;
let (reader, writer) = tokio::io::split(tls_stream);
Ok(TcpBulkTransport::from_split(
Box::new(reader),
Box::new(writer),
))
}