use std::sync::Arc;
use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader, BufWriter};
use tokio::net::TcpStream;
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
use tokio::sync::Mutex;
use super::command_message::CommandMessage;
use super::error::IpcError;
use super::message::IpcMessageHeader;
const MAX_MESSAGE_SIZE: u32 = 16 * 1024 * 1024;
const BUFFER_SIZE: usize = 64 * 1024;
pub struct FramedStream {
reader: BufReader<OwnedReadHalf>,
writer: BufWriter<OwnedWriteHalf>,
}
impl FramedStream {
pub fn new(stream: TcpStream) -> Self {
let (read_half, write_half) = stream.into_split();
Self {
reader: BufReader::with_capacity(BUFFER_SIZE, read_half),
writer: BufWriter::with_capacity(BUFFER_SIZE, write_half),
}
}
pub async fn send(&mut self, msg: &CommandMessage) -> Result<(), IpcError> {
let payload = serde_json::to_vec(msg)?;
let payload_len = payload.len() as u32;
if payload_len > MAX_MESSAGE_SIZE {
return Err(IpcError::Framing(format!(
"Message too large: {} bytes (max {})",
payload_len, MAX_MESSAGE_SIZE
)));
}
let header = IpcMessageHeader::new(msg.message_type, payload_len);
self.writer.write_all(&header.to_bytes()).await?;
self.writer.write_all(&payload).await?;
self.writer.flush().await?;
Ok(())
}
pub async fn recv(&mut self) -> Result<CommandMessage, IpcError> {
let mut header_bytes = [0u8; IpcMessageHeader::SIZE];
self.reader.read_exact(&mut header_bytes).await?;
let header = IpcMessageHeader::from_bytes(&header_bytes);
if header.length > MAX_MESSAGE_SIZE {
return Err(IpcError::Framing(format!(
"Message too large: {} bytes (max {})",
header.length, MAX_MESSAGE_SIZE
)));
}
let mut payload = vec![0u8; header.length as usize];
self.reader.read_exact(&mut payload).await?;
let msg: CommandMessage = serde_json::from_slice(&payload)?;
Ok(msg)
}
pub async fn recv_timeout(
&mut self,
timeout: std::time::Duration,
) -> Result<Option<CommandMessage>, IpcError> {
match tokio::time::timeout(timeout, self.recv()).await {
Ok(result) => result.map(Some),
Err(_) => Ok(None),
}
}
pub async fn send_recv(
&mut self,
msg: &CommandMessage,
timeout: std::time::Duration,
) -> Result<CommandMessage, IpcError> {
let transaction_id = msg.transaction_id;
self.send(msg).await?;
let deadline = tokio::time::Instant::now() + timeout;
loop {
let remaining = deadline.saturating_duration_since(tokio::time::Instant::now());
if remaining.is_zero() {
return Err(IpcError::Timeout(format!(
"Timeout waiting for response to transaction {}",
transaction_id
)));
}
match self.recv_timeout(remaining).await? {
Some(response) => {
if response.transaction_id == transaction_id {
return Ok(response);
}
}
None => {
return Err(IpcError::Timeout(format!(
"Timeout waiting for response to transaction {}",
transaction_id
)));
}
}
}
}
pub async fn close(mut self) -> Result<(), IpcError> {
self.writer.flush().await?;
self.writer.shutdown().await?;
Ok(())
}
}
pub struct IpcTransport {
inner: Arc<Mutex<FramedStream>>,
}
impl IpcTransport {
pub fn new(stream: TcpStream) -> Self {
Self {
inner: Arc::new(Mutex::new(FramedStream::new(stream))),
}
}
pub fn from_framed(stream: FramedStream) -> Self {
Self {
inner: Arc::new(Mutex::new(stream)),
}
}
pub async fn send(&self, msg: &CommandMessage) -> Result<(), IpcError> {
let mut guard = self.inner.lock().await;
guard.send(msg).await
}
pub async fn recv(&self) -> Result<CommandMessage, IpcError> {
let mut guard = self.inner.lock().await;
guard.recv().await
}
pub async fn recv_timeout(
&self,
timeout: std::time::Duration,
) -> Result<Option<CommandMessage>, IpcError> {
let mut guard = self.inner.lock().await;
guard.recv_timeout(timeout).await
}
pub async fn send_recv(
&self,
msg: &CommandMessage,
timeout: std::time::Duration,
) -> Result<CommandMessage, IpcError> {
let mut guard = self.inner.lock().await;
guard.send_recv(msg, timeout).await
}
pub fn clone_handle(&self) -> Self {
Self {
inner: Arc::clone(&self.inner),
}
}
}
impl Clone for IpcTransport {
fn clone(&self) -> Self {
self.clone_handle()
}
}
pub struct SplitTransport {
reader: Arc<Mutex<BufReader<OwnedReadHalf>>>,
writer: Arc<Mutex<BufWriter<OwnedWriteHalf>>>,
}
impl std::fmt::Debug for SplitTransport {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SplitTransport").finish_non_exhaustive()
}
}
impl SplitTransport {
pub fn new(stream: TcpStream) -> Self {
let (read_half, write_half) = stream.into_split();
Self {
reader: Arc::new(Mutex::new(BufReader::with_capacity(BUFFER_SIZE, read_half))),
writer: Arc::new(Mutex::new(BufWriter::with_capacity(BUFFER_SIZE, write_half))),
}
}
pub async fn send(&self, msg: &CommandMessage) -> Result<(), IpcError> {
let payload = serde_json::to_vec(msg)?;
let payload_len = payload.len() as u32;
if payload_len > MAX_MESSAGE_SIZE {
return Err(IpcError::Framing(format!(
"Message too large: {} bytes (max {})",
payload_len, MAX_MESSAGE_SIZE
)));
}
let header = IpcMessageHeader::new(msg.message_type, payload_len);
let mut writer = self.writer.lock().await;
writer.write_all(&header.to_bytes()).await?;
writer.write_all(&payload).await?;
writer.flush().await?;
Ok(())
}
pub async fn recv(&self) -> Result<CommandMessage, IpcError> {
let mut reader = self.reader.lock().await;
let mut header_bytes = [0u8; IpcMessageHeader::SIZE];
reader.read_exact(&mut header_bytes).await?;
let header = IpcMessageHeader::from_bytes(&header_bytes);
if header.length > MAX_MESSAGE_SIZE {
return Err(IpcError::Framing(format!(
"Message too large: {} bytes (max {})",
header.length, MAX_MESSAGE_SIZE
)));
}
let mut payload = vec![0u8; header.length as usize];
reader.read_exact(&mut payload).await?;
let msg: CommandMessage = serde_json::from_slice(&payload)?;
Ok(msg)
}
pub fn clone_reader(&self) -> Arc<Mutex<BufReader<OwnedReadHalf>>> {
Arc::clone(&self.reader)
}
pub fn clone_writer(&self) -> Arc<Mutex<BufWriter<OwnedWriteHalf>>> {
Arc::clone(&self.writer)
}
}
impl Clone for SplitTransport {
fn clone(&self) -> Self {
Self {
reader: Arc::clone(&self.reader),
writer: Arc::clone(&self.writer),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ipc::MessageType;
use tokio::net::TcpListener;
#[tokio::test]
async fn test_framed_stream_roundtrip() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let server_task = tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();
let mut framed = FramedStream::new(stream);
let msg = framed.recv().await.unwrap();
assert_eq!(msg.message_type, MessageType::Request);
assert_eq!(msg.topic, "test.function");
let response = msg.into_response(serde_json::json!({"status": "ok"}));
framed.send(&response).await.unwrap();
});
let stream = TcpStream::connect(addr).await.unwrap();
let mut framed = FramedStream::new(stream);
let request = CommandMessage::request("test.function", serde_json::json!({"key": "value"}));
let request_id = request.transaction_id;
framed.send(&request).await.unwrap();
let response = framed.recv().await.unwrap();
assert_eq!(response.message_type, MessageType::Response);
assert_eq!(response.transaction_id, request_id);
assert!(response.success);
server_task.await.unwrap();
}
#[tokio::test]
async fn test_send_recv_with_timeout() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let server_task = tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();
let mut framed = FramedStream::new(stream);
let msg = framed.recv().await.unwrap();
let response = msg.into_response(serde_json::json!(42));
framed.send(&response).await.unwrap();
});
let stream = TcpStream::connect(addr).await.unwrap();
let mut framed = FramedStream::new(stream);
let request = CommandMessage::read("test.value");
let response = framed
.send_recv(&request, std::time::Duration::from_secs(5))
.await
.unwrap();
assert!(response.success);
assert_eq!(response.data, serde_json::json!(42));
server_task.await.unwrap();
}
#[tokio::test]
async fn test_large_message_rejection() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let _server_task = tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();
let mut framed = FramedStream::new(stream);
let result = framed.recv().await;
assert!(result.is_err());
});
let mut stream = TcpStream::connect(addr).await.unwrap();
let bad_header = IpcMessageHeader::new(MessageType::Request, MAX_MESSAGE_SIZE + 1);
stream.write_all(&bad_header.to_bytes()).await.unwrap();
}
#[tokio::test]
async fn test_split_transport() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let server_task = tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();
let transport = SplitTransport::new(stream);
let msg = transport.recv().await.unwrap();
assert_eq!(msg.topic, "split.test");
let response = msg.into_response(serde_json::json!("pong"));
transport.send(&response).await.unwrap();
});
let stream = TcpStream::connect(addr).await.unwrap();
let transport = SplitTransport::new(stream);
let request = CommandMessage::request("split.test", serde_json::json!("ping"));
transport.send(&request).await.unwrap();
let response = transport.recv().await.unwrap();
assert!(response.success);
server_task.await.unwrap();
}
}