use bytes::Bytes;
use dashmap::DashMap;
use flume::{Receiver, Sender};
use std::io;
pub type InprocMessage = Vec<Bytes>;
pub type InprocSender = Sender<InprocMessage>;
pub type InprocReceiver = Receiver<InprocMessage>;
static INPROC_REGISTRY: once_cell::sync::Lazy<DashMap<String, InprocSender>> =
once_cell::sync::Lazy::new(DashMap::new);
static INPROC_REPLY_REGISTRY: once_cell::sync::Lazy<DashMap<String, InprocSender>> =
once_cell::sync::Lazy::new(DashMap::new);
pub fn bind_inproc(endpoint: &str) -> io::Result<(InprocSender, InprocReceiver)> {
let name = validate_and_extract_name(endpoint)?;
let (tx, rx) = flume::unbounded();
if INPROC_REGISTRY
.insert(name.to_string(), tx.clone())
.is_some()
{
return Err(io::Error::new(
io::ErrorKind::AddrInUse,
format!("inproc endpoint '{name}' is already bound"),
));
}
Ok((tx, rx))
}
pub fn connect_inproc(endpoint: &str) -> io::Result<InprocSender> {
let name = validate_and_extract_name(endpoint)?;
if let Some(sender) = INPROC_REGISTRY.get(name) {
return Ok(sender.clone());
}
Err(io::Error::new(
io::ErrorKind::NotFound,
format!("inproc endpoint '{name}' not found (must bind before connect)"),
))
}
pub fn unbind_inproc(endpoint: &str) -> io::Result<()> {
let name = validate_and_extract_name(endpoint)?;
INPROC_REGISTRY.remove(name);
INPROC_REPLY_REGISTRY.remove(name);
Ok(())
}
pub fn bind_inproc_bidi(
endpoint: &str,
) -> io::Result<(InprocSender, InprocReceiver, InprocSender, InprocReceiver)> {
let name = validate_and_extract_name(endpoint)?;
let (client_to_server_tx, client_to_server_rx) = flume::unbounded::<InprocMessage>();
let (server_to_client_tx, server_to_client_rx) = flume::unbounded::<InprocMessage>();
if INPROC_REGISTRY
.insert(name.to_string(), client_to_server_tx.clone())
.is_some()
{
return Err(io::Error::new(
io::ErrorKind::AddrInUse,
format!("inproc endpoint '{name}' is already bound"),
));
}
INPROC_REPLY_REGISTRY.insert(name.to_string(), server_to_client_tx.clone());
Ok((
server_to_client_tx,
client_to_server_rx,
client_to_server_tx,
server_to_client_rx,
))
}
pub fn connect_inproc_bidi(endpoint: &str) -> io::Result<(InprocSender, InprocReceiver)> {
let name = validate_and_extract_name(endpoint)?;
let to_server = INPROC_REGISTRY
.get(name)
.map(|r| r.clone())
.ok_or_else(|| {
io::Error::new(
io::ErrorKind::NotFound,
format!("inproc endpoint '{name}' not found (must bind before connect)"),
)
})?;
let from_server = INPROC_REPLY_REGISTRY
.get(name)
.map(|r| r.clone())
.ok_or_else(|| {
io::Error::new(
io::ErrorKind::NotFound,
format!(
"inproc reply channel for '{name}' not found; \
use bind_inproc_bidi on the server side"
),
)
})?;
let (our_reply_tx, our_reply_rx) = flume::unbounded::<InprocMessage>();
INPROC_REPLY_REGISTRY.insert(name.to_string(), our_reply_tx);
let _ = from_server;
Ok((to_server, our_reply_rx))
}
pub fn list_inproc_endpoints() -> Vec<String> {
INPROC_REGISTRY
.iter()
.map(|entry| entry.key().clone())
.collect()
}
fn validate_and_extract_name(endpoint: &str) -> io::Result<&str> {
const PREFIX: &str = "inproc://";
if !endpoint.starts_with(PREFIX) {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!("inproc endpoint must start with '{PREFIX}', got: '{endpoint}'"),
));
}
let name = &endpoint[PREFIX.len()..];
if name.is_empty() {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"inproc endpoint name cannot be empty",
));
}
Ok(name)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_validate_endpoint() {
assert!(validate_and_extract_name("inproc://test").is_ok());
assert_eq!(validate_and_extract_name("inproc://test").unwrap(), "test");
assert!(validate_and_extract_name("tcp://test").is_err());
assert!(validate_and_extract_name("inproc://").is_err());
assert!(validate_and_extract_name("").is_err());
}
#[test]
fn test_bind_duplicate() {
let endpoint = "inproc://test-duplicate";
let _result1 = bind_inproc(endpoint);
assert!(_result1.is_ok());
let result2 = bind_inproc(endpoint);
assert!(result2.is_err());
assert_eq!(result2.unwrap_err().kind(), io::ErrorKind::AddrInUse);
let _ = unbind_inproc(endpoint);
}
#[test]
fn test_bind_and_connect() {
let endpoint = "inproc://test-connect";
let (_tx, rx) = bind_inproc(endpoint).unwrap();
let client = connect_inproc(endpoint).unwrap();
let msg = vec![Bytes::from("Hello, inproc!")];
client.send(msg.clone()).unwrap();
let received = rx
.recv_timeout(std::time::Duration::from_millis(100))
.unwrap();
assert_eq!(received, msg);
unbind_inproc(endpoint).unwrap();
}
#[test]
fn test_list_endpoints() {
let ep1 = "inproc://test-list-1";
let ep2 = "inproc://test-list-2";
let _bind1 = bind_inproc(ep1).unwrap();
let _bind2 = bind_inproc(ep2).unwrap();
let endpoints = list_inproc_endpoints();
assert!(endpoints.contains(&"test-list-1".to_string()));
assert!(endpoints.contains(&"test-list-2".to_string()));
unbind_inproc(ep1).unwrap();
unbind_inproc(ep2).unwrap();
}
}