entelix_graph/
in_memory_checkpointer.rs1use std::collections::HashMap;
11use std::sync::Arc;
12
13use async_trait::async_trait;
14use entelix_core::{Error, Result, TenantId, ThreadKey};
15use parking_lot::Mutex;
16
17use crate::checkpoint::{Checkpoint, CheckpointId, Checkpointer};
18
19type Partition = (TenantId, String);
25
26fn partition(key: &ThreadKey) -> Partition {
27 (key.tenant_id().clone(), key.thread_id().to_owned())
28}
29
30#[derive(Clone)]
38pub struct InMemoryCheckpointer<S>
39where
40 S: Clone + Send + Sync + 'static,
41{
42 inner: Arc<Mutex<HashMap<Partition, Vec<Checkpoint<S>>>>>,
43}
44
45impl<S> InMemoryCheckpointer<S>
46where
47 S: Clone + Send + Sync + 'static,
48{
49 pub fn new() -> Self {
51 Self {
52 inner: Arc::new(Mutex::new(HashMap::new())),
53 }
54 }
55
56 pub fn total_checkpoints(&self) -> usize {
59 self.inner.lock().values().map(Vec::len).sum()
60 }
61
62 pub fn thread_count(&self) -> usize {
65 self.inner.lock().len()
66 }
67}
68
69impl<S> Default for InMemoryCheckpointer<S>
70where
71 S: Clone + Send + Sync + 'static,
72{
73 fn default() -> Self {
74 Self::new()
75 }
76}
77
78#[async_trait]
79impl<S> Checkpointer<S> for InMemoryCheckpointer<S>
80where
81 S: Clone + Send + Sync + 'static,
82{
83 async fn put(&self, checkpoint: Checkpoint<S>) -> Result<()> {
84 let key = (checkpoint.tenant_id.clone(), checkpoint.thread_id.clone());
85 self.inner.lock().entry(key).or_default().push(checkpoint);
91 Ok(())
92 }
93
94 async fn get_latest(&self, key: &ThreadKey) -> Result<Option<Checkpoint<S>>> {
95 let guard = self.inner.lock();
96 Ok(guard
97 .get(&partition(key))
98 .and_then(|history| history.last().cloned()))
99 }
100
101 async fn get_by_id(&self, key: &ThreadKey, id: &CheckpointId) -> Result<Option<Checkpoint<S>>> {
102 let guard = self.inner.lock();
103 Ok(guard
104 .get(&partition(key))
105 .and_then(|h| h.iter().find(|cp| &cp.id == id).cloned()))
106 }
107
108 async fn list_history(&self, key: &ThreadKey, limit: usize) -> Result<Vec<Checkpoint<S>>> {
109 let guard = self.inner.lock();
110 Ok(guard
111 .get(&partition(key))
112 .map(|h| h.iter().rev().take(limit).cloned().collect::<Vec<_>>())
113 .unwrap_or_default())
114 }
115
116 async fn update_state(
117 &self,
118 key: &ThreadKey,
119 parent_id: &CheckpointId,
120 new_state: S,
121 ) -> Result<CheckpointId> {
122 let part = partition(key);
123 let parent_bits: Option<(Option<String>, usize)> = {
126 let guard = self.inner.lock();
127 guard
128 .get(&part)
129 .and_then(|h| h.iter().find(|cp| &cp.id == parent_id))
130 .map(|cp| (cp.next_node.clone(), cp.step.saturating_add(1)))
131 };
132 let (next_node, step) = parent_bits.ok_or_else(|| {
133 Error::invalid_request(format!(
134 "InMemoryCheckpointer::update_state: unknown parent_id in tenant '{}' thread '{}'",
135 key.tenant_id(),
136 key.thread_id()
137 ))
138 })?;
139 let new_checkpoint =
140 Checkpoint::new(key, step, new_state, next_node).with_parent(parent_id.clone());
141 let new_id = new_checkpoint.id.clone();
142 self.inner
143 .lock()
144 .entry(part)
145 .or_default()
146 .push(new_checkpoint);
147 Ok(new_id)
148 }
149}