Skip to main content

allframe_core/cqrs/
query_bus.rs

1//! Query Bus for CQRS query dispatch and routing
2//!
3//! The QueryBus provides automatic query routing and error handling,
4//! mirroring the CommandBus pattern for the read side.
5
6use std::{
7    any::{Any, TypeId},
8    collections::HashMap,
9    sync::Arc,
10};
11
12use async_trait::async_trait;
13use tokio::sync::RwLock;
14
15/// Query trait marker
16pub trait Query: Send + Sync + 'static {}
17
18/// Query execution result
19pub type QueryResult<R> = Result<R, QueryError>;
20
21/// Query execution errors
22#[derive(Debug, Clone)]
23pub enum QueryError {
24    /// Query handler not found
25    NotFound(String),
26    /// Internal error
27    Internal(String),
28}
29
30impl std::fmt::Display for QueryError {
31    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32        match self {
33            QueryError::NotFound(msg) => write!(f, "Handler not found: {}", msg),
34            QueryError::Internal(msg) => write!(f, "Internal error: {}", msg),
35        }
36    }
37}
38
39impl std::error::Error for QueryError {}
40
41/// Query handler trait
42#[async_trait]
43pub trait QueryHandler<Q: Query, R: Send + Sync + 'static>: Send + Sync {
44    /// Execute the query
45    async fn handle(&self, query: Q) -> QueryResult<R>;
46}
47
48/// Type-erased query handler wrapper
49#[async_trait]
50trait ErasedQueryHandler: Send + Sync {
51    async fn handle_erased(&self, query: Box<dyn Any + Send>) -> Result<Box<dyn Any + Send>, QueryError>;
52}
53
54/// Wrapper to type-erase query handlers
55struct QueryHandlerWrapper<Q: Query, R: Send + Sync + 'static, H: QueryHandler<Q, R>> {
56    handler: Arc<H>,
57    _phantom: std::marker::PhantomData<(Q, R)>,
58}
59
60#[async_trait]
61impl<Q: Query, R: Send + Sync + 'static, H: QueryHandler<Q, R>> ErasedQueryHandler
62    for QueryHandlerWrapper<Q, R, H>
63{
64    async fn handle_erased(&self, query: Box<dyn Any + Send>) -> Result<Box<dyn Any + Send>, QueryError> {
65        match query.downcast::<Q>() {
66            Ok(q) => {
67                let result = self.handler.handle(*q).await?;
68                Ok(Box::new(result))
69            }
70            Err(_) => Err(QueryError::Internal(
71                "Type mismatch in query dispatch".to_string(),
72            )),
73        }
74    }
75}
76
77/// Type alias for handler storage
78type HandlerMap = HashMap<TypeId, Arc<dyn ErasedQueryHandler>>;
79
80/// Query Bus for dispatching queries to handlers
81pub struct QueryBus {
82    handlers: Arc<RwLock<HandlerMap>>,
83}
84
85impl QueryBus {
86    /// Create a new query bus
87    pub fn new() -> Self {
88        Self {
89            handlers: Arc::new(RwLock::new(HashMap::new())),
90        }
91    }
92
93    /// Register a query handler
94    pub async fn register<Q: Query, R: Send + Sync + 'static, H: QueryHandler<Q, R> + 'static>(
95        &self,
96        handler: H,
97    ) {
98        let type_id = TypeId::of::<Q>();
99        let wrapper = QueryHandlerWrapper {
100            handler: Arc::new(handler),
101            _phantom: std::marker::PhantomData,
102        };
103        let mut handlers = self.handlers.write().await;
104        handlers.insert(type_id, Arc::new(wrapper));
105    }
106
107    /// Dispatch a query
108    pub async fn dispatch<Q: Query, R: Send + Sync + 'static>(&self, query: Q) -> QueryResult<R> {
109        let type_id = TypeId::of::<Q>();
110        let handlers = self.handlers.read().await;
111
112        match handlers.get(&type_id) {
113            Some(handler) => {
114                let boxed_query: Box<dyn Any + Send> = Box::new(query);
115                let result = handler.handle_erased(boxed_query).await?;
116                match result.downcast::<R>() {
117                    Ok(r) => Ok(*r),
118                    Err(_) => Err(QueryError::Internal(
119                        "Type mismatch in query result".to_string(),
120                    )),
121                }
122            }
123            None => Err(QueryError::NotFound(format!(
124                "No handler registered for query type: {}",
125                std::any::type_name::<Q>()
126            ))),
127        }
128    }
129
130    /// Get number of registered handlers
131    pub async fn handlers_count(&self) -> usize {
132        self.handlers.read().await.len()
133    }
134}
135
136impl Default for QueryBus {
137    fn default() -> Self {
138        Self::new()
139    }
140}
141
142impl Clone for QueryBus {
143    fn clone(&self) -> Self {
144        Self {
145            handlers: Arc::clone(&self.handlers),
146        }
147    }
148}
149
150#[cfg(test)]
151mod tests {
152    use super::*;
153
154    struct GetUserQuery {
155        id: String,
156    }
157
158    impl Query for GetUserQuery {}
159
160    #[derive(Debug, PartialEq)]
161    struct UserResult {
162        id: String,
163        name: String,
164    }
165
166    struct GetUserHandler;
167
168    #[async_trait]
169    impl QueryHandler<GetUserQuery, UserResult> for GetUserHandler {
170        async fn handle(&self, query: GetUserQuery) -> QueryResult<UserResult> {
171            Ok(UserResult {
172                id: query.id,
173                name: "Test User".to_string(),
174            })
175        }
176    }
177
178    #[tokio::test]
179    async fn test_query_dispatch() {
180        let bus = QueryBus::new();
181        bus.register(GetUserHandler).await;
182
183        let result = bus
184            .dispatch::<GetUserQuery, UserResult>(GetUserQuery {
185                id: "123".to_string(),
186            })
187            .await;
188
189        assert!(result.is_ok());
190        let user = result.unwrap();
191        assert_eq!(user.id, "123");
192        assert_eq!(user.name, "Test User");
193    }
194
195    #[tokio::test]
196    async fn test_query_handler_not_found() {
197        let bus = QueryBus::new();
198
199        let result = bus
200            .dispatch::<GetUserQuery, UserResult>(GetUserQuery {
201                id: "123".to_string(),
202            })
203            .await;
204
205        assert!(matches!(result, Err(QueryError::NotFound(_))));
206    }
207}