#![warn(missing_docs, clippy::pedantic, clippy::use_self)]
use {
interprocess::local_socket::{
self, GenericNamespaced, ListenerNonblockingMode, ListenerOptions, ToNsName,
traits::{Listener as _, Stream as _},
},
std::{
io::{Read, Write},
time::{Duration, Instant},
},
};
use std::io::ErrorKind;
pub enum Endpoint {
New(Listener),
Existing(Stream),
}
pub struct Listener(local_socket::Listener);
impl Listener {
pub fn accept(&self) -> Option<Stream> {
match self.0.accept() {
Ok(stream) => Some(Stream(stream)),
Err(e) => {
log::error!("{e:?}");
None
}
}
}
}
#[derive(Debug, PartialEq)]
#[repr(u8)]
pub enum Msg {
Num(usize) = 0,
Bytes(Vec<u8>),
String(String),
Nudge,
}
fn write_u8(num: u8, stream: &mut local_socket::Stream) -> std::io::Result<()> {
stream.write_all(std::slice::from_ref(&num))
}
fn read_u8(stream: &mut local_socket::Stream) -> std::io::Result<u8> {
let mut num: u8 = 0;
stream.read_exact(std::slice::from_mut(&mut num))?;
Ok(num)
}
fn write_usize(num: usize, stream: &mut local_socket::Stream) -> std::io::Result<()> {
let bytes = num.to_le_bytes();
stream.write_all(&bytes)
}
fn read_usize(stream: &mut local_socket::Stream) -> std::io::Result<usize> {
let mut buf = [0; std::mem::size_of::<usize>()];
stream.read_exact(&mut buf)?;
Ok(usize::from_le_bytes(buf))
}
fn read_vec(stream: &mut local_socket::Stream) -> std::io::Result<Vec<u8>> {
let len = read_usize(stream)?;
log::debug!("read_vec: length: {len}");
let mut buf = vec![0; len];
stream.read_exact(&mut buf)?;
Ok(buf)
}
impl Msg {
const fn discriminant(&self) -> u8 {
unsafe { *std::ptr::from_ref(self).cast() }
}
fn write(self, stream: &mut local_socket::Stream) {
let discriminant = self.discriminant();
log::debug!("Writing discriminant {discriminant}");
write_u8(discriminant, stream).unwrap();
match self {
Self::Num(n) => {
write_usize(n, stream).unwrap();
}
Self::Bytes(bytes) => {
write_usize(bytes.len(), stream).unwrap();
log::debug!("Wrote byte length: {}", bytes.len());
stream.write_all(&bytes).unwrap();
}
Self::String(str) => {
write_usize(str.len(), stream).unwrap();
log::debug!("Wrote byte length: {}", str.len());
stream.write_all(str.as_bytes()).unwrap();
}
Self::Nudge => {}
}
}
fn read(stream: &mut local_socket::Stream) -> std::io::Result<Self> {
let discriminant = read_u8(stream)?;
log::debug!("Read discriminant {discriminant}");
match discriminant {
0 => Ok(Self::Num(read_usize(stream)?)),
1 => Ok(Self::Bytes(read_vec(stream)?)),
2 => {
log::debug!("Reading string...");
let bytes = read_vec(stream)?;
Ok(Self::String(String::from_utf8_lossy(&bytes).into_owned()))
}
3 => Ok(Self::Nudge),
etc => panic!("Unknown message discriminant {etc}"),
}
}
}
#[test]
fn test_msg_discriminant() {
assert_eq!(Msg::Num(42).discriminant(), 0);
assert_eq!(Msg::Bytes(vec![42]).discriminant(), 1);
assert_eq!(Msg::String("Hello world".into()).discriminant(), 2);
assert_eq!(Msg::Nudge.discriminant(), 3);
}
pub struct Stream(local_socket::Stream);
impl Stream {
pub fn send(&mut self, msg: Msg) {
msg.write(&mut self.0);
}
pub fn recv(&mut self) -> Option<Msg> {
match Msg::read(&mut self.0) {
Ok(msg) => Some(msg),
Err(e) => {
log::error!("Stream::recv error: {e}");
None
}
}
}
}
pub fn establish_endpoint(id: &str, nonblocking: bool) -> std::io::Result<Endpoint> {
let ns_name = id.to_ns_name::<GenericNamespaced>()?;
match local_socket::Stream::connect(ns_name.clone()) {
Ok(stream) => Ok(Endpoint::Existing(Stream(stream))),
Err(e) => match e.kind() {
ErrorKind::NotFound | ErrorKind::ConnectionRefused => {
let nb_mode = if nonblocking {
ListenerNonblockingMode::Both
} else {
ListenerNonblockingMode::Neither
};
let listener = ListenerOptions::default()
.name(ns_name.clone())
.nonblocking(nb_mode)
.create_sync()?;
log::info!("Established new endpoint with name {ns_name:?}");
Ok(Endpoint::New(Listener(listener)))
}
_ => Err(e),
},
}
}
pub fn wait_to_be_new(
id: &str,
nonblocking: bool,
sleep_ms: u64,
timeout_ms: u64,
) -> std::io::Result<Listener> {
let start = Instant::now();
loop {
match establish_endpoint(id, nonblocking)? {
Endpoint::New(listener) => return Ok(listener),
Endpoint::Existing(_) => {}
}
std::thread::sleep(Duration::from_millis(sleep_ms));
if start.elapsed().as_millis() > u128::from(timeout_ms) {
return Err(std::io::Error::new(std::io::ErrorKind::TimedOut, "timeout"));
}
}
}