Skip to main content

lellm_graph/
store.rs

1//! Checkpoint 存储后端实现 — 从 lellm-runtime 合并。
2
3use std::collections::HashMap;
4use std::sync::RwLock;
5
6use async_trait::async_trait;
7
8use crate::checkpoint::{Checkpoint, CheckpointId, CheckpointStore, CheckpointStoreError};
9use crate::ids::TraceId;
10
11/// 基于内存的 Checkpoint 存储后端。
12#[derive(Default)]
13pub struct InMemoryCheckpointStore {
14    store: RwLock<HashMap<CheckpointId, Checkpoint>>,
15    index: RwLock<HashMap<TraceId, Vec<CheckpointId>>>,
16}
17
18impl InMemoryCheckpointStore {
19    pub fn new() -> Self {
20        Self::default()
21    }
22
23    pub fn len(&self) -> usize {
24        self.store.read().unwrap().len()
25    }
26
27    pub fn is_empty(&self) -> bool {
28        self.len() == 0
29    }
30}
31
32#[async_trait]
33impl CheckpointStore for InMemoryCheckpointStore {
34    async fn save(&self, checkpoint: &Checkpoint) -> Result<(), CheckpointStoreError> {
35        let ck = checkpoint.clone();
36        let id = ck.checkpoint_id.clone();
37        let trace = ck.parent_trace_id;
38
39        {
40            let mut store = self
41                .store
42                .write()
43                .map_err(|e| CheckpointStoreError::Storage(e.to_string()))?;
44            store.insert(id.clone(), ck);
45        }
46
47        {
48            let mut index = self
49                .index
50                .write()
51                .map_err(|e| CheckpointStoreError::Storage(e.to_string()))?;
52            index.entry(trace).or_default().push(id);
53        }
54
55        Ok(())
56    }
57
58    async fn load(&self, id: &CheckpointId) -> Result<Option<Checkpoint>, CheckpointStoreError> {
59        let store = self
60            .store
61            .read()
62            .map_err(|e| CheckpointStoreError::Storage(e.to_string()))?;
63        Ok(store.get(id).cloned())
64    }
65
66    async fn load_latest(
67        &self,
68        trace_id: &TraceId,
69    ) -> Result<Option<Checkpoint>, CheckpointStoreError> {
70        let last_id = {
71            let index = self
72                .index
73                .read()
74                .map_err(|e| CheckpointStoreError::Storage(e.to_string()))?;
75            index.get(trace_id).and_then(|ids| ids.last()).cloned()
76        };
77
78        match last_id {
79            Some(id) => self.load(&id).await,
80            None => Ok(None),
81        }
82    }
83
84    async fn list(&self, trace_id: &TraceId) -> Result<Vec<CheckpointId>, CheckpointStoreError> {
85        let index = self
86            .index
87            .read()
88            .map_err(|e| CheckpointStoreError::Storage(e.to_string()))?;
89        let ids = index.get(trace_id).cloned().unwrap_or_default();
90        Ok(ids.into_iter().rev().collect())
91    }
92
93    async fn delete(&self, id: &CheckpointId) -> Result<bool, CheckpointStoreError> {
94        let trace_id = {
95            let mut store = self
96                .store
97                .write()
98                .map_err(|e| CheckpointStoreError::Storage(e.to_string()))?;
99            store.remove(id).map(|ck| ck.parent_trace_id)
100        };
101
102        match trace_id {
103            Some(trace) => {
104                let mut index = self
105                    .index
106                    .write()
107                    .map_err(|e| CheckpointStoreError::Storage(e.to_string()))?;
108                if let Some(ids) = index.get_mut(&trace) {
109                    ids.retain(|iid| iid != id);
110                    if ids.is_empty() {
111                        index.remove(&trace);
112                    }
113                }
114                Ok(true)
115            }
116            None => Ok(false),
117        }
118    }
119
120    async fn prune(&self, trace_id: &TraceId, keep: usize) -> Result<usize, CheckpointStoreError> {
121        let to_delete: Vec<CheckpointId> = {
122            let mut index = self
123                .index
124                .write()
125                .map_err(|e| CheckpointStoreError::Storage(e.to_string()))?;
126            match index.get_mut(trace_id) {
127                Some(ids) if ids.len() > keep => {
128                    let remove_count = ids.len() - keep;
129                    ids.drain(..remove_count).collect()
130                }
131                _ => return Ok(0),
132            }
133        };
134
135        let mut store = self
136            .store
137            .write()
138            .map_err(|e| CheckpointStoreError::Storage(e.to_string()))?;
139        for id in &to_delete {
140            store.remove(id);
141        }
142
143        Ok(to_delete.len())
144    }
145}
146
147/// Checkpoint 扩展 — 便捷读取物化状态中的值。
148pub trait CheckpointExt {
149    fn get_state_value(&self, key: &str) -> Option<u64>;
150}
151
152impl CheckpointExt for Checkpoint {
153    fn get_state_value(&self, key: &str) -> Option<u64> {
154        self.state.get(key).and_then(|v| v.as_u64())
155    }
156}