1use std::collections::HashMap;
4use std::sync::RwLock;
5
6use async_trait::async_trait;
7
8use crate::checkpoint::{Checkpoint, CheckpointId, CheckpointStore, CheckpointStoreError};
9use crate::ids::TraceId;
10
11#[derive(Default)]
13pub struct InMemoryCheckpointStore {
14 store: RwLock<HashMap<CheckpointId, Checkpoint>>,
15 index: RwLock<HashMap<TraceId, Vec<CheckpointId>>>,
16}
17
18impl InMemoryCheckpointStore {
19 pub fn new() -> Self {
20 Self::default()
21 }
22
23 pub fn len(&self) -> usize {
24 self.store.read().unwrap().len()
25 }
26
27 pub fn is_empty(&self) -> bool {
28 self.len() == 0
29 }
30}
31
32#[async_trait]
33impl CheckpointStore for InMemoryCheckpointStore {
34 async fn save(&self, checkpoint: &Checkpoint) -> Result<(), CheckpointStoreError> {
35 let ck = checkpoint.clone();
36 let id = ck.checkpoint_id.clone();
37 let trace = ck.parent_trace_id;
38
39 {
40 let mut store = self
41 .store
42 .write()
43 .map_err(|e| CheckpointStoreError::Storage(e.to_string()))?;
44 store.insert(id.clone(), ck);
45 }
46
47 {
48 let mut index = self
49 .index
50 .write()
51 .map_err(|e| CheckpointStoreError::Storage(e.to_string()))?;
52 index.entry(trace).or_default().push(id);
53 }
54
55 Ok(())
56 }
57
58 async fn load(&self, id: &CheckpointId) -> Result<Option<Checkpoint>, CheckpointStoreError> {
59 let store = self
60 .store
61 .read()
62 .map_err(|e| CheckpointStoreError::Storage(e.to_string()))?;
63 Ok(store.get(id).cloned())
64 }
65
66 async fn load_latest(
67 &self,
68 trace_id: &TraceId,
69 ) -> Result<Option<Checkpoint>, CheckpointStoreError> {
70 let last_id = {
71 let index = self
72 .index
73 .read()
74 .map_err(|e| CheckpointStoreError::Storage(e.to_string()))?;
75 index.get(trace_id).and_then(|ids| ids.last()).cloned()
76 };
77
78 match last_id {
79 Some(id) => self.load(&id).await,
80 None => Ok(None),
81 }
82 }
83
84 async fn list(&self, trace_id: &TraceId) -> Result<Vec<CheckpointId>, CheckpointStoreError> {
85 let index = self
86 .index
87 .read()
88 .map_err(|e| CheckpointStoreError::Storage(e.to_string()))?;
89 let ids = index.get(trace_id).cloned().unwrap_or_default();
90 Ok(ids.into_iter().rev().collect())
91 }
92
93 async fn delete(&self, id: &CheckpointId) -> Result<bool, CheckpointStoreError> {
94 let trace_id = {
95 let mut store = self
96 .store
97 .write()
98 .map_err(|e| CheckpointStoreError::Storage(e.to_string()))?;
99 store.remove(id).map(|ck| ck.parent_trace_id)
100 };
101
102 match trace_id {
103 Some(trace) => {
104 let mut index = self
105 .index
106 .write()
107 .map_err(|e| CheckpointStoreError::Storage(e.to_string()))?;
108 if let Some(ids) = index.get_mut(&trace) {
109 ids.retain(|iid| iid != id);
110 if ids.is_empty() {
111 index.remove(&trace);
112 }
113 }
114 Ok(true)
115 }
116 None => Ok(false),
117 }
118 }
119
120 async fn prune(&self, trace_id: &TraceId, keep: usize) -> Result<usize, CheckpointStoreError> {
121 let to_delete: Vec<CheckpointId> = {
122 let mut index = self
123 .index
124 .write()
125 .map_err(|e| CheckpointStoreError::Storage(e.to_string()))?;
126 match index.get_mut(trace_id) {
127 Some(ids) if ids.len() > keep => {
128 let remove_count = ids.len() - keep;
129 ids.drain(..remove_count).collect()
130 }
131 _ => return Ok(0),
132 }
133 };
134
135 let mut store = self
136 .store
137 .write()
138 .map_err(|e| CheckpointStoreError::Storage(e.to_string()))?;
139 for id in &to_delete {
140 store.remove(id);
141 }
142
143 Ok(to_delete.len())
144 }
145}
146
147pub trait CheckpointExt {
149 fn get_state_value(&self, key: &str) -> Option<u64>;
150}
151
152impl CheckpointExt for Checkpoint {
153 fn get_state_value(&self, key: &str) -> Option<u64> {
154 self.state.get(key).and_then(|v| v.as_u64())
155 }
156}