1use async_trait::async_trait;
4use std::collections::BTreeMap;
5use std::sync::Arc;
6use tokio::sync::RwLock;
7
8use super::{AuditStore, BaselineStore, StorageHealth};
9use crate::error::StoreError;
10use crate::models::{
11 AuditEvent, BaselineRecord, BaselineSummary, BaselineVersion, ListAuditEventsQuery,
12 ListAuditEventsResponse, ListBaselinesQuery, ListBaselinesResponse, ListVerdictsQuery,
13 ListVerdictsResponse, PaginationInfo, VerdictRecord,
14};
15
16#[derive(Debug, Default)]
18pub struct InMemoryStore {
19 #[allow(clippy::type_complexity)]
20 baselines: Arc<RwLock<BTreeMap<(String, String, String), BaselineRecord>>>,
21 verdicts: Arc<RwLock<Vec<VerdictRecord>>>,
22 audit_events: Arc<RwLock<Vec<AuditEvent>>>,
23}
24
25impl InMemoryStore {
26 pub fn new() -> Self {
28 Self {
29 baselines: Arc::new(RwLock::new(BTreeMap::new())),
30 verdicts: Arc::new(RwLock::new(Vec::new())),
31 audit_events: Arc::new(RwLock::new(Vec::new())),
32 }
33 }
34
35 fn key(project: &str, benchmark: &str, version: &str) -> (String, String, String) {
36 (
37 project.to_string(),
38 benchmark.to_string(),
39 version.to_string(),
40 )
41 }
42}
43
44#[async_trait]
45impl BaselineStore for InMemoryStore {
46 async fn create(&self, record: &BaselineRecord) -> Result<(), StoreError> {
47 let key = Self::key(&record.project, &record.benchmark, &record.version);
48 let mut baselines = self.baselines.write().await;
49
50 if baselines.contains_key(&key) {
51 return Err(StoreError::AlreadyExists(format!(
52 "project={}, benchmark={}, version={}",
53 record.project, record.benchmark, record.version
54 )));
55 }
56
57 baselines.insert(key, record.clone());
58 Ok(())
59 }
60
61 async fn get(
62 &self,
63 project: &str,
64 benchmark: &str,
65 version: &str,
66 ) -> Result<Option<BaselineRecord>, StoreError> {
67 let key = Self::key(project, benchmark, version);
68 let baselines = self.baselines.read().await;
69 Ok(baselines.get(&key).filter(|r| !r.deleted).cloned())
70 }
71
72 async fn get_latest(
73 &self,
74 project: &str,
75 benchmark: &str,
76 ) -> Result<Option<BaselineRecord>, StoreError> {
77 let baselines = self.baselines.read().await;
78 let latest = baselines
79 .values()
80 .filter(|r| r.project == project && r.benchmark == benchmark && !r.deleted)
81 .max_by_key(|r| r.created_at);
82 Ok(latest.cloned())
83 }
84
85 #[allow(clippy::collapsible_if)]
86 async fn list(
87 &self,
88 project: &str,
89 query: &ListBaselinesQuery,
90 ) -> Result<ListBaselinesResponse, StoreError> {
91 let baselines = self.baselines.read().await;
92 let parsed_tags = query.parsed_tags();
93
94 let mut filtered: Vec<_> = baselines
95 .values()
96 .filter(|r| {
97 if r.project != project || r.deleted {
99 return false;
100 }
101
102 if let Some(ref b) = query.benchmark {
104 if &r.benchmark != b {
105 return false;
106 }
107 }
108
109 if let Some(ref p) = query.benchmark_prefix {
111 if !r.benchmark.starts_with(p) {
112 return false;
113 }
114 }
115
116 if let Some(ref gr) = query.git_ref {
118 if r.git_ref.as_deref() != Some(gr) {
119 return false;
120 }
121 }
122
123 if let Some(ref gs) = query.git_sha {
125 if r.git_sha.as_deref() != Some(gs) {
126 return false;
127 }
128 }
129
130 if let Some(since) = query.since {
132 if r.created_at < since {
133 return false;
134 }
135 }
136
137 if let Some(until) = query.until {
139 if r.created_at > until {
140 return false;
141 }
142 }
143
144 if !parsed_tags.is_empty() {
146 for tag in &parsed_tags {
147 if !r.tags.contains(tag) {
148 return false;
149 }
150 }
151 }
152
153 true
154 })
155 .collect();
156
157 filtered.sort_by(|a, b| b.created_at.cmp(&a.created_at));
158
159 let total = filtered.len() as u64;
160 let offset = query.offset as usize;
161 let limit = query.limit as usize;
162
163 let paginated: Vec<_> = filtered
164 .into_iter()
165 .skip(offset)
166 .take(limit)
167 .map(|r| {
168 let mut summary: BaselineSummary = r.clone().into();
169 if query.include_receipt {
170 summary.receipt = Some(r.receipt.clone());
171 }
172 summary
173 })
174 .collect();
175
176 let has_more = (offset + paginated.len()) < total as usize;
177
178 Ok(ListBaselinesResponse {
179 baselines: paginated,
180 pagination: PaginationInfo {
181 total,
182 limit: query.limit,
183 offset: query.offset,
184 has_more,
185 },
186 })
187 }
188
189 async fn update(&self, record: &BaselineRecord) -> Result<(), StoreError> {
190 let key = Self::key(&record.project, &record.benchmark, &record.version);
191 let mut baselines = self.baselines.write().await;
192
193 if !baselines.contains_key(&key) {
194 return Err(StoreError::NotFound(format!(
195 "project={}, benchmark={}, version={}",
196 record.project, record.benchmark, record.version
197 )));
198 }
199
200 baselines.insert(key, record.clone());
201 Ok(())
202 }
203
204 async fn delete(
205 &self,
206 project: &str,
207 benchmark: &str,
208 version: &str,
209 ) -> Result<bool, StoreError> {
210 let key = Self::key(project, benchmark, version);
211 let mut baselines = self.baselines.write().await;
212
213 if let Some(record) = baselines.get_mut(&key) {
214 if record.deleted {
215 return Ok(false);
216 }
217 record.deleted = true;
218 return Ok(true);
219 }
220
221 Ok(false)
222 }
223
224 async fn hard_delete(
225 &self,
226 project: &str,
227 benchmark: &str,
228 version: &str,
229 ) -> Result<bool, StoreError> {
230 let key = Self::key(project, benchmark, version);
231 let mut baselines = self.baselines.write().await;
232 Ok(baselines.remove(&key).is_some())
233 }
234
235 async fn list_versions(
236 &self,
237 project: &str,
238 benchmark: &str,
239 ) -> Result<Vec<BaselineVersion>, StoreError> {
240 let baselines = self.baselines.read().await;
241
242 let mut versions: Vec<_> = baselines
243 .values()
244 .filter(|r| r.project == project && r.benchmark == benchmark && !r.deleted)
245 .map(|r| BaselineVersion {
246 version: r.version.clone(),
247 git_ref: r.git_ref.clone(),
248 git_sha: r.git_sha.clone(),
249 created_at: r.created_at,
250 created_by: None,
251 is_current: false,
252 source: r.source.clone(),
253 })
254 .collect();
255
256 versions.sort_by(|a, b| b.created_at.cmp(&a.created_at));
257
258 if let Some(first) = versions.first_mut() {
259 first.is_current = true;
260 }
261
262 Ok(versions)
263 }
264
265 async fn health_check(&self) -> Result<StorageHealth, StoreError> {
266 Ok(StorageHealth::Healthy)
267 }
268
269 fn backend_type(&self) -> &'static str {
270 "memory"
271 }
272
273 async fn create_verdict(&self, record: &VerdictRecord) -> Result<(), StoreError> {
274 let mut verdicts = self.verdicts.write().await;
275 verdicts.push(record.clone());
276 Ok(())
277 }
278
279 async fn list_verdicts(
280 &self,
281 project: &str,
282 query: &ListVerdictsQuery,
283 ) -> Result<ListVerdictsResponse, StoreError> {
284 let verdicts = self.verdicts.read().await;
285
286 let mut filtered: Vec<_> = verdicts
287 .iter()
288 .filter(|r| {
289 if r.project != project {
290 return false;
291 }
292
293 if let Some(ref b) = query.benchmark
294 && &r.benchmark != b
295 {
296 return false;
297 }
298
299 if let Some(ref s) = query.status
300 && &r.status != s
301 {
302 return false;
303 }
304
305 if let Some(since) = query.since
306 && r.created_at < since
307 {
308 return false;
309 }
310
311 if let Some(until) = query.until
312 && r.created_at > until
313 {
314 return false;
315 }
316
317 true
318 })
319 .cloned()
320 .collect();
321
322 filtered.sort_by(|a, b| b.created_at.cmp(&a.created_at));
323
324 let total = filtered.len() as u64;
325 let offset = query.offset as usize;
326 let limit = query.limit as usize;
327
328 let paginated: Vec<_> = filtered.into_iter().skip(offset).take(limit).collect();
329
330 let has_more = (offset + paginated.len()) < total as usize;
331
332 Ok(ListVerdictsResponse {
333 verdicts: paginated,
334 pagination: PaginationInfo {
335 total,
336 limit: query.limit,
337 offset: query.offset,
338 has_more,
339 },
340 })
341 }
342}
343
344#[async_trait]
345impl AuditStore for InMemoryStore {
346 async fn log_event(&self, event: &AuditEvent) -> Result<(), StoreError> {
347 let mut events = self.audit_events.write().await;
348 events.push(event.clone());
349 Ok(())
350 }
351
352 async fn list_events(
353 &self,
354 query: &ListAuditEventsQuery,
355 ) -> Result<ListAuditEventsResponse, StoreError> {
356 let events = self.audit_events.read().await;
357
358 let mut filtered: Vec<_> = events
359 .iter()
360 .filter(|e| {
361 if let Some(ref project) = query.project
362 && &e.project != project
363 {
364 return false;
365 }
366
367 if let Some(ref action) = query.action
368 && e.action.to_string() != *action
369 {
370 return false;
371 }
372
373 if let Some(ref resource_type) = query.resource_type
374 && e.resource_type.to_string() != *resource_type
375 {
376 return false;
377 }
378
379 if let Some(ref actor) = query.actor
380 && &e.actor != actor
381 {
382 return false;
383 }
384
385 if let Some(since) = query.since
386 && e.timestamp < since
387 {
388 return false;
389 }
390
391 if let Some(until) = query.until
392 && e.timestamp > until
393 {
394 return false;
395 }
396
397 true
398 })
399 .cloned()
400 .collect();
401
402 filtered.sort_by(|a, b| b.timestamp.cmp(&a.timestamp));
403
404 let total = filtered.len() as u64;
405 let offset = query.offset as usize;
406 let limit = query.limit as usize;
407
408 let paginated: Vec<_> = filtered.into_iter().skip(offset).take(limit).collect();
409
410 let has_more = (offset + paginated.len()) < total as usize;
411
412 Ok(ListAuditEventsResponse {
413 events: paginated,
414 pagination: PaginationInfo {
415 total,
416 limit: query.limit,
417 offset: query.offset,
418 has_more,
419 },
420 })
421 }
422}