#![warn(clippy::all, clippy::pedantic, clippy::nursery, clippy::cargo)]
#![allow(clippy::module_name_repetitions)]
#![cfg_attr(
test,
allow(
unused,
clippy::all,
clippy::pedantic,
clippy::nursery,
clippy::dbg_macro,
clippy::unwrap_used,
clippy::missing_docs_in_private_items,
)
)]
use std::fs::{metadata, remove_file};
use std::marker::PhantomData;
use std::path::{Path, PathBuf};
use std::time::Instant;
use tokio::net::UnixDatagram;
use gistit_proto::bytes::BytesMut;
use gistit_proto::prost::{self, Message};
use gistit_proto::Instruction;
pub type Result<T> = std::result::Result<T, Error>;
const NAMED_SOCKET_0: &str = "gistit-0";
const NAMED_SOCKET_1: &str = "gistit-1";
const READBUF_SIZE: usize = 60_000; const CONNECT_TIMEOUT_SECS: u64 = 3;
pub trait SockEnd {}
#[derive(Debug)]
pub struct Server;
impl SockEnd for Server {}
#[derive(Debug)]
pub struct Client;
impl SockEnd for Client {}
#[derive(Debug)]
pub struct Bridge<T: SockEnd> {
pub sock_0: UnixDatagram,
pub sock_1: UnixDatagram,
base: PathBuf,
__marker_t: PhantomData<T>,
}
pub fn server(base: &Path) -> Result<Bridge<Server>> {
let sockpath_0 = &base.join(NAMED_SOCKET_0);
if metadata(sockpath_0).is_ok() {
remove_file(sockpath_0)?;
}
log::trace!("Bind sock_0 (server) at {:?}", sockpath_0);
let sock_0 = UnixDatagram::bind(sockpath_0)?;
Ok(Bridge {
sock_0,
sock_1: UnixDatagram::unbound()?,
base: base.to_path_buf(),
__marker_t: PhantomData,
})
}
pub fn client(base: &Path) -> Result<Bridge<Client>> {
let sockpath_1 = &base.join(NAMED_SOCKET_1);
if metadata(sockpath_1).is_ok() {
remove_file(sockpath_1)?;
}
log::trace!("Bind sock_1 (client) at {:?}", sockpath_1);
let sock_1 = UnixDatagram::bind(sockpath_1)?;
Ok(Bridge {
sock_0: UnixDatagram::unbound()?,
sock_1,
base: base.to_path_buf(),
__marker_t: PhantomData,
})
}
fn __alive(base: &Path, dgram: &UnixDatagram, sock_name: &str) -> bool {
!matches!(dgram.connect(base.join(sock_name)), Err(_))
}
fn __connect_blocking(base: &Path, dgram: &UnixDatagram, sock_name: &str) -> Result<()> {
let earlier = Instant::now();
while let Err(err) = dgram.connect(base.join(sock_name)) {
if Instant::now().duration_since(earlier).as_secs() > CONNECT_TIMEOUT_SECS {
return Err(err.into());
}
}
log::trace!("Connecting to {:?}", sock_name);
Ok(())
}
impl Bridge<Server> {
pub fn alive(&self) -> bool {
__alive(&self.base, &self.sock_1, NAMED_SOCKET_1)
}
pub fn connect_blocking(&mut self) -> Result<()> {
__connect_blocking(&self.base, &self.sock_1, NAMED_SOCKET_1)
}
pub async fn send(&self, instruction: Instruction) -> Result<()> {
let mut buf = BytesMut::with_capacity(READBUF_SIZE);
instruction.encode(&mut buf)?;
log::trace!("Sending to client {} bytes", buf.len());
self.sock_1.send(&buf).await?;
Ok(())
}
pub async fn recv(&self) -> Result<Instruction> {
let mut buf = vec![0u8; READBUF_SIZE];
let read = self.sock_0.recv(&mut buf).await?;
buf.truncate(read);
let target = Instruction::decode(&*buf)?;
Ok(target)
}
}
impl Bridge<Client> {
pub fn alive(&self) -> bool {
__alive(&self.base, &self.sock_0, NAMED_SOCKET_0)
}
pub fn connect_blocking(&mut self) -> Result<()> {
__connect_blocking(&self.base, &self.sock_0, NAMED_SOCKET_0)
}
pub async fn send(&self, instruction: Instruction) -> Result<()> {
let mut buf = BytesMut::with_capacity(READBUF_SIZE);
instruction.encode(&mut buf)?;
log::trace!("Sending to server {} bytes", buf.len());
self.sock_0.send(&*buf).await?;
Ok(())
}
pub async fn recv(&self) -> Result<Instruction> {
let mut buf = vec![0u8; READBUF_SIZE];
let read = self.sock_1.recv(&mut buf).await?;
buf.truncate(read);
let target = Instruction::decode(&*buf)?;
Ok(target)
}
}
#[derive(thiserror::Error, Debug)]
pub enum Error {
#[error("io error {0}")]
IO(#[from] std::io::Error),
#[error("decode error {0}")]
Decode(#[from] prost::DecodeError),
#[error("encode error {0}")]
Encode(#[from] prost::EncodeError),
}
#[cfg(test)]
mod tests {
use super::*;
use assert_fs::prelude::*;
use std::sync::Arc;
pub fn test_instruction_1() -> Instruction {
Instruction::request_status()
}
pub fn test_instruction_2() -> Instruction {
Instruction::request_shutdown()
}
#[tokio::test]
async fn ipc_named_socket_spawn() {
let tmp = assert_fs::TempDir::new().unwrap();
let _ = server(&tmp).unwrap();
let _ = client(&tmp).unwrap();
assert!(tmp.child("gistit-0").exists());
assert!(tmp.child("gistit-1").exists());
}
#[tokio::test]
async fn ipc_socket_spawn_is_alive() {
let tmp = assert_fs::TempDir::new().unwrap();
let server = server(&tmp).unwrap();
let client = client(&tmp).unwrap();
assert!(server.alive());
assert!(client.alive());
}
#[tokio::test]
async fn ipc_socket_server_recv_traffic() {
let tmp = assert_fs::TempDir::new().unwrap();
let server = server(&tmp).unwrap();
let mut client = client(&tmp).unwrap();
client.connect_blocking().unwrap();
client.send(test_instruction_1()).await.unwrap();
client.send(test_instruction_2()).await.unwrap();
assert_eq!(server.recv().await.unwrap(), test_instruction_1());
assert_eq!(server.recv().await.unwrap(), test_instruction_2());
}
#[tokio::test]
async fn ipc_socket_client_recv_traffic() {
let tmp = assert_fs::TempDir::new().unwrap();
let mut server = server(&tmp).unwrap();
let client = client(&tmp).unwrap();
server.connect_blocking().unwrap();
server.send(test_instruction_1()).await.unwrap();
server.send(test_instruction_2()).await.unwrap();
assert_eq!(client.recv().await.unwrap(), test_instruction_1());
assert_eq!(client.recv().await.unwrap(), test_instruction_2());
}
#[tokio::test]
async fn ipc_socket_alternate_traffic() {
let tmp = assert_fs::TempDir::new().unwrap();
let mut server = server(&tmp).unwrap();
let mut client = client(&tmp).unwrap();
client.connect_blocking().unwrap();
server.connect_blocking().unwrap();
client.send(test_instruction_1()).await.unwrap();
client.send(test_instruction_2()).await.unwrap();
server.send(test_instruction_1()).await.unwrap();
server.send(test_instruction_2()).await.unwrap();
assert_eq!(client.recv().await.unwrap(), test_instruction_1());
assert_eq!(server.recv().await.unwrap(), test_instruction_1());
assert_eq!(client.recv().await.unwrap(), test_instruction_2());
assert_eq!(server.recv().await.unwrap(), test_instruction_2());
}
#[tokio::test]
async fn ipc_socket_alternate_traffic_rerun() {
let tmp = assert_fs::TempDir::new().unwrap();
let mut server = server(&tmp).unwrap();
let mut client = client(&tmp).unwrap();
client.connect_blocking().unwrap();
server.connect_blocking().unwrap();
client.send(test_instruction_1()).await.unwrap();
client.send(test_instruction_2()).await.unwrap();
server.send(test_instruction_1()).await.unwrap();
server.send(test_instruction_2()).await.unwrap();
assert_eq!(client.recv().await.unwrap(), test_instruction_1());
assert_eq!(server.recv().await.unwrap(), test_instruction_1());
assert_eq!(client.recv().await.unwrap(), test_instruction_2());
assert_eq!(server.recv().await.unwrap(), test_instruction_2());
client.send(test_instruction_1()).await.unwrap();
client.send(test_instruction_2()).await.unwrap();
server.send(test_instruction_1()).await.unwrap();
server.send(test_instruction_2()).await.unwrap();
assert_eq!(client.recv().await.unwrap(), test_instruction_1());
assert_eq!(server.recv().await.unwrap(), test_instruction_1());
assert_eq!(client.recv().await.unwrap(), test_instruction_2());
assert_eq!(server.recv().await.unwrap(), test_instruction_2());
}
#[tokio::test]
async fn ipc_socket_traffic_under_load() {
let tmp = assert_fs::TempDir::new().unwrap();
let mut server = server(&tmp).unwrap();
let mut client = client(&tmp).unwrap();
client.connect_blocking().unwrap();
server.connect_blocking().unwrap();
let server = Arc::new(server);
let client = Arc::new(client);
for _ in 0..8 {
let s = server.clone();
let c = client.clone();
tokio::spawn(async move {
loop {
c.send(test_instruction_1()).await.unwrap();
c.send(test_instruction_2()).await.unwrap();
s.send(test_instruction_1()).await.unwrap();
s.send(test_instruction_2()).await.unwrap();
}
});
assert_eq!(client.recv().await.unwrap(), test_instruction_1());
assert_eq!(server.recv().await.unwrap(), test_instruction_1());
assert_eq!(client.recv().await.unwrap(), test_instruction_2());
assert_eq!(server.recv().await.unwrap(), test_instruction_2());
}
}
}