use std::sync::Arc;
use async_trait::async_trait;
use bytes::Bytes;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio::sync::Mutex;
use super::{Transport, TransportKind};
use crate::error::{SrxError, TransportError};
use crate::frame::{read_length_prefixed, write_length_prefixed};
const RECV_BUF: usize = 65_536;
pub struct TcpTransport {
stream: Arc<Mutex<Option<TcpStream>>>,
}
impl TcpTransport {
pub async fn connect(addr: impl tokio::net::ToSocketAddrs) -> crate::error::Result<Self> {
let stream = TcpStream::connect(addr)
.await
.map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
Ok(Self {
stream: Arc::new(Mutex::new(Some(stream))),
})
}
#[must_use]
pub fn from_stream(stream: TcpStream) -> Self {
Self {
stream: Arc::new(Mutex::new(Some(stream))),
}
}
pub async fn send_framed(&self, payload: &[u8]) -> crate::error::Result<()> {
let mut guard = self.stream.lock().await;
let stream = guard
.as_mut()
.ok_or(SrxError::Transport(TransportError::ChannelClosed))?;
write_length_prefixed(stream, payload).await
}
pub async fn recv_framed(&self) -> crate::error::Result<Bytes> {
let mut guard = self.stream.lock().await;
let stream = guard
.as_mut()
.ok_or(SrxError::Transport(TransportError::ChannelClosed))?;
let v = read_length_prefixed(stream).await?;
Ok(Bytes::from(v))
}
}
#[async_trait]
impl Transport for TcpTransport {
fn kind(&self) -> TransportKind {
TransportKind::Tcp
}
async fn send(&self, data: Bytes) -> crate::error::Result<()> {
let mut guard = self.stream.lock().await;
let stream = guard
.as_mut()
.ok_or(SrxError::Transport(TransportError::ChannelClosed))?;
stream
.write_all(&data)
.await
.map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
stream
.flush()
.await
.map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
Ok(())
}
async fn recv(&self) -> crate::error::Result<Bytes> {
let mut guard = self.stream.lock().await;
let stream = guard
.as_mut()
.ok_or(SrxError::Transport(TransportError::ChannelClosed))?;
let mut buf = vec![0u8; RECV_BUF];
let n = stream
.read(&mut buf)
.await
.map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
if n == 0 {
return Err(SrxError::Transport(TransportError::ChannelClosed));
}
buf.truncate(n);
Ok(Bytes::from(buf))
}
async fn is_healthy(&self) -> bool {
self.stream.lock().await.is_some()
}
async fn close(&self) -> crate::error::Result<()> {
let mut guard = self.stream.lock().await;
if let Some(mut s) = guard.take() {
let _ = s.shutdown().await;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn send_recv_roundtrip() {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let server = tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();
let t = TcpTransport::from_stream(stream);
let got = t.recv().await.unwrap();
assert_eq!(got.as_ref(), b"ping");
t.send(Bytes::from_static(b"pong")).await.unwrap();
});
let client = TcpTransport::connect(addr).await.unwrap();
client.send(Bytes::from_static(b"ping")).await.unwrap();
let reply = client.recv().await.unwrap();
assert_eq!(reply.as_ref(), b"pong");
client.close().await.unwrap();
server.await.unwrap();
}
#[tokio::test]
async fn framed_roundtrip_matches_length_prefix_wire() {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let payload = b"framed-payload-srx";
let server = tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();
let t = TcpTransport::from_stream(stream);
let got = t.recv_framed().await.unwrap();
assert_eq!(got.as_ref(), payload);
t.send_framed(b"ack").await.unwrap();
});
let client = TcpTransport::connect(addr).await.unwrap();
client.send_framed(payload).await.unwrap();
let reply = client.recv_framed().await.unwrap();
assert_eq!(reply.as_ref(), b"ack");
client.close().await.unwrap();
server.await.unwrap();
}
}