use std::io;
use std::path::Path;
use tokio::io::{AsyncRead, AsyncWrite};
pub trait IpcIo: AsyncRead + AsyncWrite + Unpin + Send {}
impl<T> IpcIo for T where T: AsyncRead + AsyncWrite + Unpin + Send {}
pub type IpcStream = Box<dyn IpcIo>;
#[cfg(unix)]
use tokio::net::UnixListener;
#[cfg(unix)]
use tokio::net::UnixStream;
pub struct IpcListener {
#[cfg(unix)]
inner: UnixListener,
#[cfg(windows)]
name: String,
#[cfg(windows)]
first: std::sync::Mutex<Option<tokio::net::windows::named_pipe::NamedPipeServer>>,
}
pub async fn bind(path: &Path) -> io::Result<IpcListener> {
#[cfg(unix)]
{
if path.exists() {
let _ = std::fs::remove_file(path);
}
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent)?;
}
let inner = UnixListener::bind(path)?;
Ok(IpcListener { inner })
}
#[cfg(windows)]
{
use tokio::net::windows::named_pipe::ServerOptions;
let name = pipe_name_from_path(path);
let first = ServerOptions::new()
.first_pipe_instance(true)
.create(&name)?;
Ok(IpcListener {
name,
first: std::sync::Mutex::new(Some(first)),
})
}
}
pub async fn connect(path: &Path) -> io::Result<IpcStream> {
#[cfg(unix)]
{
let stream = UnixStream::connect(path).await?;
Ok(Box::new(stream))
}
#[cfg(windows)]
{
use tokio::net::windows::named_pipe::ClientOptions;
let name = pipe_name_from_path(path);
let pipe = ClientOptions::new().open(&name)?;
Ok(Box::new(pipe))
}
}
impl IpcListener {
pub async fn accept(&self) -> io::Result<IpcStream> {
#[cfg(unix)]
{
let (stream, _addr) = self.inner.accept().await?;
Ok(Box::new(stream))
}
#[cfg(windows)]
{
use tokio::net::windows::named_pipe::ServerOptions;
let server = if let Some(first) = self.first.lock().unwrap().take() {
first
} else {
ServerOptions::new()
.first_pipe_instance(false)
.create(&self.name)?
};
server.connect().await?;
Ok(Box::new(server))
}
}
}
#[cfg(windows)]
fn pipe_name_from_path(path: &Path) -> String {
let name = path.to_string_lossy().replace('/', "-").replace('\\', "-");
format!(r"\\.\pipe\gitgrip-{}", name)
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
#[cfg(unix)]
#[tokio::test]
async fn test_unix_socket_roundtrip() {
let dir = tempfile::tempdir().unwrap();
let sock = dir.path().join("test.sock");
let listener = bind(&sock).await.unwrap();
let client_task = tokio::spawn({
let sock = sock.clone();
async move {
let mut stream = connect(&sock).await.unwrap();
stream.write_all(b"hello\n").await.unwrap();
let mut buf = vec![0u8; 64];
let n = stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string()
}
});
let mut server_stream = listener.accept().await.unwrap();
let mut buf = vec![0u8; 64];
let n = server_stream.read(&mut buf).await.unwrap();
assert_eq!(&buf[..n], b"hello\n");
server_stream.write_all(b"world\n").await.unwrap();
let response = client_task.await.unwrap();
assert_eq!(response, "world\n");
}
}