use hashbrown::HashMap;
use std::io::{Read as IoRead, Write as IoWrite};
use std::net::SocketAddr;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::thread;
use noxu_sync::Mutex;
use super::channel::{Channel, TcpChannel};
use crate::error::{RepError, Result};
pub const MAX_SERVICE_NAME_LEN: usize = 256;
pub trait ServiceHandler: Send + Sync {
fn handle(&self, channel: Box<dyn Channel>) -> Result<()>;
fn service_name(&self) -> &str;
}
pub struct ServiceDispatcher {
services: Mutex<HashMap<String, Arc<dyn ServiceHandler>>>,
running: AtomicBool,
}
impl ServiceDispatcher {
pub fn new() -> Self {
Self {
services: Mutex::new(HashMap::new()),
running: AtomicBool::new(false),
}
}
pub fn register(&self, handler: Arc<dyn ServiceHandler>) {
let name = handler.service_name().to_string();
let mut services = self.services.lock();
services.insert(name, handler);
}
pub fn unregister(
&self,
service_name: &str,
) -> Option<Arc<dyn ServiceHandler>> {
let mut services = self.services.lock();
services.remove(service_name)
}
pub fn get_handler(&self, name: &str) -> Option<Arc<dyn ServiceHandler>> {
let services = self.services.lock();
services.get(name).cloned()
}
pub fn list_services(&self) -> Vec<String> {
let services = self.services.lock();
let mut names: Vec<String> = services.keys().cloned().collect();
names.sort();
names
}
pub fn start(&self) {
self.running.store(true, Ordering::SeqCst);
}
pub fn stop(&self) {
self.running.store(false, Ordering::SeqCst);
}
pub fn is_running(&self) -> bool {
self.running.load(Ordering::SeqCst)
}
pub fn dispatch(
&self,
service_name: &str,
channel: Box<dyn Channel>,
) -> Result<()> {
let handler = self.get_handler(service_name).ok_or_else(|| {
crate::error::RepError::ServiceNotFound(service_name.to_string())
})?;
handler.handle(channel)
}
}
impl Default for ServiceDispatcher {
fn default() -> Self {
Self::new()
}
}
pub struct TcpServiceDispatcher {
services: Arc<Mutex<HashMap<String, Arc<dyn ServiceHandler>>>>,
addr: SocketAddr,
running: Arc<AtomicBool>,
}
impl TcpServiceDispatcher {
pub fn new(addr: SocketAddr) -> Result<Self> {
Ok(Self {
services: Arc::new(Mutex::new(HashMap::new())),
addr,
running: Arc::new(AtomicBool::new(false)),
})
}
pub fn register(
&self,
name: impl Into<String>,
handler: Arc<dyn ServiceHandler>,
) {
self.services.lock().insert(name.into(), handler);
}
pub fn addr(&self) -> SocketAddr {
self.addr
}
pub fn is_running(&self) -> bool {
self.running.load(Ordering::SeqCst)
}
pub fn stop(&self) {
self.running.store(false, Ordering::SeqCst);
}
pub fn start(&self) -> Result<SocketAddr> {
use std::net::TcpListener;
let listener = TcpListener::bind(self.addr)
.map_err(|e| RepError::NetworkError(e.to_string()))?;
let bound_addr = listener
.local_addr()
.map_err(|e| RepError::NetworkError(e.to_string()))?;
let services = Arc::clone(&self.services);
let running = Arc::clone(&self.running);
running.store(true, Ordering::SeqCst);
thread::spawn(move || {
while running.load(Ordering::SeqCst) {
match listener.accept() {
Ok((stream, _peer_addr)) => {
let services_clone = Arc::clone(&services);
let running_check = Arc::clone(&running);
thread::spawn(move || {
handle_incoming(
stream,
services_clone,
running_check,
);
});
}
Err(_) => {
break;
}
}
}
running.store(false, Ordering::SeqCst);
});
Ok(bound_addr)
}
}
fn handle_incoming(
stream: std::net::TcpStream,
services: Arc<Mutex<HashMap<String, Arc<dyn ServiceHandler>>>>,
_running: Arc<AtomicBool>,
) {
let mut read_stream = match stream.try_clone() {
Ok(s) => s,
Err(_) => return,
};
let mut len_buf = [0u8; 4];
if read_stream.read_exact(&mut len_buf).is_err() {
return;
}
let name_len = u32::from_le_bytes(len_buf) as usize;
if name_len == 0 || name_len > MAX_SERVICE_NAME_LEN {
log::warn!(
"TcpServiceDispatcher: rejected service-name length {} (max {})",
name_len,
MAX_SERVICE_NAME_LEN
);
return;
}
let mut name_buf = vec![0u8; name_len];
if read_stream.read_exact(&mut name_buf).is_err() {
return;
}
let service_name = match String::from_utf8(name_buf) {
Ok(s) => s,
Err(_) => return,
};
drop(read_stream);
let handler = {
let guard = services.lock();
guard.get(&service_name).cloned()
};
if let Some(h) = handler {
let tcp_ch = TcpChannel::new(stream);
let _ = h.handle(Box::new(tcp_ch));
}
}
pub fn connect_to_service(
addr: SocketAddr,
service_name: &str,
) -> Result<TcpChannel> {
use std::net::TcpStream;
let name_bytes = service_name.as_bytes();
if name_bytes.is_empty() || name_bytes.len() > MAX_SERVICE_NAME_LEN {
return Err(RepError::ConfigError(format!(
"service name length {} out of range [1, {}]",
name_bytes.len(),
MAX_SERVICE_NAME_LEN,
)));
}
let mut stream = TcpStream::connect(addr)
.map_err(|e| RepError::NetworkError(e.to_string()))?;
let len = name_bytes.len() as u32;
stream
.write_all(&len.to_le_bytes())
.map_err(|e| RepError::NetworkError(e.to_string()))?;
stream
.write_all(name_bytes)
.map_err(|e| RepError::NetworkError(e.to_string()))?;
stream.flush().map_err(|e| RepError::NetworkError(e.to_string()))?;
Ok(TcpChannel::new(stream))
}
#[cfg(feature = "tls-rustls")]
pub fn connect_to_service_tls(
addr: SocketAddr,
service_name: &str,
tls: &crate::tls::TlsConfig,
) -> Result<super::channel::TlsTcpChannel> {
use super::channel::TlsTcpChannel;
let name_bytes = service_name.as_bytes();
if name_bytes.is_empty() || name_bytes.len() > MAX_SERVICE_NAME_LEN {
return Err(RepError::ConfigError(format!(
"service name length {} out of range [1, {}]",
name_bytes.len(),
MAX_SERVICE_NAME_LEN,
)));
}
let channel = TlsTcpChannel::connect_with_tls(addr, tls)?;
channel.send(name_bytes)?;
Ok(channel)
}
#[cfg(feature = "tls-rustls")]
pub struct TlsTcpServiceDispatcher {
services: Arc<Mutex<HashMap<String, Arc<dyn ServiceHandler>>>>,
bound_addr: SocketAddr,
listener: Arc<super::channel::TlsTcpChannelListener>,
running: Arc<AtomicBool>,
}
#[cfg(feature = "tls-rustls")]
impl TlsTcpServiceDispatcher {
pub fn new(
addr: SocketAddr,
tls: &crate::tls::TlsConfig,
allowlist: crate::auth::PeerAllowlist,
) -> Result<Self> {
let listener =
super::channel::TlsTcpChannelListener::bind_with_tls_and_allowlist(
addr, tls, allowlist,
)?;
let bound_addr = listener.local_addr()?;
Ok(Self {
services: Arc::new(Mutex::new(HashMap::new())),
bound_addr,
listener: Arc::new(listener),
running: Arc::new(AtomicBool::new(false)),
})
}
pub fn register(
&self,
name: impl Into<String>,
handler: Arc<dyn ServiceHandler>,
) {
self.services.lock().insert(name.into(), handler);
}
pub fn addr(&self) -> SocketAddr {
self.bound_addr
}
pub fn is_running(&self) -> bool {
self.running.load(Ordering::SeqCst)
}
pub fn stop(&self) {
self.running.store(false, Ordering::SeqCst);
}
pub fn start(&self) -> Result<SocketAddr> {
let services = Arc::clone(&self.services);
let listener = Arc::clone(&self.listener);
let running = Arc::clone(&self.running);
running.store(true, Ordering::SeqCst);
let bound = self.bound_addr;
thread::spawn(move || {
while running.load(Ordering::SeqCst) {
match listener.accept() {
Ok(channel) => {
let svcs = Arc::clone(&services);
thread::spawn(move || {
handle_tls_incoming(channel, svcs);
});
}
Err(e) => {
if running.load(Ordering::SeqCst) {
log::debug!(
"TlsTcpServiceDispatcher: accept error \
(continuing): {e}"
);
} else {
break;
}
}
}
}
running.store(false, Ordering::SeqCst);
});
Ok(bound)
}
}
#[cfg(feature = "tls-rustls")]
fn handle_tls_incoming(
channel: super::channel::TlsTcpChannel,
services: Arc<Mutex<HashMap<String, Arc<dyn ServiceHandler>>>>,
) {
use crate::net::channel::Channel;
let name_bytes = match channel.receive(std::time::Duration::from_secs(10)) {
Ok(Some(b)) => b,
Ok(None) => {
log::warn!("TlsTcpServiceDispatcher: timeout reading service name");
return;
}
Err(e) => {
log::debug!(
"TlsTcpServiceDispatcher: error reading service name: {e}"
);
return;
}
};
if name_bytes.len() > MAX_SERVICE_NAME_LEN {
log::warn!(
"TlsTcpServiceDispatcher: rejected service-name length {} (max {})",
name_bytes.len(),
MAX_SERVICE_NAME_LEN
);
return;
}
let service_name = match String::from_utf8(name_bytes) {
Ok(s) => s,
Err(_) => {
log::warn!(
"TlsTcpServiceDispatcher: non-UTF-8 service name, closing"
);
return;
}
};
let handler = {
let guard = services.lock();
guard.get(&service_name).cloned()
};
if let Some(h) = handler {
let _ = h.handle(Box::new(channel));
} else {
log::warn!(
"TlsTcpServiceDispatcher: no handler for service '{service_name}'"
);
}
}
pub(crate) enum AnyServiceDispatcher {
Plain(TcpServiceDispatcher),
#[cfg(feature = "tls-rustls")]
Tls(TlsTcpServiceDispatcher),
}
impl AnyServiceDispatcher {
pub fn register(
&self,
name: impl Into<String>,
handler: Arc<dyn ServiceHandler>,
) {
match self {
Self::Plain(d) => d.register(name, handler),
#[cfg(feature = "tls-rustls")]
Self::Tls(d) => d.register(name, handler),
}
}
pub fn stop(&self) {
match self {
Self::Plain(d) => d.stop(),
#[cfg(feature = "tls-rustls")]
Self::Tls(d) => d.stop(),
}
}
#[allow(dead_code)]
pub fn is_running(&self) -> bool {
match self {
Self::Plain(d) => d.is_running(),
#[cfg(feature = "tls-rustls")]
Self::Tls(d) => d.is_running(),
}
}
#[allow(dead_code)]
pub fn addr(&self) -> SocketAddr {
match self {
Self::Plain(d) => d.addr(),
#[cfg(feature = "tls-rustls")]
Self::Tls(d) => d.addr(),
}
}
pub fn is_tls(&self) -> bool {
match self {
Self::Plain(_) => false,
#[cfg(feature = "tls-rustls")]
Self::Tls(_) => true,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::AtomicU32;
struct CountingHandler {
name: String,
call_count: AtomicU32,
}
impl CountingHandler {
fn new(name: &str) -> Self {
Self { name: name.to_string(), call_count: AtomicU32::new(0) }
}
fn count(&self) -> u32 {
self.call_count.load(Ordering::SeqCst)
}
}
impl ServiceHandler for CountingHandler {
fn handle(&self, _channel: Box<dyn Channel>) -> Result<()> {
self.call_count.fetch_add(1, Ordering::SeqCst);
Ok(())
}
fn service_name(&self) -> &str {
&self.name
}
}
#[test]
fn test_register_and_get() {
let dispatcher = ServiceDispatcher::new();
let handler = Arc::new(CountingHandler::new("feeder"));
dispatcher.register(handler);
let retrieved = dispatcher.get_handler("feeder");
assert!(retrieved.is_some());
assert_eq!(retrieved.unwrap().service_name(), "feeder");
}
#[test]
fn test_get_nonexistent() {
let dispatcher = ServiceDispatcher::new();
assert!(dispatcher.get_handler("nope").is_none());
}
#[test]
fn test_unregister() {
let dispatcher = ServiceDispatcher::new();
let handler = Arc::new(CountingHandler::new("feeder"));
dispatcher.register(handler);
let removed = dispatcher.unregister("feeder");
assert!(removed.is_some());
assert_eq!(removed.unwrap().service_name(), "feeder");
assert!(dispatcher.get_handler("feeder").is_none());
}
#[test]
fn test_unregister_nonexistent() {
let dispatcher = ServiceDispatcher::new();
assert!(dispatcher.unregister("nope").is_none());
}
#[test]
fn test_list_services() {
let dispatcher = ServiceDispatcher::new();
dispatcher.register(Arc::new(CountingHandler::new("feeder")));
dispatcher.register(Arc::new(CountingHandler::new("election")));
dispatcher.register(Arc::new(CountingHandler::new("backup")));
let names = dispatcher.list_services();
assert_eq!(names, vec!["backup", "election", "feeder"]);
}
#[test]
fn test_list_services_empty() {
let dispatcher = ServiceDispatcher::new();
assert!(dispatcher.list_services().is_empty());
}
#[test]
fn test_start_stop() {
let dispatcher = ServiceDispatcher::new();
assert!(!dispatcher.is_running());
dispatcher.start();
assert!(dispatcher.is_running());
dispatcher.stop();
assert!(!dispatcher.is_running());
}
#[test]
fn test_register_replaces_existing() {
let dispatcher = ServiceDispatcher::new();
let handler1 = Arc::new(CountingHandler::new("feeder"));
let handler2 = Arc::new(CountingHandler::new("feeder"));
dispatcher.register(handler1);
dispatcher.register(handler2);
assert_eq!(dispatcher.list_services(), vec!["feeder"]);
}
#[test]
fn test_dispatch_to_handler() {
use super::super::channel::LocalChannelPair;
let dispatcher = ServiceDispatcher::new();
let handler = Arc::new(CountingHandler::new("feeder"));
dispatcher.register(handler.clone());
let pair = LocalChannelPair::new();
dispatcher.dispatch("feeder", Box::new(pair.channel_a)).unwrap();
assert_eq!(handler.count(), 1);
}
#[test]
fn test_dispatch_unknown_service() {
use super::super::channel::LocalChannelPair;
let dispatcher = ServiceDispatcher::new();
let pair = LocalChannelPair::new();
let result = dispatcher.dispatch("unknown", Box::new(pair.channel_a));
assert!(result.is_err());
}
#[test]
fn test_default_trait() {
let dispatcher = ServiceDispatcher::default();
assert!(!dispatcher.is_running());
assert!(dispatcher.list_services().is_empty());
}
use super::{TcpServiceDispatcher, connect_to_service};
use std::time::Duration;
struct EchoHandler {
name: String,
}
impl ServiceHandler for EchoHandler {
fn handle(&self, channel: Box<dyn Channel>) -> Result<()> {
let msg = channel.receive(Duration::from_secs(5))?.unwrap();
channel.send(&msg)?;
Ok(())
}
fn service_name(&self) -> &str {
&self.name
}
}
#[test]
fn test_tcp_service_dispatcher_register_and_dispatch() {
let sd =
TcpServiceDispatcher::new("127.0.0.1:0".parse().unwrap()).unwrap();
sd.register("echo", Arc::new(EchoHandler { name: "echo".into() }));
let bound_addr = sd.start().unwrap();
std::thread::sleep(Duration::from_millis(20));
let client = connect_to_service(bound_addr, "echo").unwrap();
client.send(b"hello dispatcher").unwrap();
let reply = client.receive(Duration::from_secs(5)).unwrap();
assert_eq!(reply, Some(b"hello dispatcher".to_vec()));
sd.stop();
}
#[test]
fn test_tcp_service_dispatcher_multiple_clients() {
let sd =
TcpServiceDispatcher::new("127.0.0.1:0".parse().unwrap()).unwrap();
sd.register("echo", Arc::new(EchoHandler { name: "echo".into() }));
let bound_addr = sd.start().unwrap();
std::thread::sleep(Duration::from_millis(20));
let mut handles = Vec::new();
for i in 0u8..3 {
let addr = bound_addr;
handles.push(std::thread::spawn(move || {
let client = connect_to_service(addr, "echo").unwrap();
let msg = vec![i; 8];
client.send(&msg).unwrap();
let reply =
client.receive(Duration::from_secs(5)).unwrap().unwrap();
assert_eq!(reply, msg);
}));
}
for h in handles {
h.join().unwrap();
}
sd.stop();
}
}