1use std::collections::HashMap;
4use std::sync::RwLock;
5
6use async_trait::async_trait;
7
8use crate::checkpoint::{Checkpoint, CheckpointId, CheckpointStore, CheckpointStoreError, TraceId};
9
10#[derive(Default)]
14pub struct InMemoryCheckpointStore {
15 store: RwLock<HashMap<CheckpointId, Checkpoint>>,
16 index: RwLock<HashMap<TraceId, Vec<CheckpointId>>>,
18}
19
20impl InMemoryCheckpointStore {
21 pub fn new() -> Self {
22 Self::default()
23 }
24
25 pub fn len(&self) -> usize {
26 self.store.read().unwrap().len()
27 }
28
29 pub fn is_empty(&self) -> bool {
30 self.len() == 0
31 }
32}
33
34#[async_trait]
35impl CheckpointStore for InMemoryCheckpointStore {
36 async fn save_with_trace(
37 &self,
38 trace_id: &TraceId,
39 checkpoint: &Checkpoint,
40 ) -> Result<(), CheckpointStoreError> {
41 let id = checkpoint.checkpoint_id.clone();
42
43 {
44 let mut store = self
45 .store
46 .write()
47 .map_err(|e| CheckpointStoreError::Storage(e.to_string()))?;
48 store.insert(id.clone(), checkpoint.clone());
49 }
50
51 {
52 let mut index = self
53 .index
54 .write()
55 .map_err(|e| CheckpointStoreError::Storage(e.to_string()))?;
56 index.entry(*trace_id).or_default().push(id);
57 }
58
59 Ok(())
60 }
61
62 async fn load(&self, id: &CheckpointId) -> Result<Option<Checkpoint>, CheckpointStoreError> {
63 let store = self
64 .store
65 .read()
66 .map_err(|e| CheckpointStoreError::Storage(e.to_string()))?;
67 Ok(store.get(id).cloned())
68 }
69
70 async fn load_latest(
71 &self,
72 trace_id: &TraceId,
73 ) -> Result<Option<Checkpoint>, CheckpointStoreError> {
74 let last_id = {
75 let index = self
76 .index
77 .read()
78 .map_err(|e| CheckpointStoreError::Storage(e.to_string()))?;
79 index.get(trace_id).and_then(|ids| ids.last()).cloned()
80 };
81
82 match last_id {
83 Some(id) => self.load(&id).await,
84 None => Ok(None),
85 }
86 }
87
88 async fn list(&self, trace_id: &TraceId) -> Result<Vec<CheckpointId>, CheckpointStoreError> {
89 let index = self
90 .index
91 .read()
92 .map_err(|e| CheckpointStoreError::Storage(e.to_string()))?;
93 let ids = index.get(trace_id).cloned().unwrap_or_default();
94 Ok(ids.into_iter().rev().collect())
95 }
96
97 async fn delete(&self, id: &CheckpointId) -> Result<bool, CheckpointStoreError> {
98 let mut store = self
99 .store
100 .write()
101 .map_err(|e| CheckpointStoreError::Storage(e.to_string()))?;
102 store
103 .remove(id)
104 .map(|_| true)
105 .ok_or_else(|| CheckpointStoreError::Storage("failed to acquire write lock".into()))
106 }
107
108 async fn prune(&self, trace_id: &TraceId, keep: usize) -> Result<usize, CheckpointStoreError> {
109 let to_delete: Vec<CheckpointId> = {
110 let mut index = self
111 .index
112 .write()
113 .map_err(|e| CheckpointStoreError::Storage(e.to_string()))?;
114 match index.get_mut(trace_id) {
115 Some(ids) if ids.len() > keep => {
116 let remove_count = ids.len() - keep;
117 ids.drain(..remove_count).collect()
118 }
119 _ => return Ok(0),
120 }
121 };
122
123 let mut store = self
124 .store
125 .write()
126 .map_err(|e| CheckpointStoreError::Storage(e.to_string()))?;
127 for id in &to_delete {
128 store.remove(id);
129 }
130
131 Ok(to_delete.len())
132 }
133}