flowgentra_ai/core/memory/
checkpointer.rs1use crate::core::error::Result;
7use crate::core::state::State;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::sync::RwLock;
11
12#[derive(Debug, Clone, Serialize, Deserialize, Default)]
14pub struct CheckpointMetadata {
15 pub last_node: Option<String>,
17 pub execution_path: Vec<String>,
19 #[serde(default)]
21 pub extra: HashMap<String, serde_json::Value>,
22}
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct Checkpoint {
27 #[serde(rename = "state")]
29 state_value: serde_json::Value,
30 pub metadata: CheckpointMetadata,
31}
32
33impl Checkpoint {
34 pub fn new<T: State>(state: &T, metadata: CheckpointMetadata) -> Result<Self> {
35 let state_value = state.to_value();
36 Ok(Checkpoint {
37 state_value,
38 metadata,
39 })
40 }
41
42 pub fn state<T: State>(&self) -> Result<T> {
44 T::from_json(self.state_value.clone())
45 }
46}
47
48pub trait Checkpointer: Send + Sync {
52 fn load(&self, thread_id: &str) -> Result<Option<Checkpoint>>;
54
55 fn list_threads(&self) -> Result<Vec<String>> {
57 Ok(Vec::new())
58 }
59}
60
61pub trait GenericCheckpointer: Send + Sync {
63 fn save<T: State>(
65 &self,
66 thread_id: &str,
67 state: &T,
68 metadata: &CheckpointMetadata,
69 ) -> Result<()>;
70}
71
72#[allow(dead_code)]
74pub trait CheckpointStore: Checkpointer + GenericCheckpointer {}
75
76pub struct MemoryCheckpointer {
80 store: RwLock<HashMap<String, Checkpoint>>,
81}
82
83impl Default for MemoryCheckpointer {
84 fn default() -> Self {
85 Self::new()
86 }
87}
88
89impl MemoryCheckpointer {
90 pub fn new() -> Self {
91 Self {
92 store: RwLock::new(HashMap::new()),
93 }
94 }
95}
96
97impl Checkpointer for MemoryCheckpointer {
98 fn load(&self, thread_id: &str) -> Result<Option<Checkpoint>> {
99 let guard = self
100 .store
101 .read()
102 .map_err(|e| crate::core::error::FlowgentraError::StateError(e.to_string()))?;
103 Ok(guard.get(thread_id).cloned())
104 }
105
106 fn list_threads(&self) -> Result<Vec<String>> {
107 let guard = self
108 .store
109 .read()
110 .map_err(|e| crate::core::error::FlowgentraError::StateError(e.to_string()))?;
111 Ok(guard.keys().cloned().collect())
112 }
113}
114
115impl GenericCheckpointer for MemoryCheckpointer {
116 fn save<T: State>(
117 &self,
118 thread_id: &str,
119 state: &T,
120 metadata: &CheckpointMetadata,
121 ) -> Result<()> {
122 let checkpoint = Checkpoint::new(state, metadata.clone())?;
123 self.store
124 .write()
125 .map_err(|e| crate::core::error::FlowgentraError::StateError(e.to_string()))?
126 .insert(thread_id.to_string(), checkpoint);
127 Ok(())
128 }
129}
130
131impl<T: Checkpointer + GenericCheckpointer + ?Sized> CheckpointStore for T {}