use std::io;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
use tokio::net::{TcpListener as TokioTcpListener, TcpStream as TokioTcpStream, tcp::{OwnedReadHalf, OwnedWriteHalf}};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::sync::Mutex;
use tokio::time::timeout;
use crate::matrixrpc::protocol::JsonRpcMessage;
use super::{Transport, TransportConfig};
const FRAME_HEADER_SIZE: usize = 4;
pub struct TcpTransport {
reader: Arc<Mutex<Option<OwnedReadHalf>>>,
writer: Arc<Mutex<Option<OwnedWriteHalf>>>,
config: TransportConfig,
remote_addr: Option<SocketAddr>,
is_closed: bool,
}
impl TcpTransport {
pub async fn connect(addr: &str) -> io::Result<Self> {
Self::connect_with_config(addr, TransportConfig::default()).await
}
pub async fn connect_with_config(addr: &str, config: TransportConfig) -> io::Result<Self> {
let stream = TokioTcpStream::connect(addr).await?;
let remote_addr = stream.peer_addr().ok();
let (reader, writer) = stream.into_split();
Ok(Self {
reader: Arc::new(Mutex::new(Some(reader))),
writer: Arc::new(Mutex::new(Some(writer))),
config,
remote_addr,
is_closed: false,
})
}
pub fn from_stream(stream: TokioTcpStream, config: TransportConfig) -> Self {
let remote_addr = stream.peer_addr().ok();
let (reader, writer) = stream.into_split();
Self {
reader: Arc::new(Mutex::new(Some(reader))),
writer: Arc::new(Mutex::new(Some(writer))),
config,
remote_addr,
is_closed: false,
}
}
pub fn remote_addr(&self) -> Option<SocketAddr> {
self.remote_addr
}
fn encode_frame(message: &JsonRpcMessage) -> io::Result<Vec<u8>> {
let json = message.to_json().map_err(|e| {
io::Error::new(
io::ErrorKind::InvalidData,
format!("JSON encode error: {}", e),
)
})?;
let json_bytes = json.into_bytes();
let length = json_bytes.len() as u32;
let mut frame = Vec::with_capacity(FRAME_HEADER_SIZE + json_bytes.len());
frame.extend_from_slice(&length.to_be_bytes());
frame.extend(json_bytes);
Ok(frame)
}
async fn decode_frame(
reader: &mut OwnedReadHalf,
max_size: usize,
) -> io::Result<Option<JsonRpcMessage>> {
let mut header_buf = [0u8; FRAME_HEADER_SIZE];
match reader.read_exact(&mut header_buf).await {
Ok(_) => {}
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None),
Err(e) => return Err(e),
}
let length = u32::from_be_bytes(header_buf) as usize;
if length > max_size {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("Frame size {} exceeds maximum {}", length, max_size),
));
}
if length == 0 {
return Ok(None);
}
let mut payload_buf = vec![0u8; length];
reader.read_exact(&mut payload_buf).await?;
let json_str = String::from_utf8(payload_buf).map_err(|e| {
io::Error::new(
io::ErrorKind::InvalidData,
format!("UTF-8 decode error: {}", e),
)
})?;
let message = JsonRpcMessage::from_json(&json_str).map_err(|e| {
io::Error::new(
io::ErrorKind::InvalidData,
format!("JSON parse error: {}", e),
)
})?;
Ok(Some(message))
}
}
#[async_trait]
impl Transport for TcpTransport {
async fn send(&mut self, message: &JsonRpcMessage) -> io::Result<()> {
if self.is_closed {
return Err(io::Error::new(
io::ErrorKind::BrokenPipe,
"Transport is closed",
));
}
let writer_guard = self.writer.lock().await;
let _writer = writer_guard.as_ref().ok_or_else(|| {
io::Error::new(io::ErrorKind::BrokenPipe, "No stream available")
})?;
let frame = Self::encode_frame(message)?;
let result = timeout(
Duration::from_millis(self.config.write_timeout_ms),
async {
drop(writer_guard);
let mut writer_guard = self.writer.lock().await;
let writer = writer_guard.as_mut().ok_or_else(|| {
io::Error::new(io::ErrorKind::BrokenPipe, "No stream available")
})?;
writer.write_all(&frame).await
}
)
.await;
match result {
Ok(Ok(_)) => Ok(()),
Ok(Err(e)) => Err(e),
Err(_) => Err(io::Error::new(
io::ErrorKind::TimedOut,
"Write timeout",
)),
}
}
async fn receive(&mut self) -> io::Result<Option<JsonRpcMessage>> {
if self.is_closed {
return Ok(None);
}
let read_result = timeout(
Duration::from_millis(self.config.read_timeout_ms),
async {
let mut reader_guard = self.reader.lock().await;
let reader = reader_guard.as_mut().ok_or_else(|| {
io::Error::new(io::ErrorKind::BrokenPipe, "No stream available")
})?;
Self::decode_frame(reader, self.config.max_message_size).await
}
)
.await;
match read_result {
Ok(Ok(message)) => Ok(message),
Ok(Err(e)) => Err(e),
Err(_) => Err(io::Error::new(
io::ErrorKind::TimedOut,
"Read timeout",
)),
}
}
async fn close(&mut self) -> io::Result<()> {
if self.is_closed {
return Ok(());
}
self.is_closed = true;
let mut reader_guard = self.reader.lock().await;
let mut writer_guard = self.writer.lock().await;
reader_guard.take();
writer_guard.take();
Ok(())
}
fn is_closed(&self) -> bool {
self.is_closed
}
}
pub struct TcpListener {
listener: TokioTcpListener,
local_addr: SocketAddr,
config: TransportConfig,
}
impl TcpListener {
pub async fn bind(port: u16) -> io::Result<Self> {
Self::bind_with_config(port, TransportConfig::default()).await
}
pub async fn bind_with_config(port: u16, config: TransportConfig) -> io::Result<Self> {
let addr: SocketAddr = format!("127.0.0.1:{}", port).parse().unwrap();
let listener = TokioTcpListener::bind(addr).await?;
let local_addr = listener.local_addr()?;
Ok(Self {
listener,
local_addr,
config,
})
}
pub fn local_addr(&self) -> SocketAddr {
self.local_addr
}
pub async fn accept(&self) -> io::Result<TcpTransport> {
let (stream, _addr) = self.listener.accept().await?;
Ok(TcpTransport::from_stream(stream, self.config.clone()))
}
pub fn port(&self) -> u16 {
self.local_addr.port()
}
}
pub const REGISTRY_PORT: u16 = 9527;
pub const CALLBACK_PORT: u16 = 9528;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_encode_frame_simple() {
use crate::matrixrpc::protocol::{JsonRpcRequest, JsonRpcId};
let request = JsonRpcRequest::with_id("test.method", JsonRpcId::String("test-1".to_string()))
.params(serde_json::json!({"param": "value"}));
let message = JsonRpcMessage::Request(request);
let frame = TcpTransport::encode_frame(&message).unwrap();
assert!(frame.len() > FRAME_HEADER_SIZE);
let length = u32::from_be_bytes([
frame[0], frame[1], frame[2], frame[3],
]);
assert!(length > 0);
assert_eq!(frame.len(), FRAME_HEADER_SIZE + length as usize);
}
#[test]
fn test_tcp_config() {
let config = TransportConfig::new()
.max_message_size(1024)
.read_timeout(5000);
assert_eq!(config.max_message_size, 1024);
assert_eq!(config.read_timeout_ms, 5000);
}
#[test]
fn test_frame_header_size() {
assert_eq!(FRAME_HEADER_SIZE, 4);
}
#[test]
fn test_port_constants() {
assert_eq!(REGISTRY_PORT, 9527);
assert_eq!(CALLBACK_PORT, 9528);
}
}