use crate::{Error, Result};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::any::{Any, TypeId};
use std::collections::HashMap;
use std::sync::Arc;
use uuid::Uuid;
#[cfg(feature = "native")]
use tokio::sync::{broadcast, RwLock};
#[cfg(feature = "wasm")]
use parking_lot::RwLock;
pub trait Event: Send + Sync + Clone + std::fmt::Debug + 'static {
fn event_type(&self) -> &'static str;
fn priority(&self) -> EventPriority {
EventPriority::Normal
}
fn persistent(&self) -> bool {
false
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub enum EventPriority {
Low = 0,
Normal = 1,
High = 2,
Critical = 3,
}
#[async_trait]
pub trait EventHandler<E: Event>: Send + Sync {
async fn handle(&self, event: &E) -> Result<()>;
fn priority(&self) -> i32 {
0
}
fn early(&self) -> bool {
false
}
}
type BoxedHandler = Box<dyn EventHandlerDyn + Send + Sync>;
#[async_trait]
trait EventHandlerDyn {
async fn handle_dyn(&self, event: &(dyn Any + Send + Sync)) -> Result<()>;
fn priority(&self) -> i32;
fn early(&self) -> bool;
}
struct EventHandlerWrapper<E: Event, H: EventHandler<E>> {
handler: H,
_phantom: std::marker::PhantomData<E>,
}
#[async_trait]
impl<E: Event, H: EventHandler<E>> EventHandlerDyn for EventHandlerWrapper<E, H> {
async fn handle_dyn(&self, event: &(dyn Any + Send + Sync)) -> Result<()> {
if let Some(typed_event) = event.downcast_ref::<E>() {
self.handler.handle(typed_event).await
} else {
Err(Error::Other(anyhow::anyhow!("Event type mismatch")))
}
}
fn priority(&self) -> i32 {
self.handler.priority()
}
fn early(&self) -> bool {
self.handler.early()
}
}
pub struct EventBus {
#[cfg(feature = "native")]
handlers: RwLock<HashMap<TypeId, Vec<BoxedHandler>>>,
#[cfg(feature = "native")]
broadcast_senders: RwLock<HashMap<TypeId, broadcast::Sender<Arc<dyn Any + Send + Sync>>>>,
#[cfg(feature = "wasm")]
handlers: RwLock<HashMap<TypeId, Vec<BoxedHandler>>>,
max_queue_size: usize,
tracing_enabled: bool,
}
impl EventBus {
pub fn new() -> Self {
Self {
handlers: RwLock::new(HashMap::new()),
#[cfg(feature = "native")]
broadcast_senders: RwLock::new(HashMap::new()),
max_queue_size: 1000,
tracing_enabled: true,
}
}
pub fn with_config(max_queue_size: usize, tracing_enabled: bool) -> Self {
Self {
handlers: RwLock::new(HashMap::new()),
#[cfg(feature = "native")]
broadcast_senders: RwLock::new(HashMap::new()),
max_queue_size,
tracing_enabled,
}
}
pub async fn subscribe<E: Event, H: EventHandler<E> + 'static>(&self, handler: H) -> Result<()> {
let type_id = TypeId::of::<E>();
let boxed_handler = Box::new(EventHandlerWrapper {
handler,
_phantom: std::marker::PhantomData::<E>,
});
#[cfg(feature = "native")]
{
let mut handlers = self.handlers.write().await;
let handlers_list = handlers.entry(type_id).or_insert_with(Vec::new);
handlers_list.push(boxed_handler);
handlers_list.sort_by(|a, b| {
match (a.early(), b.early()) {
(true, false) => std::cmp::Ordering::Less,
(false, true) => std::cmp::Ordering::Greater,
_ => b.priority().cmp(&a.priority()),
}
});
}
#[cfg(feature = "wasm")]
{
let mut handlers = self.handlers.write();
let handlers_list = handlers.entry(type_id).or_insert_with(Vec::new);
handlers_list.push(boxed_handler);
handlers_list.sort_by(|a, b| {
match (a.early(), b.early()) {
(true, false) => std::cmp::Ordering::Less,
(false, true) => std::cmp::Ordering::Greater,
_ => b.priority().cmp(&a.priority()),
}
});
}
if self.tracing_enabled {
tracing::debug!("Subscribed to event type: {}", std::any::type_name::<E>());
}
Ok(())
}
pub async fn publish<E: Event>(&self, event: E) -> Result<()> {
let type_id = TypeId::of::<E>();
if self.tracing_enabled {
tracing::debug!(
"Publishing event: {} with priority: {:?}",
event.event_type(),
event.priority()
);
}
#[cfg(feature = "native")]
{
let handlers = self.handlers.read().await;
if let Some(handlers_list) = handlers.get(&type_id) {
for handler in handlers_list {
if let Err(e) = handler.handle_dyn(&event as &(dyn Any + Send + Sync)).await {
tracing::error!("Error handling event: {}", e);
}
}
}
let senders = self.broadcast_senders.read().await;
if let Some(sender) = senders.get(&type_id) {
let arc_event: Arc<dyn Any + Send + Sync> = Arc::new(event.clone());
if sender.send(arc_event).is_err() {
}
}
}
#[cfg(feature = "wasm")]
{
let handlers = self.handlers.read();
if let Some(handlers_list) = handlers.get(&type_id) {
for handler in handlers_list {
if let Err(e) = handler.handle_dyn(&event as &(dyn Any + Send + Sync)).await {
tracing::error!("Error handling event: {}", e);
}
}
}
}
Ok(())
}
#[cfg(feature = "native")]
pub async fn create_stream<E: Event>(&self) -> broadcast::Receiver<Arc<dyn Any + Send + Sync>> {
let type_id = TypeId::of::<E>();
let mut senders = self.broadcast_senders.write().await;
let sender = senders.entry(type_id).or_insert_with(|| {
let (sender, _) = broadcast::channel(self.max_queue_size);
sender
});
sender.subscribe()
}
pub async fn clear(&self) {
#[cfg(feature = "native")]
{
self.handlers.write().await.clear();
self.broadcast_senders.write().await.clear();
}
#[cfg(feature = "wasm")]
{
self.handlers.write().clear();
}
}
pub async fn handler_count<E: Event>(&self) -> usize {
let type_id = TypeId::of::<E>();
#[cfg(feature = "native")]
{
self.handlers.read().await
.get(&type_id)
.map(|h| h.len())
.unwrap_or(0)
}
#[cfg(feature = "wasm")]
{
self.handlers.read()
.get(&type_id)
.map(|h| h.len())
.unwrap_or(0)
}
}
}
impl Default for EventBus {
fn default() -> Self {
Self::new()
}
}
pub mod events {
use super::*;
use chrono::{DateTime, Utc};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionCreated {
pub session_id: String,
pub timestamp: DateTime<Utc>,
pub metadata: HashMap<String, serde_json::Value>,
}
impl Event for SessionCreated {
fn event_type(&self) -> &'static str {
"session.created"
}
fn persistent(&self) -> bool {
true
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionEnded {
pub session_id: String,
pub timestamp: DateTime<Utc>,
pub reason: String,
}
impl Event for SessionEnded {
fn event_type(&self) -> &'static str {
"session.ended"
}
fn persistent(&self) -> bool {
true
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MessageSent {
pub session_id: String,
pub message_id: String,
pub role: String,
pub content: String,
pub timestamp: DateTime<Utc>,
}
impl Event for MessageSent {
fn event_type(&self) -> &'static str {
"message.sent"
}
fn priority(&self) -> EventPriority {
EventPriority::Normal
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MessageReceived {
pub session_id: String,
pub message_id: String,
pub role: String,
pub content: String,
pub timestamp: DateTime<Utc>,
pub tokens_used: Option<u32>,
}
impl Event for MessageReceived {
fn event_type(&self) -> &'static str {
"message.received"
}
fn priority(&self) -> EventPriority {
EventPriority::Normal
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolExecuted {
pub session_id: String,
pub tool_id: String,
pub tool_name: String,
pub arguments: serde_json::Value,
pub result: serde_json::Value,
pub duration_ms: u64,
pub timestamp: DateTime<Utc>,
}
impl Event for ToolExecuted {
fn event_type(&self) -> &'static str {
"tool.executed"
}
fn priority(&self) -> EventPriority {
EventPriority::Normal
}
fn persistent(&self) -> bool {
true
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolFailed {
pub session_id: String,
pub tool_id: String,
pub tool_name: String,
pub arguments: serde_json::Value,
pub error: String,
pub duration_ms: u64,
pub timestamp: DateTime<Utc>,
}
impl Event for ToolFailed {
fn event_type(&self) -> &'static str {
"tool.failed"
}
fn priority(&self) -> EventPriority {
EventPriority::High
}
fn persistent(&self) -> bool {
true
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProviderConnected {
pub provider_id: String,
pub provider_name: String,
pub timestamp: DateTime<Utc>,
}
impl Event for ProviderConnected {
fn event_type(&self) -> &'static str {
"provider.connected"
}
fn priority(&self) -> EventPriority {
EventPriority::High
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProviderDisconnected {
pub provider_id: String,
pub provider_name: String,
pub reason: String,
pub timestamp: DateTime<Utc>,
}
impl Event for ProviderDisconnected {
fn event_type(&self) -> &'static str {
"provider.disconnected"
}
fn priority(&self) -> EventPriority {
EventPriority::High
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DataStored {
pub key: String,
pub size_bytes: u64,
pub timestamp: DateTime<Utc>,
}
impl Event for DataStored {
fn event_type(&self) -> &'static str {
"storage.stored"
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DataRetrieved {
pub key: String,
pub size_bytes: u64,
pub timestamp: DateTime<Utc>,
}
impl Event for DataRetrieved {
fn event_type(&self) -> &'static str {
"storage.retrieved"
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ErrorOccurred {
pub error_id: String,
pub component: String,
pub error_message: String,
pub error_code: Option<String>,
pub context: HashMap<String, serde_json::Value>,
pub timestamp: DateTime<Utc>,
}
impl Event for ErrorOccurred {
fn event_type(&self) -> &'static str {
"error.occurred"
}
fn priority(&self) -> EventPriority {
EventPriority::Critical
}
fn persistent(&self) -> bool {
true
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SystemStarted {
pub version: String,
pub features: Vec<String>,
pub timestamp: DateTime<Utc>,
}
impl Event for SystemStarted {
fn event_type(&self) -> &'static str {
"system.started"
}
fn priority(&self) -> EventPriority {
EventPriority::High
}
fn persistent(&self) -> bool {
true
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SystemShutdown {
pub reason: String,
pub timestamp: DateTime<Utc>,
}
impl Event for SystemShutdown {
fn event_type(&self) -> &'static str {
"system.shutdown"
}
fn priority(&self) -> EventPriority {
EventPriority::Critical
}
fn persistent(&self) -> bool {
true
}
}
}
#[macro_export]
macro_rules! event_handler {
($event_type:ty, $handler_fn:expr) => {
struct SimpleEventHandler {
handler: fn(&$event_type) -> Result<()>,
}
#[async_trait]
impl EventHandler<$event_type> for SimpleEventHandler {
async fn handle(&self, event: &$event_type) -> Result<()> {
(self.handler)(event)
}
}
SimpleEventHandler {
handler: $handler_fn,
}
};
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicU32, Ordering};
#[derive(Debug, Clone)]
struct TestEvent {
message: String,
}
impl Event for TestEvent {
fn event_type(&self) -> &'static str {
"test.event"
}
}
struct TestHandler {
counter: Arc<AtomicU32>,
}
#[async_trait]
impl EventHandler<TestEvent> for TestHandler {
async fn handle(&self, _event: &TestEvent) -> Result<()> {
self.counter.fetch_add(1, Ordering::SeqCst);
Ok(())
}
}
#[cfg(feature = "native")]
#[tokio::test]
async fn test_event_bus() {
let bus = EventBus::new();
let counter = Arc::new(AtomicU32::new(0));
let handler = TestHandler {
counter: counter.clone(),
};
bus.subscribe(handler).await.unwrap();
let event = TestEvent {
message: "Hello, World!".to_string(),
};
bus.publish(event).await.unwrap();
assert_eq!(counter.load(Ordering::SeqCst), 1);
}
}