synaptic_mongodb/
checkpointer.rs1use async_trait::async_trait;
2use bson::{doc, DateTime as BsonDateTime};
3use futures::TryStreamExt;
4use mongodb::{Collection, Database, IndexModel};
5use synaptic_core::SynapticError;
6use synaptic_graph::{Checkpoint, CheckpointConfig, Checkpointer};
7
8pub struct MongoCheckpointer {
26 collection: Collection<bson::Document>,
27}
28
29impl MongoCheckpointer {
30 pub async fn new(db: &Database, collection_name: &str) -> Result<Self, SynapticError> {
35 let collection: Collection<bson::Document> = db.collection(collection_name);
36
37 let unique_idx = IndexModel::builder()
39 .keys(doc! { "thread_id": 1, "checkpoint_id": 1 })
40 .options(
41 mongodb::options::IndexOptions::builder()
42 .unique(true)
43 .build(),
44 )
45 .build();
46
47 let seq_idx = IndexModel::builder()
49 .keys(doc! { "thread_id": 1, "seq": 1 })
50 .build();
51
52 collection
53 .create_index(unique_idx)
54 .await
55 .map_err(|e| SynapticError::Store(format!("MongoDB create unique index: {e}")))?;
56
57 collection
58 .create_index(seq_idx)
59 .await
60 .map_err(|e| SynapticError::Store(format!("MongoDB create seq index: {e}")))?;
61
62 Ok(Self { collection })
63 }
64}
65
66#[async_trait]
67impl Checkpointer for MongoCheckpointer {
68 async fn put(
69 &self,
70 config: &CheckpointConfig,
71 checkpoint: &Checkpoint,
72 ) -> Result<(), SynapticError> {
73 let state_json = serde_json::to_string(checkpoint)
75 .map_err(|e| SynapticError::Store(format!("Serialize: {e}")))?;
76
77 let count = self
79 .collection
80 .count_documents(doc! { "thread_id": &config.thread_id })
81 .await
82 .map_err(|e| SynapticError::Store(format!("MongoDB count: {e}")))?;
83
84 let document = doc! {
85 "thread_id": &config.thread_id,
86 "checkpoint_id": &checkpoint.id,
87 "seq": count as i64,
88 "state": &state_json,
89 "created_at": BsonDateTime::now(),
90 };
91
92 self.collection
94 .update_one(
95 doc! {
96 "thread_id": &config.thread_id,
97 "checkpoint_id": &checkpoint.id
98 },
99 doc! { "$setOnInsert": document },
100 )
101 .with_options(
102 mongodb::options::UpdateOptions::builder()
103 .upsert(true)
104 .build(),
105 )
106 .await
107 .map_err(|e| SynapticError::Store(format!("MongoDB upsert: {e}")))?;
108
109 Ok(())
110 }
111
112 async fn get(&self, config: &CheckpointConfig) -> Result<Option<Checkpoint>, SynapticError> {
113 let filter = if let Some(ref id) = config.checkpoint_id {
114 doc! { "thread_id": &config.thread_id, "checkpoint_id": id }
115 } else {
116 doc! { "thread_id": &config.thread_id }
117 };
118
119 let opts = mongodb::options::FindOneOptions::builder()
120 .sort(doc! { "seq": -1 })
121 .build();
122
123 let result = self
124 .collection
125 .find_one(filter)
126 .with_options(opts)
127 .await
128 .map_err(|e| SynapticError::Store(format!("MongoDB find_one: {e}")))?;
129
130 match result {
131 None => Ok(None),
132 Some(doc) => {
133 let state_str = doc
134 .get_str("state")
135 .map_err(|e| SynapticError::Store(format!("MongoDB get state field: {e}")))?;
136 let cp: Checkpoint = serde_json::from_str(state_str)
137 .map_err(|e| SynapticError::Store(format!("Deserialize: {e}")))?;
138 Ok(Some(cp))
139 }
140 }
141 }
142
143 async fn list(&self, config: &CheckpointConfig) -> Result<Vec<Checkpoint>, SynapticError> {
144 let filter = doc! { "thread_id": &config.thread_id };
145 let opts = mongodb::options::FindOptions::builder()
146 .sort(doc! { "seq": 1 })
147 .build();
148
149 let mut cursor = self
150 .collection
151 .find(filter)
152 .with_options(opts)
153 .await
154 .map_err(|e| SynapticError::Store(format!("MongoDB find: {e}")))?;
155
156 let mut checkpoints = Vec::new();
157 while let Some(doc) = cursor
158 .try_next()
159 .await
160 .map_err(|e| SynapticError::Store(format!("MongoDB cursor: {e}")))?
161 {
162 let state_str = doc
163 .get_str("state")
164 .map_err(|e| SynapticError::Store(format!("MongoDB get state field: {e}")))?;
165 let cp: Checkpoint = serde_json::from_str(state_str)
166 .map_err(|e| SynapticError::Store(format!("Deserialize: {e}")))?;
167 checkpoints.push(cp);
168 }
169
170 Ok(checkpoints)
171 }
172}