1use crate::traits::MessageStore;
13use async_trait::async_trait;
14use bytes::Bytes;
15use ironfix_core::error::StoreError;
16use ironfix_core::message::{MsgType, OwnedMessage};
17use parking_lot::RwLock;
18use std::collections::BTreeMap;
19use std::sync::atomic::{AtomicU64, Ordering};
20use std::time::SystemTime;
21
22#[derive(Debug)]
27pub struct MemoryStore {
28 messages: RwLock<BTreeMap<u64, Bytes>>,
30 next_sender_seq: AtomicU64,
32 next_target_seq: AtomicU64,
34 creation_time: SystemTime,
36}
37
38impl MemoryStore {
39 #[must_use]
41 pub fn new() -> Self {
42 Self {
43 messages: RwLock::new(BTreeMap::new()),
44 next_sender_seq: AtomicU64::new(1),
45 next_target_seq: AtomicU64::new(1),
46 creation_time: SystemTime::now(),
47 }
48 }
49
50 #[must_use]
56 pub fn with_initial_seqs(sender_seq: u64, target_seq: u64) -> Self {
57 Self {
58 messages: RwLock::new(BTreeMap::new()),
59 next_sender_seq: AtomicU64::new(sender_seq),
60 next_target_seq: AtomicU64::new(target_seq),
61 creation_time: SystemTime::now(),
62 }
63 }
64
65 #[must_use]
67 pub fn message_count(&self) -> usize {
68 self.messages.read().len()
69 }
70
71 #[must_use]
73 pub fn contains(&self, seq_num: u64) -> bool {
74 self.messages.read().contains_key(&seq_num)
75 }
76}
77
78impl Default for MemoryStore {
79 fn default() -> Self {
80 Self::new()
81 }
82}
83
84#[async_trait]
85impl MessageStore for MemoryStore {
86 async fn store(&self, seq_num: u64, message: &[u8]) -> Result<(), StoreError> {
87 let mut messages = self.messages.write();
88 messages.insert(seq_num, Bytes::copy_from_slice(message));
89 Ok(())
90 }
91
92 async fn get_range(&self, begin: u64, end: u64) -> Result<Vec<OwnedMessage>, StoreError> {
93 let messages = self.messages.read();
94 let end = if end == 0 { u64::MAX } else { end };
95
96 let result: Vec<OwnedMessage> = messages
97 .range(begin..=end)
98 .map(|(_, bytes)| OwnedMessage::new(bytes.clone(), MsgType::default(), vec![]))
99 .collect();
100
101 if result.is_empty() && begin <= end {
102 return Err(StoreError::RangeNotAvailable {
103 range: begin..end + 1,
104 });
105 }
106
107 Ok(result)
108 }
109
110 fn next_sender_seq(&self) -> u64 {
111 self.next_sender_seq.load(Ordering::SeqCst)
112 }
113
114 fn next_target_seq(&self) -> u64 {
115 self.next_target_seq.load(Ordering::SeqCst)
116 }
117
118 fn set_next_sender_seq(&self, seq: u64) {
119 self.next_sender_seq.store(seq, Ordering::SeqCst);
120 }
121
122 fn set_next_target_seq(&self, seq: u64) {
123 self.next_target_seq.store(seq, Ordering::SeqCst);
124 }
125
126 async fn reset(&self) -> Result<(), StoreError> {
127 let mut messages = self.messages.write();
128 messages.clear();
129 self.next_sender_seq.store(1, Ordering::SeqCst);
130 self.next_target_seq.store(1, Ordering::SeqCst);
131 Ok(())
132 }
133
134 fn creation_time(&self) -> SystemTime {
135 self.creation_time
136 }
137}
138
139#[cfg(test)]
140mod tests {
141 use super::*;
142
143 #[tokio::test]
144 async fn test_memory_store_new() {
145 let store = MemoryStore::new();
146 assert_eq!(store.next_sender_seq(), 1);
147 assert_eq!(store.next_target_seq(), 1);
148 assert_eq!(store.message_count(), 0);
149 }
150
151 #[tokio::test]
152 async fn test_memory_store_store_and_retrieve() {
153 let store = MemoryStore::new();
154
155 store.store(1, b"message1").await.unwrap();
156 store.store(2, b"message2").await.unwrap();
157 store.store(3, b"message3").await.unwrap();
158
159 assert_eq!(store.message_count(), 3);
160 assert!(store.contains(1));
161 assert!(store.contains(2));
162 assert!(store.contains(3));
163 assert!(!store.contains(4));
164 }
165
166 #[tokio::test]
167 async fn test_memory_store_get_range() {
168 let store = MemoryStore::new();
169
170 store.store(1, b"msg1").await.unwrap();
171 store.store(2, b"msg2").await.unwrap();
172 store.store(3, b"msg3").await.unwrap();
173 store.store(5, b"msg5").await.unwrap();
174
175 let range = store.get_range(1, 3).await.unwrap();
176 assert_eq!(range.len(), 3);
177
178 let range = store.get_range(2, 5).await.unwrap();
179 assert_eq!(range.len(), 3);
180 }
181
182 #[tokio::test]
183 async fn test_memory_store_sequence_numbers() {
184 let store = MemoryStore::new();
185
186 store.set_next_sender_seq(10);
187 store.set_next_target_seq(20);
188
189 assert_eq!(store.next_sender_seq(), 10);
190 assert_eq!(store.next_target_seq(), 20);
191 }
192
193 #[tokio::test]
194 async fn test_memory_store_reset() {
195 let store = MemoryStore::new();
196
197 store.store(1, b"msg1").await.unwrap();
198 store.set_next_sender_seq(10);
199 store.set_next_target_seq(20);
200
201 store.reset().await.unwrap();
202
203 assert_eq!(store.message_count(), 0);
204 assert_eq!(store.next_sender_seq(), 1);
205 assert_eq!(store.next_target_seq(), 1);
206 }
207}