use std::sync::Arc;
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub enum TransportKind {
Pipe,
Http,
Unix,
}
impl TransportKind {
pub fn as_str(&self) -> &'static str {
match self {
TransportKind::Pipe => "pipe",
TransportKind::Http => "http",
TransportKind::Unix => "unix",
}
}
}
impl std::fmt::Display for TransportKind {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.as_str())
}
}
#[derive(Clone, Copy, Default, Debug, PartialEq, Eq, Hash)]
pub struct TransportCapabilities {
pub shm: bool,
}
impl TransportCapabilities {
pub const fn none() -> Self {
Self { shm: false }
}
pub const fn shm() -> Self {
Self { shm: true }
}
}
pub type ServeStartHook =
Arc<dyn Fn(TransportKind, &TransportCapabilities) + Send + Sync + 'static>;
#[cfg(test)]
mod tests {
use super::*;
use crate::RpcServer;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Mutex;
#[test]
fn notify_transport_fires_hook_once_per_combination() {
let calls: Arc<Mutex<Vec<(TransportKind, TransportCapabilities)>>> =
Arc::new(Mutex::new(Vec::new()));
let recorder = calls.clone();
let hook: ServeStartHook = Arc::new(move |k, c| {
recorder.lock().unwrap().push((k, *c));
});
let server = RpcServer::builder().on_serve_start(hook).build();
server.notify_transport(TransportKind::Pipe, TransportCapabilities::none());
server.notify_transport(TransportKind::Pipe, TransportCapabilities::none());
server.notify_transport(TransportKind::Pipe, TransportCapabilities::none());
server.notify_transport(TransportKind::Pipe, TransportCapabilities::shm());
server.notify_transport(TransportKind::Http, TransportCapabilities::none());
let log = calls.lock().unwrap().clone();
assert_eq!(
log,
vec![
(TransportKind::Pipe, TransportCapabilities::none()),
(TransportKind::Pipe, TransportCapabilities::shm()),
(TransportKind::Http, TransportCapabilities::none()),
]
);
}
#[test]
fn transport_kind_and_capabilities_observed_after_notify() {
let server = RpcServer::builder().build();
assert!(server.transport_kind().is_none());
assert_eq!(
server.transport_capabilities(),
TransportCapabilities::none()
);
server.notify_transport(TransportKind::Unix, TransportCapabilities::none());
assert_eq!(server.transport_kind(), Some(TransportKind::Unix));
assert_eq!(
server.transport_capabilities(),
TransportCapabilities::none()
);
server.notify_transport(TransportKind::Pipe, TransportCapabilities::shm());
assert_eq!(server.transport_kind(), Some(TransportKind::Pipe));
assert!(server.transport_capabilities().shm);
}
#[test]
fn notify_without_hook_is_no_op_safe() {
let server = RpcServer::builder().build();
server.notify_transport(TransportKind::Http, TransportCapabilities::none());
assert_eq!(server.transport_kind(), Some(TransportKind::Http));
}
#[test]
fn concurrent_notify_fires_hook_once() {
let fire_count = Arc::new(AtomicUsize::new(0));
let counter = fire_count.clone();
let hook: ServeStartHook = Arc::new(move |_, _| {
counter.fetch_add(1, Ordering::Relaxed);
});
let server = Arc::new(RpcServer::builder().on_serve_start(hook).build());
let mut handles = Vec::new();
for _ in 0..32 {
let srv = server.clone();
handles.push(std::thread::spawn(move || {
srv.notify_transport(TransportKind::Http, TransportCapabilities::none());
}));
}
for h in handles {
h.join().unwrap();
}
assert_eq!(fire_count.load(Ordering::Relaxed), 1);
}
#[test]
fn as_str_matches_python_wire_form() {
assert_eq!(TransportKind::Pipe.as_str(), "pipe");
assert_eq!(TransportKind::Http.as_str(), "http");
assert_eq!(TransportKind::Unix.as_str(), "unix");
}
}