Skip to main content

envoy/message/
store.rs

1use super::types::{MessageEnvelope, MessageType, Part, PartContent};
2use crate::error::{EnvoyError, Result};
3
4const KIND_MESSAGE: &str = "EnvoyMessage";
5const KIND_MSG_SEQ_COUNTER: &str = "EnvoyMsgSeqCounter";
6
7/// Stateless message store. All methods take `&SqliteGraph` for the shared connection.
8pub struct MessageStore;
9
10impl Default for MessageStore {
11    fn default() -> Self {
12        Self::new()
13    }
14}
15
16impl MessageStore {
17    pub fn new() -> Self {
18        Self
19    }
20
21    /// Store a message and return a fully-built MessageEnvelope.
22    #[allow(clippy::too_many_arguments)]
23    pub fn store(
24        &self,
25        graph: &sqlitegraph::SqliteGraph,
26        msg_type: MessageType,
27        from: String,
28        to: String,
29        task_id: Option<String>,
30        context_id: Option<String>,
31        parts: Vec<Part>,
32    ) -> Result<MessageEnvelope> {
33        use sqlitegraph::GraphEntity;
34
35        let msg_type_val = serde_json::to_value(&msg_type)?;
36
37        let temp = MessageEnvelope {
38            message_id: String::new(),
39            msg_type,
40            from,
41            to,
42            task_id,
43            context_id,
44            timestamp: String::new(),
45            sequence_id: 0,
46            parts,
47        };
48        temp.validate()?;
49
50        let timestamp = chrono::Utc::now().to_rfc3339();
51
52        // Per-recipient sequence counter
53        let counter_name = format!("msg-seq-{}", temp.to);
54        let sequence_id = if let Some(mut entity) =
55            graph.find_entity_by_kind_and_name(KIND_MSG_SEQ_COUNTER, &counter_name)?
56        {
57            let next = entity
58                .data
59                .get("next")
60                .and_then(|v| v.as_i64())
61                .unwrap_or(1);
62            entity.data["next"] = serde_json::json!(next + 1);
63            graph.update_entity(&entity)?;
64            next
65        } else {
66            let entity = GraphEntity {
67                id: 0,
68                kind: KIND_MSG_SEQ_COUNTER.to_string(),
69                name: counter_name,
70                file_path: None,
71                data: serde_json::json!({"next": 2}),
72            };
73            graph.insert_entity(&entity)?;
74            1
75        };
76
77        let entity = GraphEntity {
78            id: 0,
79            kind: KIND_MESSAGE.to_string(),
80            name: format!("msg-{}", uuid::Uuid::new_v4()),
81            file_path: None,
82            data: serde_json::json!({
83                "msg_type": msg_type_val,
84                "from": temp.from,
85                "to": temp.to,
86                "task_id": temp.task_id,
87                "context_id": temp.context_id,
88                "timestamp": timestamp,
89                "sequence_id": sequence_id,
90                "parts": serde_json::to_value(&temp.parts)?,
91            }),
92        };
93        let id = graph.insert_entity(&entity)?;
94
95        Ok(MessageEnvelope {
96            message_id: id.to_string(),
97            msg_type: temp.msg_type,
98            from: temp.from,
99            to: temp.to,
100            task_id: temp.task_id,
101            context_id: temp.context_id,
102            timestamp,
103            sequence_id,
104            parts: temp.parts,
105        })
106    }
107
108    /// Store a system notification for an offline agent.
109    /// Reuses the message entity schema so it appears in poll/reconnect catch-up.
110    pub fn store_notification(
111        &self,
112        graph: &sqlitegraph::SqliteGraph,
113        to: &str,
114        event_type: &str,
115        data: &serde_json::Value,
116    ) -> Result<MessageEnvelope> {
117        let text = serde_json::to_string(data).unwrap_or_default();
118        self.store(
119            graph,
120            MessageType::System,
121            "envoy".to_string(),
122            to.to_string(),
123            None,
124            Some(event_type.to_string()),
125            vec![Part {
126                content: PartContent::Text(text),
127            }],
128        )
129    }
130
131    /// Mark a message as consumed (ACKed) by an agent.
132    pub fn ack(
133        &self,
134        graph: &sqlitegraph::SqliteGraph,
135        message_id: &str,
136        agent_id: &str,
137    ) -> Result<Vec<String>> {
138        let id: i64 = message_id
139            .parse()
140            .map_err(|_| EnvoyError::MessageNotFound(message_id.to_string()))?;
141        let mut entity = graph
142            .get_entity(id)
143            .map_err(|_| EnvoyError::MessageNotFound(message_id.to_string()))?;
144        if entity.kind != KIND_MESSAGE {
145            return Err(EnvoyError::MessageNotFound(message_id.to_string()));
146        }
147
148        let mut acked: Vec<String> = entity
149            .data
150            .get("acked_by")
151            .and_then(|v| serde_json::from_value(v.clone()).ok())
152            .unwrap_or_default();
153
154        if !acked.iter().any(|a| a == agent_id) {
155            acked.push(agent_id.to_string());
156        }
157
158        entity.data["acked_by"] = serde_json::to_value(&acked)?;
159        graph.update_entity(&entity)?;
160        Ok(acked)
161    }
162
163    /// Get messages for a recipient since a given sequence_id.
164    /// When `include_acked` is false, only returns messages not yet ACKed by the recipient.
165    pub fn poll(
166        &self,
167        graph: &sqlitegraph::SqliteGraph,
168        to: &str,
169        since: i64,
170        limit: i64,
171        include_acked: bool,
172    ) -> Result<Vec<MessageEnvelope>> {
173        let limit = limit.min(100);
174        let entities = graph.find_entities_by_kind(KIND_MESSAGE)?;
175        let mut messages: Vec<MessageEnvelope> = entities
176            .iter()
177            .filter(|e| {
178                let msg_to = e.data.get("to").and_then(|v| v.as_str()).unwrap_or("");
179                let seq = e
180                    .data
181                    .get("sequence_id")
182                    .and_then(|v| v.as_i64())
183                    .unwrap_or(0);
184                if msg_to != to || seq <= since {
185                    return false;
186                }
187                if include_acked {
188                    return true;
189                }
190                // Filter: only include if recipient hasn't ACKed
191                let acked_by: Vec<String> = e
192                    .data
193                    .get("acked_by")
194                    .and_then(|v| serde_json::from_value(v.clone()).ok())
195                    .unwrap_or_default();
196                !acked_by.iter().any(|a| a == to)
197            })
198            .map(entity_to_envelope)
199            .filter_map(|r| r.ok())
200            .collect();
201        messages.sort_by_key(|m| m.sequence_id);
202        messages.truncate(limit as usize);
203        Ok(messages)
204    }
205
206    /// Get a single message by ID.
207    pub fn get(
208        &self,
209        graph: &sqlitegraph::SqliteGraph,
210        message_id: &str,
211    ) -> Result<MessageEnvelope> {
212        let id: i64 = message_id
213            .parse()
214            .map_err(|_| EnvoyError::MessageNotFound(message_id.to_string()))?;
215        let entity = graph
216            .get_entity(id)
217            .map_err(|_| EnvoyError::MessageNotFound(message_id.to_string()))?;
218        if entity.kind != KIND_MESSAGE {
219            return Err(EnvoyError::MessageNotFound(message_id.to_string()));
220        }
221        entity_to_envelope(&entity)
222    }
223
224    /// Get total message count.
225    pub fn count_all(&self, graph: &sqlitegraph::SqliteGraph) -> Result<i64> {
226        Ok(graph.find_entities_by_kind(KIND_MESSAGE)?.len() as i64)
227    }
228}
229
230fn entity_to_envelope(entity: &sqlitegraph::GraphEntity) -> Result<MessageEnvelope> {
231    let msg_type: MessageType = entity
232        .data
233        .get("msg_type")
234        .and_then(|v| serde_json::from_value(v.clone()).ok())
235        .unwrap_or(MessageType::Direct);
236    let parts: Vec<Part> = entity
237        .data
238        .get("parts")
239        .and_then(|v| serde_json::from_value(v.clone()).ok())
240        .unwrap_or_default();
241    Ok(MessageEnvelope {
242        message_id: entity.id.to_string(),
243        msg_type,
244        from: read_json_str(&entity.data, "from"),
245        to: read_json_str(&entity.data, "to"),
246        task_id: entity
247            .data
248            .get("task_id")
249            .and_then(|v| v.as_str())
250            .map(String::from),
251        context_id: entity
252            .data
253            .get("context_id")
254            .and_then(|v| v.as_str())
255            .map(String::from),
256        timestamp: read_json_str(&entity.data, "timestamp"),
257        sequence_id: entity
258            .data
259            .get("sequence_id")
260            .and_then(|v| v.as_i64())
261            .unwrap_or(0),
262        parts,
263    })
264}
265
266fn read_json_str(data: &serde_json::Value, key: &str) -> String {
267    data.get(key)
268        .and_then(|v| v.as_str())
269        .unwrap_or("")
270        .to_string()
271}
272
273#[cfg(test)]
274mod tests {
275    use super::super::types::{MessageType, Part, PartContent};
276    use super::*;
277    use crate::engine::Engine;
278
279    #[test]
280    fn message_store_assigns_ids() {
281        let engine = Engine::open_in_memory().unwrap();
282        let graph = engine.graph();
283        let store = MessageStore::new();
284
285        let stored = store
286            .store(
287                graph,
288                MessageType::Direct,
289                "id1".into(),
290                "id2".into(),
291                None,
292                None,
293                vec![Part {
294                    content: PartContent::Text("hello".into()),
295                }],
296            )
297            .unwrap();
298        assert!(!stored.message_id.is_empty());
299        assert!(!stored.timestamp.is_empty());
300        assert_eq!(stored.sequence_id, 1);
301
302        let stored2 = store
303            .store(
304                graph,
305                MessageType::Direct,
306                "id1".into(),
307                "id2".into(),
308                None,
309                None,
310                vec![Part {
311                    content: PartContent::Text("world".into()),
312                }],
313            )
314            .unwrap();
315        assert_eq!(stored2.sequence_id, 2);
316
317        let msgs = store.poll(graph, "id2", 0, 50, true).unwrap();
318        assert_eq!(msgs.len(), 2);
319
320        let msgs = store.poll(graph, "id2", 1, 50, true).unwrap();
321        assert_eq!(msgs.len(), 1);
322        assert_eq!(msgs[0].sequence_id, 2);
323    }
324}