lellm_graph/checkpoint/
store.rs1use std::collections::HashMap;
6use std::sync::RwLock;
7
8use async_trait::async_trait;
9
10use super::checkpoint::{CheckpointBlob, CheckpointId, CheckpointStoreError, TraceId};
11
12#[async_trait]
19pub trait BlobCheckpointStore: Send + Sync {
20 async fn save_with_trace(
22 &self,
23 trace_id: &TraceId,
24 blob: &CheckpointBlob,
25 ) -> Result<(), CheckpointStoreError>;
26
27 async fn load(&self, id: &CheckpointId)
29 -> Result<Option<CheckpointBlob>, CheckpointStoreError>;
30
31 async fn load_latest(
33 &self,
34 trace_id: &TraceId,
35 ) -> Result<Option<CheckpointBlob>, CheckpointStoreError>;
36
37 async fn list(&self, trace_id: &TraceId) -> Result<Vec<CheckpointId>, CheckpointStoreError>;
39
40 async fn delete(&self, id: &CheckpointId) -> Result<bool, CheckpointStoreError>;
42
43 async fn prune(&self, trace_id: &TraceId, keep: usize) -> Result<usize, CheckpointStoreError>;
45}
46
47#[derive(Default)]
53pub struct InMemoryBlobStore {
54 store: RwLock<HashMap<CheckpointId, CheckpointBlob>>,
55 index: RwLock<HashMap<TraceId, Vec<CheckpointId>>>,
57}
58
59impl InMemoryBlobStore {
60 pub fn new() -> Self {
61 Self::default()
62 }
63
64 pub fn len(&self) -> usize {
65 self.store.read().unwrap().len()
66 }
67
68 pub fn is_empty(&self) -> bool {
69 self.len() == 0
70 }
71}
72
73#[async_trait]
74impl BlobCheckpointStore for InMemoryBlobStore {
75 async fn save_with_trace(
76 &self,
77 trace_id: &TraceId,
78 blob: &CheckpointBlob,
79 ) -> Result<(), CheckpointStoreError> {
80 let id = blob.id.clone();
81
82 {
83 let mut store = self
84 .store
85 .write()
86 .map_err(|e| CheckpointStoreError::Storage(e.to_string()))?;
87 store.insert(id.clone(), blob.clone());
88 }
89
90 {
91 let mut index = self
92 .index
93 .write()
94 .map_err(|e| CheckpointStoreError::Storage(e.to_string()))?;
95 index.entry(*trace_id).or_default().push(id);
96 }
97
98 Ok(())
99 }
100
101 async fn load(
102 &self,
103 id: &CheckpointId,
104 ) -> Result<Option<CheckpointBlob>, CheckpointStoreError> {
105 let store = self
106 .store
107 .read()
108 .map_err(|e| CheckpointStoreError::Storage(e.to_string()))?;
109 Ok(store.get(id).cloned())
110 }
111
112 async fn load_latest(
113 &self,
114 trace_id: &TraceId,
115 ) -> Result<Option<CheckpointBlob>, CheckpointStoreError> {
116 let last_id = {
117 let index = self
118 .index
119 .read()
120 .map_err(|e| CheckpointStoreError::Storage(e.to_string()))?;
121 index.get(trace_id).and_then(|ids| ids.last()).cloned()
122 };
123
124 match last_id {
125 Some(id) => self.load(&id).await,
126 None => Ok(None),
127 }
128 }
129
130 async fn list(&self, trace_id: &TraceId) -> Result<Vec<CheckpointId>, CheckpointStoreError> {
131 let index = self
132 .index
133 .read()
134 .map_err(|e| CheckpointStoreError::Storage(e.to_string()))?;
135 let ids = index.get(trace_id).cloned().unwrap_or_default();
136 Ok(ids.into_iter().rev().collect())
137 }
138
139 async fn delete(&self, id: &CheckpointId) -> Result<bool, CheckpointStoreError> {
140 let mut store = self
141 .store
142 .write()
143 .map_err(|e| CheckpointStoreError::Storage(e.to_string()))?;
144 store
145 .remove(id)
146 .map(|_| true)
147 .ok_or_else(|| CheckpointStoreError::Storage("failed to acquire write lock".into()))
148 }
149
150 async fn prune(&self, trace_id: &TraceId, keep: usize) -> Result<usize, CheckpointStoreError> {
151 let to_delete: Vec<CheckpointId> = {
152 let mut index = self
153 .index
154 .write()
155 .map_err(|e| CheckpointStoreError::Storage(e.to_string()))?;
156 match index.get_mut(trace_id) {
157 Some(ids) if ids.len() > keep => {
158 let remove_count = ids.len() - keep;
159 ids.drain(..remove_count).collect()
160 }
161 _ => return Ok(0),
162 }
163 };
164
165 let mut store = self
166 .store
167 .write()
168 .map_err(|e| CheckpointStoreError::Storage(e.to_string()))?;
169 for id in &to_delete {
170 store.remove(id);
171 }
172
173 Ok(to_delete.len())
174 }
175}