lellm_graph/checkpoint/
mutation_log.rs1use std::collections::HashMap;
20use std::sync::RwLock;
21use std::time::SystemTime;
22
23use async_trait::async_trait;
24use serde::{Deserialize, Serialize};
25
26use super::checkpoint::NodeId;
27use super::checkpoint::{CheckpointId, CheckpointStoreError, TraceId};
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct MutationLogEntry {
37 pub trace_id: TraceId,
39 pub step: usize,
41 pub node_id: NodeId,
43 pub checkpoint_id: Option<CheckpointId>,
45 pub mutation_index: usize,
47 pub mutation: serde_json::Value,
49 pub timestamp: SystemTime,
51}
52
53impl MutationLogEntry {
54 pub fn new(
55 trace_id: TraceId,
56 step: usize,
57 node_id: NodeId,
58 checkpoint_id: Option<CheckpointId>,
59 mutation_index: usize,
60 mutation: serde_json::Value,
61 ) -> Self {
62 Self {
63 trace_id,
64 step,
65 node_id,
66 checkpoint_id,
67 mutation_index,
68 mutation,
69 timestamp: SystemTime::now(),
70 }
71 }
72}
73
74#[async_trait]
80pub trait MutationLogStore: Send + Sync {
81 async fn append(&self, entry: MutationLogEntry) -> Result<(), CheckpointStoreError>;
83
84 async fn append_batch(
86 &self,
87 entries: Vec<MutationLogEntry>,
88 ) -> Result<(), CheckpointStoreError> {
89 for entry in entries {
90 self.append(entry).await?;
91 }
92 Ok(())
93 }
94
95 async fn replay(
97 &self,
98 trace_id: &TraceId,
99 from_step: usize,
100 ) -> Result<Vec<MutationLogEntry>, CheckpointStoreError>;
101
102 async fn truncate(
104 &self,
105 trace_id: &TraceId,
106 keep_from_step: usize,
107 ) -> Result<usize, CheckpointStoreError>;
108}
109
110#[derive(Default)]
116pub struct InMemoryMutationLog {
117 entries: RwLock<Vec<MutationLogEntry>>,
118 index: RwLock<HashMap<TraceId, Vec<usize>>>,
120}
121
122impl InMemoryMutationLog {
123 pub fn new() -> Self {
124 Self::default()
125 }
126
127 pub fn len(&self) -> usize {
128 self.entries.read().unwrap().len()
129 }
130
131 pub fn is_empty(&self) -> bool {
132 self.len() == 0
133 }
134}
135
136#[async_trait]
137impl MutationLogStore for InMemoryMutationLog {
138 async fn append(&self, entry: MutationLogEntry) -> Result<(), CheckpointStoreError> {
139 let trace_id = entry.trace_id;
140 let idx = {
141 let mut entries = self
142 .entries
143 .write()
144 .map_err(|e| CheckpointStoreError::Storage(e.to_string()))?;
145 let idx = entries.len();
146 entries.push(entry);
147 idx
148 };
149 {
150 let mut index_map = self
151 .index
152 .write()
153 .map_err(|e| CheckpointStoreError::Storage(e.to_string()))?;
154 index_map.entry(trace_id).or_default().push(idx);
155 }
156 Ok(())
157 }
158
159 async fn replay(
160 &self,
161 trace_id: &TraceId,
162 from_step: usize,
163 ) -> Result<Vec<MutationLogEntry>, CheckpointStoreError> {
164 let entry_indices = {
165 let index_map = self
166 .index
167 .read()
168 .map_err(|e| CheckpointStoreError::Storage(e.to_string()))?;
169 index_map.get(trace_id).cloned().unwrap_or_default()
170 };
171
172 let entries = self
173 .entries
174 .read()
175 .map_err(|e| CheckpointStoreError::Storage(e.to_string()))?;
176
177 let mut result = Vec::new();
178 for &idx in &entry_indices {
179 if idx < entries.len() {
180 let entry = &entries[idx];
181 if entry.step >= from_step {
182 result.push(entry.clone());
183 }
184 }
185 }
186 Ok(result)
187 }
188
189 async fn truncate(
190 &self,
191 trace_id: &TraceId,
192 keep_from_step: usize,
193 ) -> Result<usize, CheckpointStoreError> {
194 let entry_indices: Vec<usize> = {
195 let index_map = self
196 .index
197 .read()
198 .map_err(|e| CheckpointStoreError::Storage(e.to_string()))?;
199 index_map.get(trace_id).cloned().unwrap_or_default()
200 };
201
202 let entries = self
203 .entries
204 .read()
205 .map_err(|e| CheckpointStoreError::Storage(e.to_string()))?;
206
207 let mut removed = 0;
208 for &idx in &entry_indices {
209 if idx < entries.len() && entries[idx].step < keep_from_step {
210 removed += 1;
211 }
212 }
213
214 Ok(removed)
217 }
218}
219
220pub fn mutations_to_log_entries<E: Serialize>(
226 trace_id: TraceId,
227 step: usize,
228 node_id: NodeId,
229 checkpoint_id: Option<CheckpointId>,
230 mutations: impl IntoIterator<Item = E>,
231) -> Vec<MutationLogEntry> {
232 let mut result = Vec::new();
233 for (idx, mutation) in mutations.into_iter().enumerate() {
234 if let Ok(value) = serde_json::to_value(&mutation) {
235 result.push(MutationLogEntry::new(
236 trace_id,
237 step,
238 node_id.clone(),
239 checkpoint_id.clone(),
240 idx,
241 value,
242 ));
243 }
244 }
245 result
246}