Skip to main content

lellm_graph/
store.rs

1//! Checkpoint 存储后端实现 — 内存后端。
2
3use std::collections::HashMap;
4use std::sync::RwLock;
5
6use async_trait::async_trait;
7
8use crate::checkpoint::{Checkpoint, CheckpointId, CheckpointStore, CheckpointStoreError, TraceId};
9
10/// 基于内存的 Checkpoint 存储后端。
11///
12/// 通过 `save_with_trace()` 关联 trace_id,或在存储层组织关联。
13#[derive(Default)]
14pub struct InMemoryCheckpointStore {
15    store: RwLock<HashMap<CheckpointId, Checkpoint>>,
16    /// trace_id → [CheckpointId] 索引(按时间正序)
17    index: RwLock<HashMap<TraceId, Vec<CheckpointId>>>,
18}
19
20impl InMemoryCheckpointStore {
21    pub fn new() -> Self {
22        Self::default()
23    }
24
25    pub fn len(&self) -> usize {
26        self.store.read().unwrap().len()
27    }
28
29    pub fn is_empty(&self) -> bool {
30        self.len() == 0
31    }
32}
33
34#[async_trait]
35impl CheckpointStore for InMemoryCheckpointStore {
36    async fn save_with_trace(
37        &self,
38        trace_id: &TraceId,
39        checkpoint: &Checkpoint,
40    ) -> Result<(), CheckpointStoreError> {
41        let id = checkpoint.checkpoint_id.clone();
42
43        {
44            let mut store = self
45                .store
46                .write()
47                .map_err(|e| CheckpointStoreError::Storage(e.to_string()))?;
48            store.insert(id.clone(), checkpoint.clone());
49        }
50
51        {
52            let mut index = self
53                .index
54                .write()
55                .map_err(|e| CheckpointStoreError::Storage(e.to_string()))?;
56            index.entry(*trace_id).or_default().push(id);
57        }
58
59        Ok(())
60    }
61
62    async fn load(&self, id: &CheckpointId) -> Result<Option<Checkpoint>, CheckpointStoreError> {
63        let store = self
64            .store
65            .read()
66            .map_err(|e| CheckpointStoreError::Storage(e.to_string()))?;
67        Ok(store.get(id).cloned())
68    }
69
70    async fn load_latest(
71        &self,
72        trace_id: &TraceId,
73    ) -> Result<Option<Checkpoint>, CheckpointStoreError> {
74        let last_id = {
75            let index = self
76                .index
77                .read()
78                .map_err(|e| CheckpointStoreError::Storage(e.to_string()))?;
79            index.get(trace_id).and_then(|ids| ids.last()).cloned()
80        };
81
82        match last_id {
83            Some(id) => self.load(&id).await,
84            None => Ok(None),
85        }
86    }
87
88    async fn list(&self, trace_id: &TraceId) -> Result<Vec<CheckpointId>, CheckpointStoreError> {
89        let index = self
90            .index
91            .read()
92            .map_err(|e| CheckpointStoreError::Storage(e.to_string()))?;
93        let ids = index.get(trace_id).cloned().unwrap_or_default();
94        Ok(ids.into_iter().rev().collect())
95    }
96
97    async fn delete(&self, id: &CheckpointId) -> Result<bool, CheckpointStoreError> {
98        let mut store = self
99            .store
100            .write()
101            .map_err(|e| CheckpointStoreError::Storage(e.to_string()))?;
102        store
103            .remove(id)
104            .map(|_| true)
105            .ok_or_else(|| CheckpointStoreError::Storage("failed to acquire write lock".into()))
106    }
107
108    async fn prune(&self, trace_id: &TraceId, keep: usize) -> Result<usize, CheckpointStoreError> {
109        let to_delete: Vec<CheckpointId> = {
110            let mut index = self
111                .index
112                .write()
113                .map_err(|e| CheckpointStoreError::Storage(e.to_string()))?;
114            match index.get_mut(trace_id) {
115                Some(ids) if ids.len() > keep => {
116                    let remove_count = ids.len() - keep;
117                    ids.drain(..remove_count).collect()
118                }
119                _ => return Ok(0),
120            }
121        };
122
123        let mut store = self
124            .store
125            .write()
126            .map_err(|e| CheckpointStoreError::Storage(e.to_string()))?;
127        for id in &to_delete {
128            store.remove(id);
129        }
130
131        Ok(to_delete.len())
132    }
133}