Skip to main content

atomr_agents_channel_core/
store.rs

1//! Persistence surface for channel state.
2//!
3//! Same dual-backend pattern as
4//! [`atomr_agents_meetings_harness::MeetingsStore`]: a single trait
5//! that the in-memory default and a feature-gated checkpointer-backed
6//! impl both satisfy. The orchestrator holds an `Arc<dyn ChannelStore>`.
7
8use std::collections::HashMap;
9use std::sync::Arc;
10
11use async_trait::async_trait;
12use parking_lot::RwLock;
13use serde::{Deserialize, Serialize};
14
15use crate::content::ChannelMessageRecord;
16use crate::error::Result;
17use crate::ids::{ChannelId, ThreadId};
18use crate::spec::ChannelSpec;
19use crate::thread::Thread;
20
21/// Lightweight summary row for listing threads in a channel.
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct ThreadSummary {
24    pub id: ThreadId,
25    pub channel: ChannelId,
26    pub peer: String,
27    pub target_kind: String,
28    pub history_len: usize,
29}
30
31impl ThreadSummary {
32    pub fn of(t: &Thread) -> Self {
33        Self {
34            id: t.id.clone(),
35            channel: t.channel.clone(),
36            peer: t.peer.as_str().to_string(),
37            target_kind: t.target.kind().to_string(),
38            history_len: t.history.len(),
39        }
40    }
41}
42
43#[async_trait]
44pub trait ChannelStore: Send + Sync + 'static {
45    // ----- Channels --------------------------------------------------
46    async fn upsert_channel(&self, spec: &ChannelSpec) -> Result<()>;
47    async fn get_channel(&self, id: &ChannelId) -> Result<Option<ChannelSpec>>;
48    async fn list_channels(&self) -> Result<Vec<ChannelSpec>>;
49    async fn delete_channel(&self, id: &ChannelId) -> Result<()>;
50
51    // ----- Threads ---------------------------------------------------
52    async fn upsert_thread(&self, thread: &Thread) -> Result<()>;
53    async fn get_thread(&self, id: &ThreadId) -> Result<Option<Thread>>;
54    async fn list_threads(&self, channel: &ChannelId) -> Result<Vec<ThreadSummary>>;
55    async fn delete_thread(&self, id: &ThreadId) -> Result<()>;
56
57    // ----- Messages --------------------------------------------------
58    async fn append_message(&self, rec: &ChannelMessageRecord) -> Result<()>;
59    async fn list_messages(&self, thread: &ThreadId, limit: usize) -> Result<Vec<ChannelMessageRecord>>;
60
61    /// Idempotency / dedup helpers.
62    async fn lookup_outbound_by_key(
63        &self,
64        thread: &ThreadId,
65        idempotency_key: &str,
66    ) -> Result<Option<String>>;
67    async fn has_inbound(&self, channel: &ChannelId, provider_msg_id: &str) -> Result<bool>;
68}
69
70#[async_trait]
71impl ChannelStore for Arc<dyn ChannelStore> {
72    async fn upsert_channel(&self, spec: &ChannelSpec) -> Result<()> {
73        (**self).upsert_channel(spec).await
74    }
75    async fn get_channel(&self, id: &ChannelId) -> Result<Option<ChannelSpec>> {
76        (**self).get_channel(id).await
77    }
78    async fn list_channels(&self) -> Result<Vec<ChannelSpec>> {
79        (**self).list_channels().await
80    }
81    async fn delete_channel(&self, id: &ChannelId) -> Result<()> {
82        (**self).delete_channel(id).await
83    }
84    async fn upsert_thread(&self, thread: &Thread) -> Result<()> {
85        (**self).upsert_thread(thread).await
86    }
87    async fn get_thread(&self, id: &ThreadId) -> Result<Option<Thread>> {
88        (**self).get_thread(id).await
89    }
90    async fn list_threads(&self, channel: &ChannelId) -> Result<Vec<ThreadSummary>> {
91        (**self).list_threads(channel).await
92    }
93    async fn delete_thread(&self, id: &ThreadId) -> Result<()> {
94        (**self).delete_thread(id).await
95    }
96    async fn append_message(&self, rec: &ChannelMessageRecord) -> Result<()> {
97        (**self).append_message(rec).await
98    }
99    async fn list_messages(&self, thread: &ThreadId, limit: usize) -> Result<Vec<ChannelMessageRecord>> {
100        (**self).list_messages(thread, limit).await
101    }
102    async fn lookup_outbound_by_key(
103        &self,
104        thread: &ThreadId,
105        idempotency_key: &str,
106    ) -> Result<Option<String>> {
107        (**self).lookup_outbound_by_key(thread, idempotency_key).await
108    }
109    async fn has_inbound(&self, channel: &ChannelId, provider_msg_id: &str) -> Result<bool> {
110        (**self).has_inbound(channel, provider_msg_id).await
111    }
112}
113
114/// Process-local, volatile channel store.
115#[derive(Default)]
116struct StoreInner {
117    channels: HashMap<ChannelId, ChannelSpec>,
118    threads: HashMap<ThreadId, Thread>,
119    messages: HashMap<ThreadId, Vec<ChannelMessageRecord>>,
120    /// `(channel_id, provider_msg_id)` set used for inbound dedup.
121    inbound_seen: std::collections::HashSet<(ChannelId, String)>,
122}
123
124#[derive(Default, Clone)]
125pub struct InMemoryChannelStore {
126    inner: Arc<RwLock<StoreInner>>,
127}
128
129impl InMemoryChannelStore {
130    pub fn new() -> Self {
131        Self::default()
132    }
133}
134
135#[async_trait]
136impl ChannelStore for InMemoryChannelStore {
137    async fn upsert_channel(&self, spec: &ChannelSpec) -> Result<()> {
138        self.inner.write().channels.insert(spec.id.clone(), spec.clone());
139        Ok(())
140    }
141
142    async fn get_channel(&self, id: &ChannelId) -> Result<Option<ChannelSpec>> {
143        Ok(self.inner.read().channels.get(id).cloned())
144    }
145
146    async fn list_channels(&self) -> Result<Vec<ChannelSpec>> {
147        let mut v: Vec<_> = self.inner.read().channels.values().cloned().collect();
148        v.sort_by(|a, b| a.id.as_str().cmp(b.id.as_str()));
149        Ok(v)
150    }
151
152    async fn delete_channel(&self, id: &ChannelId) -> Result<()> {
153        let mut g = self.inner.write();
154        g.channels.remove(id);
155        let drop_threads: Vec<_> = g
156            .threads
157            .iter()
158            .filter(|(_, t)| &t.channel == id)
159            .map(|(tid, _)| tid.clone())
160            .collect();
161        for tid in drop_threads {
162            g.threads.remove(&tid);
163            g.messages.remove(&tid);
164        }
165        g.inbound_seen.retain(|(cid, _)| cid != id);
166        Ok(())
167    }
168
169    async fn upsert_thread(&self, thread: &Thread) -> Result<()> {
170        self.inner.write().threads.insert(thread.id.clone(), thread.clone());
171        Ok(())
172    }
173
174    async fn get_thread(&self, id: &ThreadId) -> Result<Option<Thread>> {
175        Ok(self.inner.read().threads.get(id).cloned())
176    }
177
178    async fn list_threads(&self, channel: &ChannelId) -> Result<Vec<ThreadSummary>> {
179        let mut v: Vec<_> = self
180            .inner
181            .read()
182            .threads
183            .values()
184            .filter(|t| &t.channel == channel)
185            .map(ThreadSummary::of)
186            .collect();
187        v.sort_by(|a, b| a.id.as_str().cmp(b.id.as_str()));
188        Ok(v)
189    }
190
191    async fn delete_thread(&self, id: &ThreadId) -> Result<()> {
192        let mut g = self.inner.write();
193        g.threads.remove(id);
194        g.messages.remove(id);
195        Ok(())
196    }
197
198    async fn append_message(&self, rec: &ChannelMessageRecord) -> Result<()> {
199        let mut g = self.inner.write();
200        if let Some(pid) = &rec.provider_msg_id {
201            if matches!(rec.direction, crate::content::Direction::Inbound) {
202                let channel = g.threads.get(&rec.thread_id).map(|t| t.channel.clone());
203                if let Some(channel) = channel {
204                    g.inbound_seen.insert((channel, pid.clone()));
205                }
206            }
207        }
208        g.messages
209            .entry(rec.thread_id.clone())
210            .or_default()
211            .push(rec.clone());
212        Ok(())
213    }
214
215    async fn list_messages(&self, thread: &ThreadId, limit: usize) -> Result<Vec<ChannelMessageRecord>> {
216        let g = self.inner.read();
217        let v = g
218            .messages
219            .get(thread)
220            .map(|v| {
221                let take = if limit == 0 || limit > v.len() {
222                    v.len()
223                } else {
224                    limit
225                };
226                v[v.len() - take..].to_vec()
227            })
228            .unwrap_or_default();
229        Ok(v)
230    }
231
232    async fn lookup_outbound_by_key(
233        &self,
234        thread: &ThreadId,
235        idempotency_key: &str,
236    ) -> Result<Option<String>> {
237        Ok(self.inner.read().messages.get(thread).and_then(|v| {
238            v.iter()
239                .find(|r| r.idempotency_key.as_deref() == Some(idempotency_key))
240                .and_then(|r| r.provider_msg_id.clone())
241        }))
242    }
243
244    async fn has_inbound(&self, channel: &ChannelId, provider_msg_id: &str) -> Result<bool> {
245        Ok(self
246            .inner
247            .read()
248            .inbound_seen
249            .contains(&(channel.clone(), provider_msg_id.to_string())))
250    }
251}
252
253#[cfg(test)]
254mod tests {
255    use super::*;
256    use crate::content::{Direction, MessageContent};
257    use crate::ids::PeerId;
258    use crate::spec::{Capabilities, ProviderKind};
259    use crate::target::ThreadTarget;
260    use atomr_agents_callable::FnCallable;
261    use std::sync::Arc;
262
263    fn fake_thread(channel: &ChannelId, peer: &str) -> Thread {
264        let handle: atomr_agents_callable::CallableHandle =
265            Arc::new(FnCallable::new(|v, _ctx| async move { Ok(v) }));
266        Thread::new(
267            channel.clone(),
268            PeerId::from(peer),
269            ThreadTarget::callable(handle),
270        )
271    }
272
273    #[tokio::test]
274    async fn channel_round_trip() {
275        let s = InMemoryChannelStore::new();
276        let spec = ChannelSpec::new(ChannelId::from("memory:dev"), ProviderKind::Memory)
277            .with_capabilities(Capabilities::text_only());
278        s.upsert_channel(&spec).await.unwrap();
279        assert_eq!(s.list_channels().await.unwrap().len(), 1);
280        s.delete_channel(&spec.id).await.unwrap();
281        assert!(s.list_channels().await.unwrap().is_empty());
282    }
283
284    #[tokio::test]
285    async fn thread_round_trip_and_message_log() {
286        let s = InMemoryChannelStore::new();
287        let chan = ChannelId::from("memory:dev");
288        let t = fake_thread(&chan, "alice");
289        s.upsert_channel(
290            &ChannelSpec::new(chan.clone(), ProviderKind::Memory),
291        )
292        .await
293        .unwrap();
294        s.upsert_thread(&t).await.unwrap();
295        assert_eq!(s.list_threads(&chan).await.unwrap().len(), 1);
296
297        let rec = ChannelMessageRecord {
298            thread_id: t.id.clone(),
299            id: "m1".into(),
300            direction: Direction::Inbound,
301            content: MessageContent::text("hi"),
302            provider_msg_id: Some("pmid-1".into()),
303            idempotency_key: None,
304            at: chrono::Utc::now(),
305        };
306        s.append_message(&rec).await.unwrap();
307        assert!(s.has_inbound(&chan, "pmid-1").await.unwrap());
308        assert_eq!(s.list_messages(&t.id, 0).await.unwrap().len(), 1);
309    }
310}