use crate::error::{ErpcResult, TransportError};
use crate::transport::FramedTransport;
use async_trait::async_trait;
use std::path::Path;
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{UnixListener, UnixStream};
use tokio::time::timeout;
pub struct SocketTransport {
stream: UnixStream,
timeout: Duration,
connected: bool,
}
impl SocketTransport {
pub async fn connect<P: AsRef<Path>>(path: P) -> ErpcResult<Self> {
let stream = UnixStream::connect(path)
.await
.map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
Ok(Self {
stream,
timeout: Duration::from_secs(30),
connected: true,
})
}
pub fn from_stream(stream: UnixStream) -> Self {
Self {
stream,
timeout: Duration::from_secs(30),
connected: true,
}
}
pub async fn listen<P: AsRef<Path>>(path: P) -> ErpcResult<UnixListener> {
if path.as_ref().exists() {
std::fs::remove_file(&path)
.map_err(|e| TransportError::ConnectionFailed(format!("Failed to remove existing socket: {}", e)))?;
}
Ok(UnixListener::bind(path)
.map_err(|e| TransportError::ConnectionFailed(e.to_string()))?)
}
pub async fn accept(listener: &UnixListener) -> ErpcResult<Self> {
let (stream, _) = listener
.accept()
.await
.map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
Ok(Self::from_stream(stream))
}
pub fn local_addr(&self) -> std::io::Result<tokio::net::unix::SocketAddr> {
self.stream.local_addr()
}
pub fn peer_addr(&self) -> std::io::Result<tokio::net::unix::SocketAddr> {
self.stream.peer_addr()
}
}
#[async_trait]
impl FramedTransport for SocketTransport {
async fn base_send(&mut self, data: &[u8]) -> ErpcResult<()> {
if !self.connected {
return Err(TransportError::Closed.into());
}
let send_future = self.stream.write_all(data);
match timeout(self.timeout, send_future).await {
Ok(result) => result.map_err(|e| TransportError::SendFailed(e.to_string()).into()),
Err(_) => Err(TransportError::Timeout.into()),
}
}
async fn base_receive(&mut self, length: usize) -> ErpcResult<Vec<u8>> {
if !self.connected {
return Err(TransportError::Closed.into());
}
let mut buffer = vec![0u8; length];
let mut total_read = 0;
while total_read < length {
let read_future = self.stream.read(&mut buffer[total_read..]);
match timeout(self.timeout, read_future).await {
Ok(Ok(0)) => {
self.connected = false;
return Err(TransportError::ConnectionFailed(
"Connection closed by peer".to_string(),
)
.into());
}
Ok(Ok(bytes_read)) => {
total_read += bytes_read;
}
Ok(Err(e)) => {
return Err(TransportError::ReceiveFailed(e.to_string()).into());
}
Err(_) => {
return Err(TransportError::Timeout.into());
}
}
}
Ok(buffer)
}
fn is_connected(&self) -> bool {
self.connected
}
async fn close(&mut self) -> ErpcResult<()> {
if self.connected {
self.connected = false;
}
Ok(())
}
fn set_timeout(&mut self, timeout: Duration) {
self.timeout = timeout;
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::transport::Transport;
use std::path::PathBuf;
use tempfile::tempdir;
use tokio::time::Duration;
#[tokio::test]
async fn test_socket_transport_basic() {
let temp_dir = tempdir().unwrap();
let socket_path = temp_dir.path().join("test_socket");
let listener = SocketTransport::listen(&socket_path).await.unwrap();
let server_socket_path = socket_path.clone();
let server_handle = tokio::spawn(async move {
let mut server_transport = SocketTransport::accept(&listener).await.unwrap();
let received = server_transport.receive().await.unwrap();
assert_eq!(received, b"Hello, Socket!");
server_transport.send(b"Echo: Hello, Socket!").await.unwrap();
});
tokio::time::sleep(Duration::from_millis(100)).await;
let mut client_transport = SocketTransport::connect(&socket_path).await.unwrap();
client_transport.send(b"Hello, Socket!").await.unwrap();
let response = client_transport.receive().await.unwrap();
assert_eq!(response, b"Echo: Hello, Socket!");
Transport::close(&mut client_transport).await.unwrap();
server_handle.await.unwrap();
}
#[tokio::test]
async fn test_socket_transport_timeout() {
let temp_dir = tempdir().unwrap();
let socket_path = temp_dir.path().join("test_timeout_socket");
let listener = SocketTransport::listen(&socket_path).await.unwrap();
let _server_handle = tokio::spawn(async move {
let _server_transport = SocketTransport::accept(&listener).await.unwrap();
tokio::time::sleep(Duration::from_secs(60)).await;
});
tokio::time::sleep(Duration::from_millis(100)).await;
let mut client_transport = SocketTransport::connect(&socket_path).await.unwrap();
Transport::set_timeout(&mut client_transport, Duration::from_millis(100));
client_transport.send(b"Hello").await.unwrap();
let result = client_transport.receive().await;
assert!(result.is_err());
}
}