1use crate::{Aggregate, CqrsError, EventEnvelope, View};
2use std::collections::HashMap;
3use std::sync::{Arc, Mutex};
4use tracing::debug;
5
6pub struct InMemoryViewStore<A, V>
8where
9 A: Aggregate,
10 V: View<A>,
11{
12 views: Arc<Mutex<HashMap<String, V>>>,
13 _phantom: std::marker::PhantomData<A>,
14}
15
16impl<A, V> InMemoryViewStore<A, V>
17where
18 A: Aggregate,
19 V: View<A>,
20{
21 #[must_use]
23 pub fn new() -> Self {
24 Self {
25 views: Arc::new(Mutex::new(HashMap::new())),
26 _phantom: std::marker::PhantomData,
27 }
28 }
29
30 pub fn get_view(&self, view_id: &str) -> Option<V> {
32 let views = self.views.lock().unwrap();
33 views.get(view_id).cloned()
34 }
35
36 pub fn get_all_views(&self) -> HashMap<String, V> {
38 let views = self.views.lock().unwrap();
39 views.clone()
40 }
41
42 pub fn update_view(&self, event: &EventEnvelope<A>) -> Result<(), CqrsError> {
44 debug!("Updating view with event");
45
46 let view_id = V::view_id(event);
47 let mut views = self.views.lock().unwrap();
48
49 let view = views.entry(view_id.clone()).or_default();
50
51 if let Some(updated_view) = view.update(event) {
52 debug!(view_id = %view_id, "View updated successfully");
53 views.insert(view_id, updated_view);
54 } else {
55 debug!(view_id = %view_id, "View not updated (no changes)");
56 }
57
58 Ok(())
59 }
60
61 pub fn clear(&self) {
63 let mut views = self.views.lock().unwrap();
64 views.clear();
65 }
66}
67
68impl<A, V> Default for InMemoryViewStore<A, V>
69where
70 A: Aggregate,
71 V: View<A>,
72{
73 fn default() -> Self {
74 Self::new()
75 }
76}
77
78#[cfg(test)]
79mod tests {
80 use super::*;
81 use crate::{Aggregate, Event, ViewElements};
82 use chrono::Utc;
83 use http::StatusCode;
84 use serde::{Deserialize, Serialize};
85 use std::error::Error;
86 use std::fmt;
87
88 #[derive(Debug, Clone)]
90 struct TestError(String);
91
92 impl fmt::Display for TestError {
93 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
94 write!(f, "{}", self.0)
95 }
96 }
97
98 impl Error for TestError {}
99
100 #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
102 #[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
103 enum TestEvent {
104 Created { name: String },
105 Updated { name: String },
106 }
107
108 impl Event for TestEvent {
109 fn event_type(&self) -> String {
110 match self {
111 TestEvent::Created { .. } => "Created".to_string(),
112 TestEvent::Updated { .. } => "Updated".to_string(),
113 }
114 }
115 }
116
117 #[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
119 struct TestAggregate {
120 id: String,
121 name: String,
122 }
123
124 cqrs_async_trait! {
125 impl Aggregate for TestAggregate {
126 const TYPE: &'static str = "TEST";
127
128 type Event = TestEvent;
129 type Error = TestError;
130
131 fn aggregate_id(&self) -> String {
132 self.id.clone()
133 }
134
135 fn with_aggregate_id(self, id: String) -> Self {
136 Self { id, ..self }
137 }
138
139 fn apply(&mut self, event: Self::Event) -> Result<(), Self::Error> {
140 match event {
141 TestEvent::Created { name } => self.name = name,
142 TestEvent::Updated { name } => self.name = name,
143 }
144 Ok(())
145 }
146
147 fn error(_status: StatusCode, details: &str) -> Self::Error {
148 TestError(details.to_string())
149 }
150 }
151 }
152
153 #[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
155 struct TestView {
156 id: String,
157 name: String,
158 version: usize,
159 }
160
161 impl View<TestAggregate> for TestView {
162 const TYPE: &'static str = "TEST_VIEW";
163 const IS_CHILD_OF_AGGREGATE: bool = true;
164
165 fn view_id(event: &EventEnvelope<TestAggregate>) -> String {
166 event.aggregate_id.clone()
167 }
168
169 fn update(&self, event: &EventEnvelope<TestAggregate>) -> Option<Self> {
170 let mut updated = self.clone();
171 updated.id = event.aggregate_id.clone();
172 updated.version = event.version;
173
174 match &event.payload {
175 TestEvent::Created { name } => {
176 updated.name = name.clone();
177 Some(updated)
178 }
179 TestEvent::Updated { name } => {
180 updated.name = name.clone();
181 Some(updated)
182 }
183 }
184 }
185 }
186
187 impl ViewElements<TestAggregate> for TestView {
188 fn aggregate_id(&self) -> String {
189 self.id.clone()
190 }
191 }
192
193 #[test]
194 fn test_in_memory_view_store() {
195 let view_store = InMemoryViewStore::<TestAggregate, TestView>::new();
197
198 let event = EventEnvelope {
200 event_id: "event1".to_string(),
201 aggregate_id: "agg1".to_string(),
202 version: 1,
203 payload: TestEvent::Created {
204 name: "Test 1".to_string(),
205 },
206 metadata: HashMap::new(),
207 at: Utc::now(),
208 };
209
210 view_store.update_view(&event).unwrap();
212
213 let view = view_store.get_view("agg1").unwrap();
215 assert_eq!(view.id, "agg1");
216 assert_eq!(view.name, "Test 1");
217 assert_eq!(view.version, 1);
218
219 let event2 = EventEnvelope {
221 event_id: "event2".to_string(),
222 aggregate_id: "agg1".to_string(),
223 version: 2,
224 payload: TestEvent::Updated {
225 name: "Test 1 Updated".to_string(),
226 },
227 metadata: HashMap::new(),
228 at: Utc::now(),
229 };
230
231 view_store.update_view(&event2).unwrap();
233
234 let updated_view = view_store.get_view("agg1").unwrap();
236 assert_eq!(updated_view.id, "agg1");
237 assert_eq!(updated_view.name, "Test 1 Updated");
238 assert_eq!(updated_view.version, 2);
239
240 let all_views = view_store.get_all_views();
242 assert_eq!(all_views.len(), 1);
243 assert!(all_views.contains_key("agg1"));
244
245 view_store.clear();
247
248 assert!(view_store.get_view("agg1").is_none());
250 }
251}