Skip to main content

lellm_graph/
store.rs

1//! Checkpoint 存储后端 — BlobCheckpointStore SPI + 内存后端实现。
2//!
3//! 存储层操作 `CheckpointBlob`(bytes in / bytes out),与 State 类型和序列化格式解耦。
4
5use std::collections::HashMap;
6use std::sync::RwLock;
7
8use async_trait::async_trait;
9
10use crate::checkpoint::{CheckpointBlob, CheckpointId, CheckpointStoreError, TraceId};
11
12// ─── BlobCheckpointStore Trait ─────────────────────────────────
13
14/// Checkpoint 存储后端 SPI — bytes in / bytes out。
15///
16/// 存储层无需知道 State 类型或序列化格式,只操作 `CheckpointBlob`。
17/// 通过 `TypedCheckpointStore` 组合 Codec 实现类型化的 save/load。
18#[async_trait]
19pub trait BlobCheckpointStore: Send + Sync {
20    /// 保存 CheckpointBlob 并关联 trace_id。
21    async fn save_with_trace(
22        &self,
23        trace_id: &TraceId,
24        blob: &CheckpointBlob,
25    ) -> Result<(), CheckpointStoreError>;
26
27    /// 加载指定 ID 的 CheckpointBlob。
28    async fn load(&self, id: &CheckpointId)
29    -> Result<Option<CheckpointBlob>, CheckpointStoreError>;
30
31    /// 加载 trace 最新的 CheckpointBlob。
32    async fn load_latest(
33        &self,
34        trace_id: &TraceId,
35    ) -> Result<Option<CheckpointBlob>, CheckpointStoreError>;
36
37    /// 列出 trace 的所有 CheckpointId(按时间倒序)。
38    async fn list(&self, trace_id: &TraceId) -> Result<Vec<CheckpointId>, CheckpointStoreError>;
39
40    /// 删除指定 ID 的 Checkpoint。
41    async fn delete(&self, id: &CheckpointId) -> Result<bool, CheckpointStoreError>;
42
43    /// 修剪 trace 的旧 Checkpoint,保留最新的 keep 个。
44    async fn prune(&self, trace_id: &TraceId, keep: usize) -> Result<usize, CheckpointStoreError>;
45}
46
47// ─── InMemoryBlobStore ─────────────────────────────────────────
48
49/// 基于内存的 Checkpoint 存储后端。
50///
51/// 通过 `save_with_trace()` 关联 trace_id,或在存储层组织关联。
52#[derive(Default)]
53pub struct InMemoryBlobStore {
54    store: RwLock<HashMap<CheckpointId, CheckpointBlob>>,
55    /// trace_id → [CheckpointId] 索引(按时间正序)
56    index: RwLock<HashMap<TraceId, Vec<CheckpointId>>>,
57}
58
59impl InMemoryBlobStore {
60    pub fn new() -> Self {
61        Self::default()
62    }
63
64    pub fn len(&self) -> usize {
65        self.store.read().unwrap().len()
66    }
67
68    pub fn is_empty(&self) -> bool {
69        self.len() == 0
70    }
71}
72
73#[async_trait]
74impl BlobCheckpointStore for InMemoryBlobStore {
75    async fn save_with_trace(
76        &self,
77        trace_id: &TraceId,
78        blob: &CheckpointBlob,
79    ) -> Result<(), CheckpointStoreError> {
80        let id = blob.id.clone();
81
82        {
83            let mut store = self
84                .store
85                .write()
86                .map_err(|e| CheckpointStoreError::Storage(e.to_string()))?;
87            store.insert(id.clone(), blob.clone());
88        }
89
90        {
91            let mut index = self
92                .index
93                .write()
94                .map_err(|e| CheckpointStoreError::Storage(e.to_string()))?;
95            index.entry(*trace_id).or_default().push(id);
96        }
97
98        Ok(())
99    }
100
101    async fn load(
102        &self,
103        id: &CheckpointId,
104    ) -> Result<Option<CheckpointBlob>, CheckpointStoreError> {
105        let store = self
106            .store
107            .read()
108            .map_err(|e| CheckpointStoreError::Storage(e.to_string()))?;
109        Ok(store.get(id).cloned())
110    }
111
112    async fn load_latest(
113        &self,
114        trace_id: &TraceId,
115    ) -> Result<Option<CheckpointBlob>, CheckpointStoreError> {
116        let last_id = {
117            let index = self
118                .index
119                .read()
120                .map_err(|e| CheckpointStoreError::Storage(e.to_string()))?;
121            index.get(trace_id).and_then(|ids| ids.last()).cloned()
122        };
123
124        match last_id {
125            Some(id) => self.load(&id).await,
126            None => Ok(None),
127        }
128    }
129
130    async fn list(&self, trace_id: &TraceId) -> Result<Vec<CheckpointId>, CheckpointStoreError> {
131        let index = self
132            .index
133            .read()
134            .map_err(|e| CheckpointStoreError::Storage(e.to_string()))?;
135        let ids = index.get(trace_id).cloned().unwrap_or_default();
136        Ok(ids.into_iter().rev().collect())
137    }
138
139    async fn delete(&self, id: &CheckpointId) -> Result<bool, CheckpointStoreError> {
140        let mut store = self
141            .store
142            .write()
143            .map_err(|e| CheckpointStoreError::Storage(e.to_string()))?;
144        store
145            .remove(id)
146            .map(|_| true)
147            .ok_or_else(|| CheckpointStoreError::Storage("failed to acquire write lock".into()))
148    }
149
150    async fn prune(&self, trace_id: &TraceId, keep: usize) -> Result<usize, CheckpointStoreError> {
151        let to_delete: Vec<CheckpointId> = {
152            let mut index = self
153                .index
154                .write()
155                .map_err(|e| CheckpointStoreError::Storage(e.to_string()))?;
156            match index.get_mut(trace_id) {
157                Some(ids) if ids.len() > keep => {
158                    let remove_count = ids.len() - keep;
159                    ids.drain(..remove_count).collect()
160                }
161                _ => return Ok(0),
162            }
163        };
164
165        let mut store = self
166            .store
167            .write()
168            .map_err(|e| CheckpointStoreError::Storage(e.to_string()))?;
169        for id in &to_delete {
170            store.remove(id);
171        }
172
173        Ok(to_delete.len())
174    }
175}