use std::path::PathBuf;
use anyhow::Result;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpStream, UnixListener, UnixStream};
use tokio::sync::oneshot;
pub struct UnixSocketProxy {
socket_path: PathBuf,
shutdown_tx: Option<oneshot::Sender<()>>,
}
impl UnixSocketProxy {
pub async fn new(socket_path: PathBuf, tcp_endpoint: String) -> Result<Self> {
if socket_path.exists() {
std::fs::remove_file(&socket_path)?;
}
let (shutdown_tx, shutdown_rx) = oneshot::channel();
let listener = UnixListener::bind(&socket_path)?;
let tcp_endpoint_clone = tcp_endpoint.clone();
let socket_path_clone = socket_path.clone();
tokio::spawn(async move {
let mut shutdown_rx = shutdown_rx;
loop {
tokio::select! {
result = listener.accept() => {
match result {
Ok((unix_stream, _)) => {
let tcp_endpoint = tcp_endpoint_clone.clone();
tokio::spawn(async move {
if let Err(e) = Self::handle_connection(unix_stream, &tcp_endpoint).await {
eprintln!("Error handling connection: {}", e);
}
});
}
Err(e) => {
eprintln!("Error accepting connection: {}", e);
break;
}
}
}
_ = &mut shutdown_rx => {
break;
}
}
}
let _ = std::fs::remove_file(&socket_path_clone);
});
Ok(Self {
socket_path,
shutdown_tx: Some(shutdown_tx),
})
}
pub fn socket_path(&self) -> &PathBuf {
&self.socket_path
}
async fn handle_connection(unix_stream: UnixStream, tcp_endpoint: &str) -> Result<()> {
let tcp_stream = TcpStream::connect(tcp_endpoint).await?;
let (mut unix_read, mut unix_write) = tokio::io::split(unix_stream);
let (mut tcp_read, mut tcp_write) = tokio::io::split(tcp_stream);
let client_to_server = tokio::spawn(async move {
let mut buf = vec![0u8; 8192];
loop {
match unix_read.read(&mut buf).await {
Ok(0) => break,
Ok(n) => {
if tcp_write.write_all(&buf[..n]).await.is_err() {
break;
}
},
Err(_) => break,
}
}
let _ = tcp_write.shutdown().await;
Ok::<(), std::io::Error>(())
});
let server_to_client = tokio::spawn(async move {
let mut buf = vec![0u8; 8192];
loop {
match tcp_read.read(&mut buf).await {
Ok(0) => break,
Ok(n) => {
if unix_write.write_all(&buf[..n]).await.is_err() {
break;
}
},
Err(_) => break,
}
}
let _ = unix_write.shutdown().await;
Ok::<(), std::io::Error>(())
});
let _ = tokio::try_join!(client_to_server, server_to_client);
Ok(())
}
}
impl Drop for UnixSocketProxy {
fn drop(&mut self) {
if let Some(tx) = self.shutdown_tx.take() {
let _ = tx.send(());
}
if self.socket_path.exists() {
let _ = std::fs::remove_file(&self.socket_path);
}
}
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use tempfile::TempDir;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, UnixStream};
use super::*;
async fn create_echo_server(port: u16) -> (tokio::task::JoinHandle<()>, tokio::sync::oneshot::Sender<()>) {
let (shutdown_tx, mut shutdown_rx) = tokio::sync::oneshot::channel();
let handle = tokio::spawn(async move {
let listener = TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.expect("Failed to bind TCP listener");
loop {
tokio::select! {
result = listener.accept() => {
match result {
Ok((mut stream, _)) => {
tokio::spawn(async move {
let mut buf = vec![0u8; 1024];
loop {
match stream.read(&mut buf).await {
Ok(0) => break,
Ok(n) => {
if stream.write_all(&buf[..n]).await.is_err() {
break;
}
}
Err(_) => break,
}
}
});
}
Err(_) => break,
}
}
_ = &mut shutdown_rx => {
break;
}
}
}
});
(handle, shutdown_tx)
}
#[tokio::test]
async fn test_proxy_creates_socket_file() {
let temp_dir = TempDir::new().unwrap();
let socket_path = temp_dir.path().join("test_proxy.sock");
let _proxy = UnixSocketProxy::new(socket_path.clone(), "127.0.0.1:9999".to_string())
.await
.expect("Failed to create proxy");
assert!(socket_path.exists(), "Socket file should be created");
}
#[tokio::test]
async fn test_proxy_removes_existing_socket() {
let temp_dir = TempDir::new().unwrap();
let socket_path = temp_dir.path().join("test_existing.sock");
std::fs::File::create(&socket_path).unwrap();
assert!(socket_path.exists());
let _proxy = UnixSocketProxy::new(socket_path.clone(), "127.0.0.1:9999".to_string())
.await
.expect("Failed to create proxy");
assert!(socket_path.exists(), "Socket file should exist after proxy creation");
}
#[tokio::test]
async fn test_proxy_forwards_data() {
let temp_dir = TempDir::new().unwrap();
let socket_path = temp_dir.path().join("test_forward.sock");
let tcp_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let tcp_addr = tcp_listener.local_addr().unwrap();
let tcp_port = tcp_addr.port();
drop(tcp_listener);
let (echo_handle, shutdown_tx) = create_echo_server(tcp_port).await;
tokio::time::sleep(Duration::from_millis(100)).await;
let _proxy = UnixSocketProxy::new(socket_path.clone(), format!("127.0.0.1:{}", tcp_port))
.await
.expect("Failed to create proxy");
tokio::time::sleep(Duration::from_millis(100)).await;
let mut unix_stream = UnixStream::connect(&socket_path)
.await
.expect("Failed to connect to Unix socket");
let test_data = b"Hello, proxy!";
unix_stream.write_all(test_data).await.unwrap();
unix_stream.flush().await.unwrap();
let mut response = vec![0u8; test_data.len()];
unix_stream.read_exact(&mut response).await.unwrap();
assert_eq!(&response, test_data, "Proxy should echo data back");
let _ = shutdown_tx.send(());
echo_handle.abort();
}
#[tokio::test]
async fn test_proxy_handles_multiple_connections() {
let temp_dir = TempDir::new().unwrap();
let socket_path = temp_dir.path().join("test_multiple.sock");
let tcp_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let tcp_port = tcp_listener.local_addr().unwrap().port();
drop(tcp_listener); let (echo_handle, shutdown_tx) = create_echo_server(tcp_port).await;
tokio::time::sleep(Duration::from_millis(100)).await;
let _proxy = UnixSocketProxy::new(socket_path.clone(), format!("127.0.0.1:{}", tcp_port))
.await
.expect("Failed to create proxy");
tokio::time::sleep(Duration::from_millis(100)).await;
let mut handles = Vec::new();
for i in 0..5 {
let socket_path = socket_path.clone();
let handle = tokio::spawn(async move {
let mut stream = UnixStream::connect(&socket_path).await.unwrap();
let data = format!("test-{}", i);
stream.write_all(data.as_bytes()).await.unwrap();
stream.flush().await.unwrap();
let mut response = vec![0u8; data.len()];
stream.read_exact(&mut response).await.unwrap();
assert_eq!(response, data.as_bytes());
});
handles.push(handle);
}
for handle in handles {
handle.await.unwrap();
}
let _ = shutdown_tx.send(());
echo_handle.abort();
}
#[tokio::test]
async fn test_proxy_cleans_up_on_drop() {
let temp_dir = TempDir::new().unwrap();
let socket_path = temp_dir.path().join("test_cleanup.sock");
{
let _proxy = UnixSocketProxy::new(socket_path.clone(), "127.0.0.1:9999".to_string())
.await
.expect("Failed to create proxy");
assert!(socket_path.exists());
}
tokio::time::sleep(Duration::from_millis(100)).await;
assert!(!socket_path.exists());
}
#[tokio::test]
async fn test_proxy_socket_path_accessor() {
let temp_dir = TempDir::new().unwrap();
let socket_path = temp_dir.path().join("test_path.sock");
let proxy = UnixSocketProxy::new(socket_path.clone(), "127.0.0.1:9999".to_string())
.await
.expect("Failed to create proxy");
assert_eq!(proxy.socket_path(), &socket_path);
}
#[tokio::test]
#[cfg_attr(feature = "smoke-test", ignore)]
async fn test_client_functionality_through_socket_proxy() {
use std::sync::Arc;
use tempfile::TempDir;
use crate::cas_client::simulation::client_unit_testing::test_client_functionality;
use crate::cas_client::simulation::simulation_server::LocalTestServerBuilder;
let temp_dir = TempDir::new().unwrap();
let socket_path = temp_dir.path().join("test_socket.sock");
test_client_functionality(|| async {
let server = LocalTestServerBuilder::new()
.with_socket(socket_path.clone())
.with_ephemeral_disk()
.start()
.await;
assert!(socket_path.exists(), "Unix socket file should exist");
Arc::new(server) as Arc<dyn crate::cas_client::simulation::DirectAccessClient>
})
.await;
}
}