use parking_lot::RwLock;
use std::collections::HashMap;
use std::sync::Arc;
use crate::error::{Result, RingKernelError};
use crate::hlc::HlcTimestamp;
use crate::k2k::{DeliveryReceipt, DeliveryStatus, K2KBroker, K2KConfig};
use crate::message::MessageEnvelope;
use crate::persistent_message::{DispatchTable, PersistentMessage};
use crate::runtime::KernelId;
#[derive(Debug, Clone)]
pub struct DispatcherConfig {
pub enable_logging: bool,
pub enable_metrics: bool,
pub default_priority: u8,
}
impl Default for DispatcherConfig {
fn default() -> Self {
Self {
enable_logging: false,
enable_metrics: true,
default_priority: 0,
}
}
}
#[derive(Debug, Default)]
pub struct DispatcherMetrics {
pub messages_dispatched: u64,
pub messages_delivered: u64,
pub unknown_type_errors: u64,
pub delivery_errors: u64,
}
pub struct KernelDispatcher {
routes: RwLock<HashMap<u64, KernelId>>,
handler_tables: RwLock<HashMap<KernelId, DispatchTable>>,
broker: Arc<K2KBroker>,
config: DispatcherConfig,
metrics: RwLock<DispatcherMetrics>,
}
impl KernelDispatcher {
pub fn builder() -> DispatcherBuilder {
DispatcherBuilder::new()
}
pub fn new(broker: Arc<K2KBroker>) -> Self {
Self::with_config(broker, DispatcherConfig::default())
}
pub fn with_config(broker: Arc<K2KBroker>, config: DispatcherConfig) -> Self {
Self {
routes: RwLock::new(HashMap::new()),
handler_tables: RwLock::new(HashMap::new()),
broker,
config,
metrics: RwLock::new(DispatcherMetrics::default()),
}
}
pub fn register<M: PersistentMessage>(&self, kernel_id: KernelId) -> crate::error::Result<()> {
self.register_with_name::<M>(kernel_id, std::any::type_name::<M>())
}
pub fn register_with_name<M: PersistentMessage>(
&self,
kernel_id: KernelId,
handler_name: &str,
) -> crate::error::Result<()> {
let type_id = M::message_type();
self.routes.write().insert(type_id, kernel_id.clone());
let mut handler_tables = self.handler_tables.write();
let table = handler_tables.entry(kernel_id).or_default();
table.register_message::<M>(handler_name)
}
pub fn register_route(&self, type_id: u64, kernel_id: KernelId) {
self.routes.write().insert(type_id, kernel_id);
}
pub fn unregister(&self, type_id: u64) {
self.routes.write().remove(&type_id);
}
pub fn get_route(&self, type_id: u64) -> Option<KernelId> {
self.routes.read().get(&type_id).cloned()
}
pub fn has_route(&self, type_id: u64) -> bool {
self.routes.read().contains_key(&type_id)
}
pub fn routes(&self) -> Vec<(u64, KernelId)> {
self.routes
.read()
.iter()
.map(|(k, v)| (*k, v.clone()))
.collect()
}
pub fn get_dispatch_table(&self, kernel_id: &KernelId) -> Option<DispatchTable> {
self.handler_tables.read().get(kernel_id).cloned()
}
pub async fn dispatch(&self, envelope: MessageEnvelope) -> Result<DeliveryReceipt> {
self.dispatch_from(KernelId::new("host"), envelope).await
}
pub async fn dispatch_from(
&self,
source: KernelId,
envelope: MessageEnvelope,
) -> Result<DeliveryReceipt> {
let type_id = envelope.header.message_type;
let kernel_id = {
let routes = self.routes.read();
routes.get(&type_id).cloned()
};
let kernel_id = match kernel_id {
Some(id) => id,
None => {
{
let mut metrics = self.metrics.write();
metrics.messages_dispatched += 1;
metrics.unknown_type_errors += 1;
}
return Err(RingKernelError::K2KError(format!(
"No route for message type_id: {}",
type_id
)));
}
};
let receipt = self
.broker
.send_priority(source, kernel_id, envelope, self.config.default_priority)
.await?;
{
let mut metrics = self.metrics.write();
metrics.messages_dispatched += 1;
match receipt.status {
DeliveryStatus::Delivered => metrics.messages_delivered += 1,
DeliveryStatus::Pending => {} _ => metrics.delivery_errors += 1,
}
}
Ok(receipt)
}
pub async fn dispatch_message<M: PersistentMessage>(
&self,
message: &M,
timestamp: HlcTimestamp,
) -> Result<DeliveryReceipt> {
let envelope = MessageEnvelope::new(message, 0, 0, timestamp);
self.dispatch(envelope).await
}
pub fn metrics(&self) -> DispatcherMetrics {
let metrics = self.metrics.read();
DispatcherMetrics {
messages_dispatched: metrics.messages_dispatched,
messages_delivered: metrics.messages_delivered,
unknown_type_errors: metrics.unknown_type_errors,
delivery_errors: metrics.delivery_errors,
}
}
pub fn reset_metrics(&self) {
*self.metrics.write() = DispatcherMetrics::default();
}
pub fn broker(&self) -> &Arc<K2KBroker> {
&self.broker
}
}
pub struct DispatcherBuilder {
routes: Vec<Route>,
config: DispatcherConfig,
k2k_config: K2KConfig,
}
struct Route {
type_id: u64,
kernel_id: KernelId,
handler_name: String,
handler_id: Option<u32>,
requires_response: bool,
}
impl DispatcherBuilder {
pub fn new() -> Self {
Self {
routes: Vec::new(),
config: DispatcherConfig::default(),
k2k_config: K2KConfig::default(),
}
}
pub fn route<M: PersistentMessage>(mut self, kernel_id: KernelId) -> Self {
self.routes.push(Route {
type_id: M::message_type(),
kernel_id,
handler_name: std::any::type_name::<M>().to_string(),
handler_id: Some(M::handler_id()),
requires_response: M::requires_response(),
});
self
}
pub fn route_named<M: PersistentMessage>(
mut self,
kernel_id: KernelId,
handler_name: &str,
) -> Self {
self.routes.push(Route {
type_id: M::message_type(),
kernel_id,
handler_name: handler_name.to_string(),
handler_id: Some(M::handler_id()),
requires_response: M::requires_response(),
});
self
}
pub fn route_raw(mut self, type_id: u64, kernel_id: KernelId) -> Self {
self.routes.push(Route {
type_id,
kernel_id,
handler_name: format!("handler_{}", type_id),
handler_id: None,
requires_response: false,
});
self
}
pub fn with_config(mut self, config: DispatcherConfig) -> Self {
self.config = config;
self
}
pub fn with_k2k_config(mut self, config: K2KConfig) -> Self {
self.k2k_config = config;
self
}
pub fn with_logging(mut self) -> Self {
self.config.enable_logging = true;
self
}
pub fn with_priority(mut self, priority: u8) -> Self {
self.config.default_priority = priority;
self
}
pub fn build(self) -> KernelDispatcher {
let broker = K2KBroker::new(self.k2k_config.clone());
self.build_with_broker(broker)
}
pub fn build_with_broker(self, broker: Arc<K2KBroker>) -> KernelDispatcher {
let dispatcher = KernelDispatcher::with_config(broker, self.config);
for route in self.routes {
dispatcher
.routes
.write()
.insert(route.type_id, route.kernel_id.clone());
if let Some(handler_id) = route.handler_id {
use crate::persistent_message::HandlerRegistration;
let mut handler_tables = dispatcher.handler_tables.write();
let table = handler_tables.entry(route.kernel_id).or_default();
let mut registration =
HandlerRegistration::new(handler_id, &route.handler_name, route.type_id);
if route.requires_response {
registration = registration.with_response(0);
}
if let Err(e) = table.register(registration) {
tracing::warn!("Failed to register handler in dispatcher build: {}", e);
}
}
}
dispatcher
}
}
impl Default for DispatcherBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::hlc::HlcClock;
use crate::message::{MessageHeader, RingMessage};
#[derive(Clone, Copy, Debug)]
#[repr(C)]
struct TestRequest {
value: u64,
}
impl RingMessage for TestRequest {
fn message_type() -> u64 {
5001
}
fn message_id(&self) -> crate::message::MessageId {
crate::message::MessageId::new(0)
}
fn correlation_id(&self) -> crate::message::CorrelationId {
crate::message::CorrelationId::none()
}
fn priority(&self) -> crate::message::Priority {
crate::message::Priority::Normal
}
fn serialize(&self) -> Vec<u8> {
self.value.to_le_bytes().to_vec()
}
fn deserialize(bytes: &[u8]) -> Result<Self> {
if bytes.len() < 8 {
return Err(RingKernelError::DeserializationError(
"Too small".to_string(),
));
}
let value = u64::from_le_bytes(bytes[..8].try_into().unwrap());
Ok(Self { value })
}
fn size_hint(&self) -> usize {
8
}
}
impl PersistentMessage for TestRequest {
fn handler_id() -> u32 {
1
}
fn requires_response() -> bool {
true
}
fn payload_size() -> usize {
8
}
fn to_inline_payload(
&self,
) -> Option<[u8; crate::persistent_message::MAX_INLINE_PAYLOAD_SIZE]> {
let mut payload = [0u8; 32];
payload[..8].copy_from_slice(&self.value.to_le_bytes());
Some(payload)
}
fn from_inline_payload(payload: &[u8]) -> Result<Self> {
if payload.len() < 8 {
return Err(RingKernelError::DeserializationError(
"Too small".to_string(),
));
}
let value = u64::from_le_bytes(payload[..8].try_into().unwrap());
Ok(Self { value })
}
}
#[test]
fn test_dispatcher_builder() {
let kernel_id = KernelId::new("test_kernel");
let dispatcher = DispatcherBuilder::new()
.route::<TestRequest>(kernel_id.clone())
.build();
assert!(dispatcher.has_route(5001));
assert_eq!(dispatcher.get_route(5001), Some(kernel_id));
}
#[test]
fn test_dispatcher_registration() {
let dispatcher = DispatcherBuilder::new().build();
let kernel_id = KernelId::new("processor");
dispatcher
.register::<TestRequest>(kernel_id.clone())
.unwrap();
assert!(dispatcher.has_route(5001));
assert_eq!(dispatcher.get_route(5001), Some(kernel_id));
}
#[test]
fn test_dispatcher_unregister() {
let dispatcher = DispatcherBuilder::new()
.route::<TestRequest>(KernelId::new("processor"))
.build();
assert!(dispatcher.has_route(5001));
dispatcher.unregister(5001);
assert!(!dispatcher.has_route(5001));
}
#[test]
fn test_dispatcher_routes() {
let kernel_a = KernelId::new("kernel_a");
let kernel_b = KernelId::new("kernel_b");
let dispatcher = DispatcherBuilder::new()
.route::<TestRequest>(kernel_a.clone())
.route_raw(9999, kernel_b.clone())
.build();
let routes = dispatcher.routes();
assert_eq!(routes.len(), 2);
assert!(routes.contains(&(5001, kernel_a)));
assert!(routes.contains(&(9999, kernel_b)));
}
#[test]
fn test_dispatch_table_generation() {
let kernel_id = KernelId::new("test_kernel");
let dispatcher = DispatcherBuilder::new()
.route::<TestRequest>(kernel_id.clone())
.build();
let table = dispatcher.get_dispatch_table(&kernel_id);
assert!(table.is_some());
let table = table.unwrap();
assert_eq!(table.len(), 1);
let handler = table.get(1).unwrap();
assert_eq!(handler.handler_id, 1);
assert_eq!(handler.message_type_id, 5001);
}
#[tokio::test]
async fn test_dispatch_unknown_type() {
let dispatcher = DispatcherBuilder::new().build();
let clock = HlcClock::new(1);
let header = MessageHeader::new(9999, 0, 0, 0, clock.now());
let envelope = MessageEnvelope {
header,
payload: vec![],
};
let result = dispatcher.dispatch(envelope).await;
assert!(result.is_err());
let metrics = dispatcher.metrics();
assert_eq!(metrics.messages_dispatched, 1);
assert_eq!(metrics.unknown_type_errors, 1);
}
#[tokio::test]
async fn test_dispatch_to_registered_kernel() {
let kernel_id = KernelId::new("test_kernel");
let broker = K2KBroker::new(K2KConfig::default());
let _endpoint = broker.register(kernel_id.clone());
let dispatcher = DispatcherBuilder::new()
.route::<TestRequest>(kernel_id)
.build_with_broker(broker);
let clock = HlcClock::new(1);
let msg = TestRequest { value: 42 };
let envelope = MessageEnvelope::new(&msg, 0, 0, clock.now());
let receipt = dispatcher.dispatch(envelope).await.unwrap();
assert_eq!(receipt.status, DeliveryStatus::Delivered);
let metrics = dispatcher.metrics();
assert_eq!(metrics.messages_dispatched, 1);
assert_eq!(metrics.messages_delivered, 1);
}
#[test]
fn test_metrics_reset() {
let dispatcher = DispatcherBuilder::new().build();
{
let mut metrics = dispatcher.metrics.write();
metrics.messages_dispatched = 100;
metrics.messages_delivered = 50;
}
let metrics = dispatcher.metrics();
assert_eq!(metrics.messages_dispatched, 100);
dispatcher.reset_metrics();
let metrics = dispatcher.metrics();
assert_eq!(metrics.messages_dispatched, 0);
assert_eq!(metrics.messages_delivered, 0);
}
}