1use crate::error::StoreError;
4use crate::traits::StateStore;
5use async_trait::async_trait;
6use attuned_core::{ComponentHealth, HealthCheck, StateSnapshot};
7use dashmap::DashMap;
8use std::collections::VecDeque;
9use std::sync::Arc;
10
11#[derive(Clone, Debug)]
13pub struct MemoryStoreConfig {
14 pub enable_history: bool,
16 pub max_history_per_user: usize,
18}
19
20impl Default for MemoryStoreConfig {
21 fn default() -> Self {
22 Self {
23 enable_history: false,
24 max_history_per_user: 100,
25 }
26 }
27}
28
29#[derive(Clone)]
34pub struct MemoryStore {
35 latest: Arc<DashMap<String, StateSnapshot>>,
36 history: Option<Arc<DashMap<String, VecDeque<StateSnapshot>>>>,
37 config: MemoryStoreConfig,
38}
39
40impl MemoryStore {
41 pub fn new(config: MemoryStoreConfig) -> Self {
43 let history = if config.enable_history {
44 Some(Arc::new(DashMap::new()))
45 } else {
46 None
47 };
48
49 Self {
50 latest: Arc::new(DashMap::new()),
51 history,
52 config,
53 }
54 }
55
56 pub fn len(&self) -> usize {
58 self.latest.len()
59 }
60
61 pub fn is_empty(&self) -> bool {
63 self.latest.is_empty()
64 }
65
66 pub fn clear(&self) {
68 self.latest.clear();
69 if let Some(ref history) = self.history {
70 history.clear();
71 }
72 }
73}
74
75impl Default for MemoryStore {
76 fn default() -> Self {
77 Self::new(MemoryStoreConfig::default())
78 }
79}
80
81#[async_trait]
82impl StateStore for MemoryStore {
83 #[tracing::instrument(skip(self, snapshot), fields(user_id = %snapshot.user_id))]
84 async fn upsert_latest(&self, snapshot: StateSnapshot) -> Result<(), StoreError> {
85 snapshot.validate()?;
87
88 let user_id = snapshot.user_id.clone();
89
90 if let Some(ref history) = self.history {
92 let mut entry = history.entry(user_id.clone()).or_insert_with(VecDeque::new);
93 entry.push_front(snapshot.clone());
94
95 while entry.len() > self.config.max_history_per_user {
97 entry.pop_back();
98 }
99 }
100
101 self.latest.insert(user_id, snapshot);
103
104 tracing::debug!("upserted state snapshot");
105 Ok(())
106 }
107
108 #[tracing::instrument(skip(self), fields(user_id = %user_id))]
109 async fn get_latest(&self, user_id: &str) -> Result<Option<StateSnapshot>, StoreError> {
110 let result = self.latest.get(user_id).map(|r| r.value().clone());
111 tracing::debug!(found = result.is_some(), "retrieved state snapshot");
112 Ok(result)
113 }
114
115 #[tracing::instrument(skip(self), fields(user_id = %user_id))]
116 async fn delete(&self, user_id: &str) -> Result<(), StoreError> {
117 self.latest.remove(user_id);
118 if let Some(ref history) = self.history {
119 history.remove(user_id);
120 }
121 tracing::debug!("deleted user state");
122 Ok(())
123 }
124
125 #[tracing::instrument(skip(self), fields(user_id = %user_id, limit = %limit))]
126 async fn get_history(
127 &self,
128 user_id: &str,
129 limit: usize,
130 ) -> Result<Vec<StateSnapshot>, StoreError> {
131 let result = match &self.history {
132 Some(history) => history
133 .get(user_id)
134 .map(|entry| entry.iter().take(limit).cloned().collect())
135 .unwrap_or_default(),
136 None => vec![],
137 };
138 tracing::debug!(count = result.len(), "retrieved history");
139 Ok(result)
140 }
141
142 async fn health_check(&self) -> Result<bool, StoreError> {
143 Ok(true)
144 }
145}
146
147#[async_trait]
148impl HealthCheck for MemoryStore {
149 async fn check(&self) -> ComponentHealth {
150 ComponentHealth::healthy("memory_store")
151 }
152}
153
154#[cfg(test)]
155mod tests {
156 use super::*;
157 use attuned_core::Source;
158
159 fn test_snapshot(user_id: &str) -> StateSnapshot {
160 StateSnapshot::builder()
161 .user_id(user_id)
162 .source(Source::SelfReport)
163 .axis("warmth", 0.7)
164 .build()
165 .unwrap()
166 }
167
168 #[tokio::test]
169 async fn test_upsert_and_get() {
170 let store = MemoryStore::default();
171 let snapshot = test_snapshot("user_1");
172
173 store.upsert_latest(snapshot.clone()).await.unwrap();
174
175 let retrieved = store.get_latest("user_1").await.unwrap();
176 assert!(retrieved.is_some());
177 assert_eq!(retrieved.unwrap().user_id, "user_1");
178 }
179
180 #[tokio::test]
181 async fn test_get_nonexistent() {
182 let store = MemoryStore::default();
183 let result = store.get_latest("nonexistent").await.unwrap();
184 assert!(result.is_none());
185 }
186
187 #[tokio::test]
188 async fn test_delete() {
189 let store = MemoryStore::default();
190 store.upsert_latest(test_snapshot("user_1")).await.unwrap();
191
192 store.delete("user_1").await.unwrap();
193
194 assert!(store.get_latest("user_1").await.unwrap().is_none());
195 }
196
197 #[tokio::test]
198 async fn test_history() {
199 let config = MemoryStoreConfig {
200 enable_history: true,
201 max_history_per_user: 5,
202 };
203 let store = MemoryStore::new(config);
204
205 for i in 0..3 {
207 let mut snapshot = test_snapshot("user_1");
208 snapshot.axes.insert("warmth".to_string(), i as f32 / 10.0);
209 store.upsert_latest(snapshot).await.unwrap();
210 }
211
212 let history = store.get_history("user_1", 10).await.unwrap();
213 assert_eq!(history.len(), 3);
214 }
215
216 #[tokio::test]
217 async fn test_history_limit() {
218 let config = MemoryStoreConfig {
219 enable_history: true,
220 max_history_per_user: 3,
221 };
222 let store = MemoryStore::new(config);
223
224 for i in 0..5 {
226 let mut snapshot = test_snapshot("user_1");
227 snapshot.axes.insert("warmth".to_string(), i as f32 / 10.0);
228 store.upsert_latest(snapshot).await.unwrap();
229 }
230
231 let history = store.get_history("user_1", 10).await.unwrap();
232 assert_eq!(history.len(), 3); }
234
235 #[tokio::test]
236 async fn test_concurrent_access() {
237 let store = MemoryStore::default();
238 let store = Arc::new(store);
239
240 let handles: Vec<_> = (0..100)
241 .map(|i| {
242 let store = store.clone();
243 tokio::spawn(async move {
244 let snapshot = test_snapshot(&format!("user_{}", i));
245 store.upsert_latest(snapshot).await
246 })
247 })
248 .collect();
249
250 for handle in handles {
251 handle.await.unwrap().unwrap();
252 }
253
254 assert_eq!(store.len(), 100);
255 }
256}