allframe_core/cqrs/
query_bus.rs1use std::{
7 any::{Any, TypeId},
8 collections::HashMap,
9 sync::Arc,
10};
11
12use async_trait::async_trait;
13use tokio::sync::RwLock;
14
15pub trait Query: Send + Sync + 'static {}
17
18pub type QueryResult<R> = Result<R, QueryError>;
20
21#[derive(Debug, Clone)]
23pub enum QueryError {
24 NotFound(String),
26 Internal(String),
28}
29
30impl std::fmt::Display for QueryError {
31 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32 match self {
33 QueryError::NotFound(msg) => write!(f, "Handler not found: {}", msg),
34 QueryError::Internal(msg) => write!(f, "Internal error: {}", msg),
35 }
36 }
37}
38
39impl std::error::Error for QueryError {}
40
41#[async_trait]
43pub trait QueryHandler<Q: Query, R: Send + Sync + 'static>: Send + Sync {
44 async fn handle(&self, query: Q) -> QueryResult<R>;
46}
47
48#[async_trait]
50trait ErasedQueryHandler: Send + Sync {
51 async fn handle_erased(&self, query: Box<dyn Any + Send>) -> Result<Box<dyn Any + Send>, QueryError>;
52}
53
54struct QueryHandlerWrapper<Q: Query, R: Send + Sync + 'static, H: QueryHandler<Q, R>> {
56 handler: Arc<H>,
57 _phantom: std::marker::PhantomData<(Q, R)>,
58}
59
60#[async_trait]
61impl<Q: Query, R: Send + Sync + 'static, H: QueryHandler<Q, R>> ErasedQueryHandler
62 for QueryHandlerWrapper<Q, R, H>
63{
64 async fn handle_erased(&self, query: Box<dyn Any + Send>) -> Result<Box<dyn Any + Send>, QueryError> {
65 match query.downcast::<Q>() {
66 Ok(q) => {
67 let result = self.handler.handle(*q).await?;
68 Ok(Box::new(result))
69 }
70 Err(_) => Err(QueryError::Internal(
71 "Type mismatch in query dispatch".to_string(),
72 )),
73 }
74 }
75}
76
77type HandlerMap = HashMap<TypeId, Arc<dyn ErasedQueryHandler>>;
79
80pub struct QueryBus {
82 handlers: Arc<RwLock<HandlerMap>>,
83}
84
85impl QueryBus {
86 pub fn new() -> Self {
88 Self {
89 handlers: Arc::new(RwLock::new(HashMap::new())),
90 }
91 }
92
93 pub async fn register<Q: Query, R: Send + Sync + 'static, H: QueryHandler<Q, R> + 'static>(
95 &self,
96 handler: H,
97 ) {
98 let type_id = TypeId::of::<Q>();
99 let wrapper = QueryHandlerWrapper {
100 handler: Arc::new(handler),
101 _phantom: std::marker::PhantomData,
102 };
103 let mut handlers = self.handlers.write().await;
104 handlers.insert(type_id, Arc::new(wrapper));
105 }
106
107 pub async fn dispatch<Q: Query, R: Send + Sync + 'static>(&self, query: Q) -> QueryResult<R> {
109 let type_id = TypeId::of::<Q>();
110 let handlers = self.handlers.read().await;
111
112 match handlers.get(&type_id) {
113 Some(handler) => {
114 let boxed_query: Box<dyn Any + Send> = Box::new(query);
115 let result = handler.handle_erased(boxed_query).await?;
116 match result.downcast::<R>() {
117 Ok(r) => Ok(*r),
118 Err(_) => Err(QueryError::Internal(
119 "Type mismatch in query result".to_string(),
120 )),
121 }
122 }
123 None => Err(QueryError::NotFound(format!(
124 "No handler registered for query type: {}",
125 std::any::type_name::<Q>()
126 ))),
127 }
128 }
129
130 pub async fn handlers_count(&self) -> usize {
132 self.handlers.read().await.len()
133 }
134}
135
136impl Default for QueryBus {
137 fn default() -> Self {
138 Self::new()
139 }
140}
141
142impl Clone for QueryBus {
143 fn clone(&self) -> Self {
144 Self {
145 handlers: Arc::clone(&self.handlers),
146 }
147 }
148}
149
150#[cfg(test)]
151mod tests {
152 use super::*;
153
154 struct GetUserQuery {
155 id: String,
156 }
157
158 impl Query for GetUserQuery {}
159
160 #[derive(Debug, PartialEq)]
161 struct UserResult {
162 id: String,
163 name: String,
164 }
165
166 struct GetUserHandler;
167
168 #[async_trait]
169 impl QueryHandler<GetUserQuery, UserResult> for GetUserHandler {
170 async fn handle(&self, query: GetUserQuery) -> QueryResult<UserResult> {
171 Ok(UserResult {
172 id: query.id,
173 name: "Test User".to_string(),
174 })
175 }
176 }
177
178 #[tokio::test]
179 async fn test_query_dispatch() {
180 let bus = QueryBus::new();
181 bus.register(GetUserHandler).await;
182
183 let result = bus
184 .dispatch::<GetUserQuery, UserResult>(GetUserQuery {
185 id: "123".to_string(),
186 })
187 .await;
188
189 assert!(result.is_ok());
190 let user = result.unwrap();
191 assert_eq!(user.id, "123");
192 assert_eq!(user.name, "Test User");
193 }
194
195 #[tokio::test]
196 async fn test_query_handler_not_found() {
197 let bus = QueryBus::new();
198
199 let result = bus
200 .dispatch::<GetUserQuery, UserResult>(GetUserQuery {
201 id: "123".to_string(),
202 })
203 .await;
204
205 assert!(matches!(result, Err(QueryError::NotFound(_))));
206 }
207}