Skip to main content

cqrs_rust_lib/read/
memory.rs

1use crate::{Aggregate, CqrsError, EventEnvelope, View};
2use std::collections::HashMap;
3use std::sync::{Arc, Mutex};
4use tracing::debug;
5
6/// A simple in-memory view store that can be used for testing or simple applications.
7pub struct InMemoryViewStore<A, V>
8where
9    A: Aggregate,
10    V: View<A>,
11{
12    views: Arc<Mutex<HashMap<String, V>>>,
13    _phantom: std::marker::PhantomData<A>,
14}
15
16impl<A, V> InMemoryViewStore<A, V>
17where
18    A: Aggregate,
19    V: View<A>,
20{
21    /// Creates a new in-memory view store.
22    #[must_use]
23    pub fn new() -> Self {
24        Self {
25            views: Arc::new(Mutex::new(HashMap::new())),
26            _phantom: std::marker::PhantomData,
27        }
28    }
29
30    /// Gets a view by its ID.
31    pub fn get_view(&self, view_id: &str) -> Option<V> {
32        let views = self.views.lock().unwrap();
33        views.get(view_id).cloned()
34    }
35
36    /// Gets all views in the store.
37    pub fn get_all_views(&self) -> HashMap<String, V> {
38        let views = self.views.lock().unwrap();
39        views.clone()
40    }
41
42    /// Updates a view with an event.
43    pub fn update_view(&self, event: &EventEnvelope<A>) -> Result<(), CqrsError> {
44        debug!("Updating view with event");
45
46        let view_id = V::view_id(event);
47        let mut views = self.views.lock().unwrap();
48
49        let view = views.entry(view_id.clone()).or_default();
50
51        if let Some(updated_view) = view.update(event) {
52            debug!(view_id = %view_id, "View updated successfully");
53            views.insert(view_id, updated_view);
54        } else {
55            debug!(view_id = %view_id, "View not updated (no changes)");
56        }
57
58        Ok(())
59    }
60
61    /// Clears all views from the store.
62    pub fn clear(&self) {
63        let mut views = self.views.lock().unwrap();
64        views.clear();
65    }
66}
67
68impl<A, V> Default for InMemoryViewStore<A, V>
69where
70    A: Aggregate,
71    V: View<A>,
72{
73    fn default() -> Self {
74        Self::new()
75    }
76}
77
78#[cfg(test)]
79mod tests {
80    use super::*;
81    use crate::{Aggregate, Event, ViewElements};
82    use chrono::Utc;
83    use http::StatusCode;
84    use serde::{Deserialize, Serialize};
85    use std::error::Error;
86    use std::fmt;
87
88    // Custom error type that implements std::error::Error
89    #[derive(Debug, Clone)]
90    struct TestError(String);
91
92    impl fmt::Display for TestError {
93        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
94            write!(f, "{}", self.0)
95        }
96    }
97
98    impl Error for TestError {}
99
100    // Simple event for testing
101    #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
102    #[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
103    enum TestEvent {
104        Created { name: String },
105        Updated { name: String },
106    }
107
108    impl Event for TestEvent {
109        fn event_type(&self) -> String {
110            match self {
111                TestEvent::Created { .. } => "Created".to_string(),
112                TestEvent::Updated { .. } => "Updated".to_string(),
113            }
114        }
115    }
116
117    // Simple aggregate for testing
118    #[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
119    struct TestAggregate {
120        id: String,
121        name: String,
122    }
123
124    cqrs_async_trait! {
125    impl Aggregate for TestAggregate {
126        const TYPE: &'static str = "TEST";
127
128        type Event = TestEvent;
129        type Error = TestError;
130
131        fn aggregate_id(&self) -> String {
132            self.id.clone()
133        }
134
135        fn with_aggregate_id(self, id: String) -> Self {
136            Self { id, ..self }
137        }
138
139        fn apply(&mut self, event: Self::Event) -> Result<(), Self::Error> {
140            match event {
141                TestEvent::Created { name } => self.name = name,
142                TestEvent::Updated { name } => self.name = name,
143            }
144            Ok(())
145        }
146
147        fn error(_status: StatusCode, details: &str) -> Self::Error {
148            TestError(details.to_string())
149        }
150    }
151    }
152
153    // Simple view for testing
154    #[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
155    struct TestView {
156        id: String,
157        name: String,
158        version: usize,
159    }
160
161    impl View<TestAggregate> for TestView {
162        const TYPE: &'static str = "TEST_VIEW";
163        const IS_CHILD_OF_AGGREGATE: bool = true;
164
165        fn view_id(event: &EventEnvelope<TestAggregate>) -> String {
166            event.aggregate_id.clone()
167        }
168
169        fn update(&self, event: &EventEnvelope<TestAggregate>) -> Option<Self> {
170            let mut updated = self.clone();
171            updated.id = event.aggregate_id.clone();
172            updated.version = event.version;
173
174            match &event.payload {
175                TestEvent::Created { name } => {
176                    updated.name = name.clone();
177                    Some(updated)
178                }
179                TestEvent::Updated { name } => {
180                    updated.name = name.clone();
181                    Some(updated)
182                }
183            }
184        }
185    }
186
187    impl ViewElements<TestAggregate> for TestView {
188        fn aggregate_id(&self) -> String {
189            self.id.clone()
190        }
191    }
192
193    #[test]
194    fn test_in_memory_view_store() {
195        // Create a view store
196        let view_store = InMemoryViewStore::<TestAggregate, TestView>::new();
197
198        // Create a test event
199        let event = EventEnvelope {
200            event_id: "event1".to_string(),
201            aggregate_id: "agg1".to_string(),
202            version: 1,
203            payload: TestEvent::Created {
204                name: "Test 1".to_string(),
205            },
206            metadata: HashMap::new(),
207            at: Utc::now(),
208        };
209
210        // Update the view with the event
211        view_store.update_view(&event).unwrap();
212
213        // Verify the view was created and updated
214        let view = view_store.get_view("agg1").unwrap();
215        assert_eq!(view.id, "agg1");
216        assert_eq!(view.name, "Test 1");
217        assert_eq!(view.version, 1);
218
219        // Create another test event
220        let event2 = EventEnvelope {
221            event_id: "event2".to_string(),
222            aggregate_id: "agg1".to_string(),
223            version: 2,
224            payload: TestEvent::Updated {
225                name: "Test 1 Updated".to_string(),
226            },
227            metadata: HashMap::new(),
228            at: Utc::now(),
229        };
230
231        // Update the view with the second event
232        view_store.update_view(&event2).unwrap();
233
234        // Verify the view was updated
235        let updated_view = view_store.get_view("agg1").unwrap();
236        assert_eq!(updated_view.id, "agg1");
237        assert_eq!(updated_view.name, "Test 1 Updated");
238        assert_eq!(updated_view.version, 2);
239
240        // Verify we can get all views
241        let all_views = view_store.get_all_views();
242        assert_eq!(all_views.len(), 1);
243        assert!(all_views.contains_key("agg1"));
244
245        // Clear the views
246        view_store.clear();
247
248        // Verify the views were cleared
249        assert!(view_store.get_view("agg1").is_none());
250    }
251}