1use super::types::{MessageEnvelope, MessageType, Part, PartContent};
2use crate::error::{EnvoyError, Result};
3
4const KIND_MESSAGE: &str = "EnvoyMessage";
5const KIND_MSG_SEQ_COUNTER: &str = "EnvoyMsgSeqCounter";
6
7pub struct MessageStore;
9
10impl Default for MessageStore {
11 fn default() -> Self {
12 Self::new()
13 }
14}
15
16impl MessageStore {
17 pub fn new() -> Self {
18 Self
19 }
20
21 #[allow(clippy::too_many_arguments)]
23 pub fn store(
24 &self,
25 graph: &sqlitegraph::SqliteGraph,
26 msg_type: MessageType,
27 from: String,
28 to: String,
29 task_id: Option<String>,
30 context_id: Option<String>,
31 parts: Vec<Part>,
32 ) -> Result<MessageEnvelope> {
33 use sqlitegraph::GraphEntity;
34
35 let msg_type_val = serde_json::to_value(&msg_type)?;
36
37 let temp = MessageEnvelope {
38 message_id: String::new(),
39 msg_type,
40 from,
41 to,
42 task_id,
43 context_id,
44 timestamp: String::new(),
45 sequence_id: 0,
46 parts,
47 };
48 temp.validate()?;
49
50 let timestamp = chrono::Utc::now().to_rfc3339();
51
52 let counter_name = format!("msg-seq-{}", temp.to);
54 let sequence_id = if let Some(mut entity) =
55 graph.find_entity_by_kind_and_name(KIND_MSG_SEQ_COUNTER, &counter_name)?
56 {
57 let next = entity
58 .data
59 .get("next")
60 .and_then(|v| v.as_i64())
61 .unwrap_or(1);
62 entity.data["next"] = serde_json::json!(next + 1);
63 graph.update_entity(&entity)?;
64 next
65 } else {
66 let entity = GraphEntity {
67 id: 0,
68 kind: KIND_MSG_SEQ_COUNTER.to_string(),
69 name: counter_name,
70 file_path: None,
71 data: serde_json::json!({"next": 2}),
72 };
73 graph.insert_entity(&entity)?;
74 1
75 };
76
77 let entity = GraphEntity {
78 id: 0,
79 kind: KIND_MESSAGE.to_string(),
80 name: format!("msg-{}", uuid::Uuid::new_v4()),
81 file_path: None,
82 data: serde_json::json!({
83 "msg_type": msg_type_val,
84 "from": temp.from,
85 "to": temp.to,
86 "task_id": temp.task_id,
87 "context_id": temp.context_id,
88 "timestamp": timestamp,
89 "sequence_id": sequence_id,
90 "parts": serde_json::to_value(&temp.parts)?,
91 }),
92 };
93 let id = graph.insert_entity(&entity)?;
94
95 Ok(MessageEnvelope {
96 message_id: id.to_string(),
97 msg_type: temp.msg_type,
98 from: temp.from,
99 to: temp.to,
100 task_id: temp.task_id,
101 context_id: temp.context_id,
102 timestamp,
103 sequence_id,
104 parts: temp.parts,
105 })
106 }
107
108 pub fn store_notification(
111 &self,
112 graph: &sqlitegraph::SqliteGraph,
113 to: &str,
114 event_type: &str,
115 data: &serde_json::Value,
116 ) -> Result<MessageEnvelope> {
117 let text = serde_json::to_string(data).unwrap_or_default();
118 self.store(
119 graph,
120 MessageType::System,
121 "envoy".to_string(),
122 to.to_string(),
123 None,
124 Some(event_type.to_string()),
125 vec![Part {
126 content: PartContent::Text(text),
127 }],
128 )
129 }
130
131 pub fn ack(
133 &self,
134 graph: &sqlitegraph::SqliteGraph,
135 message_id: &str,
136 agent_id: &str,
137 ) -> Result<Vec<String>> {
138 let id: i64 = message_id
139 .parse()
140 .map_err(|_| EnvoyError::MessageNotFound(message_id.to_string()))?;
141 let mut entity = graph
142 .get_entity(id)
143 .map_err(|_| EnvoyError::MessageNotFound(message_id.to_string()))?;
144 if entity.kind != KIND_MESSAGE {
145 return Err(EnvoyError::MessageNotFound(message_id.to_string()));
146 }
147
148 let mut acked: Vec<String> = entity
149 .data
150 .get("acked_by")
151 .and_then(|v| serde_json::from_value(v.clone()).ok())
152 .unwrap_or_default();
153
154 if !acked.iter().any(|a| a == agent_id) {
155 acked.push(agent_id.to_string());
156 }
157
158 entity.data["acked_by"] = serde_json::to_value(&acked)?;
159 graph.update_entity(&entity)?;
160 Ok(acked)
161 }
162
163 pub fn poll(
166 &self,
167 graph: &sqlitegraph::SqliteGraph,
168 to: &str,
169 since: i64,
170 limit: i64,
171 include_acked: bool,
172 ) -> Result<Vec<MessageEnvelope>> {
173 let limit = limit.min(100);
174 let entities = graph.find_entities_by_kind(KIND_MESSAGE)?;
175 let mut messages: Vec<MessageEnvelope> = entities
176 .iter()
177 .filter(|e| {
178 let msg_to = e.data.get("to").and_then(|v| v.as_str()).unwrap_or("");
179 let seq = e
180 .data
181 .get("sequence_id")
182 .and_then(|v| v.as_i64())
183 .unwrap_or(0);
184 if msg_to != to || seq <= since {
185 return false;
186 }
187 if include_acked {
188 return true;
189 }
190 let acked_by: Vec<String> = e
192 .data
193 .get("acked_by")
194 .and_then(|v| serde_json::from_value(v.clone()).ok())
195 .unwrap_or_default();
196 !acked_by.iter().any(|a| a == to)
197 })
198 .map(entity_to_envelope)
199 .filter_map(|r| r.ok())
200 .collect();
201 messages.sort_by_key(|m| m.sequence_id);
202 messages.truncate(limit as usize);
203 Ok(messages)
204 }
205
206 pub fn get(
208 &self,
209 graph: &sqlitegraph::SqliteGraph,
210 message_id: &str,
211 ) -> Result<MessageEnvelope> {
212 let id: i64 = message_id
213 .parse()
214 .map_err(|_| EnvoyError::MessageNotFound(message_id.to_string()))?;
215 let entity = graph
216 .get_entity(id)
217 .map_err(|_| EnvoyError::MessageNotFound(message_id.to_string()))?;
218 if entity.kind != KIND_MESSAGE {
219 return Err(EnvoyError::MessageNotFound(message_id.to_string()));
220 }
221 entity_to_envelope(&entity)
222 }
223
224 pub fn count_all(&self, graph: &sqlitegraph::SqliteGraph) -> Result<i64> {
226 Ok(graph.find_entities_by_kind(KIND_MESSAGE)?.len() as i64)
227 }
228}
229
230fn entity_to_envelope(entity: &sqlitegraph::GraphEntity) -> Result<MessageEnvelope> {
231 let msg_type: MessageType = entity
232 .data
233 .get("msg_type")
234 .and_then(|v| serde_json::from_value(v.clone()).ok())
235 .unwrap_or(MessageType::Direct);
236 let parts: Vec<Part> = entity
237 .data
238 .get("parts")
239 .and_then(|v| serde_json::from_value(v.clone()).ok())
240 .unwrap_or_default();
241 Ok(MessageEnvelope {
242 message_id: entity.id.to_string(),
243 msg_type,
244 from: read_json_str(&entity.data, "from"),
245 to: read_json_str(&entity.data, "to"),
246 task_id: entity
247 .data
248 .get("task_id")
249 .and_then(|v| v.as_str())
250 .map(String::from),
251 context_id: entity
252 .data
253 .get("context_id")
254 .and_then(|v| v.as_str())
255 .map(String::from),
256 timestamp: read_json_str(&entity.data, "timestamp"),
257 sequence_id: entity
258 .data
259 .get("sequence_id")
260 .and_then(|v| v.as_i64())
261 .unwrap_or(0),
262 parts,
263 })
264}
265
266fn read_json_str(data: &serde_json::Value, key: &str) -> String {
267 data.get(key)
268 .and_then(|v| v.as_str())
269 .unwrap_or("")
270 .to_string()
271}
272
273#[cfg(test)]
274mod tests {
275 use super::super::types::{MessageType, Part, PartContent};
276 use super::*;
277 use crate::engine::Engine;
278
279 #[test]
280 fn message_store_assigns_ids() {
281 let engine = Engine::open_in_memory().unwrap();
282 let graph = engine.graph();
283 let store = MessageStore::new();
284
285 let stored = store
286 .store(
287 graph,
288 MessageType::Direct,
289 "id1".into(),
290 "id2".into(),
291 None,
292 None,
293 vec![Part {
294 content: PartContent::Text("hello".into()),
295 }],
296 )
297 .unwrap();
298 assert!(!stored.message_id.is_empty());
299 assert!(!stored.timestamp.is_empty());
300 assert_eq!(stored.sequence_id, 1);
301
302 let stored2 = store
303 .store(
304 graph,
305 MessageType::Direct,
306 "id1".into(),
307 "id2".into(),
308 None,
309 None,
310 vec![Part {
311 content: PartContent::Text("world".into()),
312 }],
313 )
314 .unwrap();
315 assert_eq!(stored2.sequence_id, 2);
316
317 let msgs = store.poll(graph, "id2", 0, 50, true).unwrap();
318 assert_eq!(msgs.len(), 2);
319
320 let msgs = store.poll(graph, "id2", 1, 50, true).unwrap();
321 assert_eq!(msgs.len(), 1);
322 assert_eq!(msgs[0].sequence_id, 2);
323 }
324}