atomr_agents_channel_core/
store.rs1use std::collections::HashMap;
9use std::sync::Arc;
10
11use async_trait::async_trait;
12use parking_lot::RwLock;
13use serde::{Deserialize, Serialize};
14
15use crate::content::ChannelMessageRecord;
16use crate::error::Result;
17use crate::ids::{ChannelId, ThreadId};
18use crate::spec::ChannelSpec;
19use crate::thread::Thread;
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct ThreadSummary {
24 pub id: ThreadId,
25 pub channel: ChannelId,
26 pub peer: String,
27 pub target_kind: String,
28 pub history_len: usize,
29}
30
31impl ThreadSummary {
32 pub fn of(t: &Thread) -> Self {
33 Self {
34 id: t.id.clone(),
35 channel: t.channel.clone(),
36 peer: t.peer.as_str().to_string(),
37 target_kind: t.target.kind().to_string(),
38 history_len: t.history.len(),
39 }
40 }
41}
42
43#[async_trait]
44pub trait ChannelStore: Send + Sync + 'static {
45 async fn upsert_channel(&self, spec: &ChannelSpec) -> Result<()>;
47 async fn get_channel(&self, id: &ChannelId) -> Result<Option<ChannelSpec>>;
48 async fn list_channels(&self) -> Result<Vec<ChannelSpec>>;
49 async fn delete_channel(&self, id: &ChannelId) -> Result<()>;
50
51 async fn upsert_thread(&self, thread: &Thread) -> Result<()>;
53 async fn get_thread(&self, id: &ThreadId) -> Result<Option<Thread>>;
54 async fn list_threads(&self, channel: &ChannelId) -> Result<Vec<ThreadSummary>>;
55 async fn delete_thread(&self, id: &ThreadId) -> Result<()>;
56
57 async fn append_message(&self, rec: &ChannelMessageRecord) -> Result<()>;
59 async fn list_messages(&self, thread: &ThreadId, limit: usize) -> Result<Vec<ChannelMessageRecord>>;
60
61 async fn lookup_outbound_by_key(
63 &self,
64 thread: &ThreadId,
65 idempotency_key: &str,
66 ) -> Result<Option<String>>;
67 async fn has_inbound(&self, channel: &ChannelId, provider_msg_id: &str) -> Result<bool>;
68}
69
70#[async_trait]
71impl ChannelStore for Arc<dyn ChannelStore> {
72 async fn upsert_channel(&self, spec: &ChannelSpec) -> Result<()> {
73 (**self).upsert_channel(spec).await
74 }
75 async fn get_channel(&self, id: &ChannelId) -> Result<Option<ChannelSpec>> {
76 (**self).get_channel(id).await
77 }
78 async fn list_channels(&self) -> Result<Vec<ChannelSpec>> {
79 (**self).list_channels().await
80 }
81 async fn delete_channel(&self, id: &ChannelId) -> Result<()> {
82 (**self).delete_channel(id).await
83 }
84 async fn upsert_thread(&self, thread: &Thread) -> Result<()> {
85 (**self).upsert_thread(thread).await
86 }
87 async fn get_thread(&self, id: &ThreadId) -> Result<Option<Thread>> {
88 (**self).get_thread(id).await
89 }
90 async fn list_threads(&self, channel: &ChannelId) -> Result<Vec<ThreadSummary>> {
91 (**self).list_threads(channel).await
92 }
93 async fn delete_thread(&self, id: &ThreadId) -> Result<()> {
94 (**self).delete_thread(id).await
95 }
96 async fn append_message(&self, rec: &ChannelMessageRecord) -> Result<()> {
97 (**self).append_message(rec).await
98 }
99 async fn list_messages(&self, thread: &ThreadId, limit: usize) -> Result<Vec<ChannelMessageRecord>> {
100 (**self).list_messages(thread, limit).await
101 }
102 async fn lookup_outbound_by_key(
103 &self,
104 thread: &ThreadId,
105 idempotency_key: &str,
106 ) -> Result<Option<String>> {
107 (**self).lookup_outbound_by_key(thread, idempotency_key).await
108 }
109 async fn has_inbound(&self, channel: &ChannelId, provider_msg_id: &str) -> Result<bool> {
110 (**self).has_inbound(channel, provider_msg_id).await
111 }
112}
113
114#[derive(Default)]
116struct StoreInner {
117 channels: HashMap<ChannelId, ChannelSpec>,
118 threads: HashMap<ThreadId, Thread>,
119 messages: HashMap<ThreadId, Vec<ChannelMessageRecord>>,
120 inbound_seen: std::collections::HashSet<(ChannelId, String)>,
122}
123
124#[derive(Default, Clone)]
125pub struct InMemoryChannelStore {
126 inner: Arc<RwLock<StoreInner>>,
127}
128
129impl InMemoryChannelStore {
130 pub fn new() -> Self {
131 Self::default()
132 }
133}
134
135#[async_trait]
136impl ChannelStore for InMemoryChannelStore {
137 async fn upsert_channel(&self, spec: &ChannelSpec) -> Result<()> {
138 self.inner.write().channels.insert(spec.id.clone(), spec.clone());
139 Ok(())
140 }
141
142 async fn get_channel(&self, id: &ChannelId) -> Result<Option<ChannelSpec>> {
143 Ok(self.inner.read().channels.get(id).cloned())
144 }
145
146 async fn list_channels(&self) -> Result<Vec<ChannelSpec>> {
147 let mut v: Vec<_> = self.inner.read().channels.values().cloned().collect();
148 v.sort_by(|a, b| a.id.as_str().cmp(b.id.as_str()));
149 Ok(v)
150 }
151
152 async fn delete_channel(&self, id: &ChannelId) -> Result<()> {
153 let mut g = self.inner.write();
154 g.channels.remove(id);
155 let drop_threads: Vec<_> = g
156 .threads
157 .iter()
158 .filter(|(_, t)| &t.channel == id)
159 .map(|(tid, _)| tid.clone())
160 .collect();
161 for tid in drop_threads {
162 g.threads.remove(&tid);
163 g.messages.remove(&tid);
164 }
165 g.inbound_seen.retain(|(cid, _)| cid != id);
166 Ok(())
167 }
168
169 async fn upsert_thread(&self, thread: &Thread) -> Result<()> {
170 self.inner.write().threads.insert(thread.id.clone(), thread.clone());
171 Ok(())
172 }
173
174 async fn get_thread(&self, id: &ThreadId) -> Result<Option<Thread>> {
175 Ok(self.inner.read().threads.get(id).cloned())
176 }
177
178 async fn list_threads(&self, channel: &ChannelId) -> Result<Vec<ThreadSummary>> {
179 let mut v: Vec<_> = self
180 .inner
181 .read()
182 .threads
183 .values()
184 .filter(|t| &t.channel == channel)
185 .map(ThreadSummary::of)
186 .collect();
187 v.sort_by(|a, b| a.id.as_str().cmp(b.id.as_str()));
188 Ok(v)
189 }
190
191 async fn delete_thread(&self, id: &ThreadId) -> Result<()> {
192 let mut g = self.inner.write();
193 g.threads.remove(id);
194 g.messages.remove(id);
195 Ok(())
196 }
197
198 async fn append_message(&self, rec: &ChannelMessageRecord) -> Result<()> {
199 let mut g = self.inner.write();
200 if let Some(pid) = &rec.provider_msg_id {
201 if matches!(rec.direction, crate::content::Direction::Inbound) {
202 let channel = g.threads.get(&rec.thread_id).map(|t| t.channel.clone());
203 if let Some(channel) = channel {
204 g.inbound_seen.insert((channel, pid.clone()));
205 }
206 }
207 }
208 g.messages
209 .entry(rec.thread_id.clone())
210 .or_default()
211 .push(rec.clone());
212 Ok(())
213 }
214
215 async fn list_messages(&self, thread: &ThreadId, limit: usize) -> Result<Vec<ChannelMessageRecord>> {
216 let g = self.inner.read();
217 let v = g
218 .messages
219 .get(thread)
220 .map(|v| {
221 let take = if limit == 0 || limit > v.len() {
222 v.len()
223 } else {
224 limit
225 };
226 v[v.len() - take..].to_vec()
227 })
228 .unwrap_or_default();
229 Ok(v)
230 }
231
232 async fn lookup_outbound_by_key(
233 &self,
234 thread: &ThreadId,
235 idempotency_key: &str,
236 ) -> Result<Option<String>> {
237 Ok(self.inner.read().messages.get(thread).and_then(|v| {
238 v.iter()
239 .find(|r| r.idempotency_key.as_deref() == Some(idempotency_key))
240 .and_then(|r| r.provider_msg_id.clone())
241 }))
242 }
243
244 async fn has_inbound(&self, channel: &ChannelId, provider_msg_id: &str) -> Result<bool> {
245 Ok(self
246 .inner
247 .read()
248 .inbound_seen
249 .contains(&(channel.clone(), provider_msg_id.to_string())))
250 }
251}
252
253#[cfg(test)]
254mod tests {
255 use super::*;
256 use crate::content::{Direction, MessageContent};
257 use crate::ids::PeerId;
258 use crate::spec::{Capabilities, ProviderKind};
259 use crate::target::ThreadTarget;
260 use atomr_agents_callable::FnCallable;
261 use std::sync::Arc;
262
263 fn fake_thread(channel: &ChannelId, peer: &str) -> Thread {
264 let handle: atomr_agents_callable::CallableHandle =
265 Arc::new(FnCallable::new(|v, _ctx| async move { Ok(v) }));
266 Thread::new(
267 channel.clone(),
268 PeerId::from(peer),
269 ThreadTarget::callable(handle),
270 )
271 }
272
273 #[tokio::test]
274 async fn channel_round_trip() {
275 let s = InMemoryChannelStore::new();
276 let spec = ChannelSpec::new(ChannelId::from("memory:dev"), ProviderKind::Memory)
277 .with_capabilities(Capabilities::text_only());
278 s.upsert_channel(&spec).await.unwrap();
279 assert_eq!(s.list_channels().await.unwrap().len(), 1);
280 s.delete_channel(&spec.id).await.unwrap();
281 assert!(s.list_channels().await.unwrap().is_empty());
282 }
283
284 #[tokio::test]
285 async fn thread_round_trip_and_message_log() {
286 let s = InMemoryChannelStore::new();
287 let chan = ChannelId::from("memory:dev");
288 let t = fake_thread(&chan, "alice");
289 s.upsert_channel(
290 &ChannelSpec::new(chan.clone(), ProviderKind::Memory),
291 )
292 .await
293 .unwrap();
294 s.upsert_thread(&t).await.unwrap();
295 assert_eq!(s.list_threads(&chan).await.unwrap().len(), 1);
296
297 let rec = ChannelMessageRecord {
298 thread_id: t.id.clone(),
299 id: "m1".into(),
300 direction: Direction::Inbound,
301 content: MessageContent::text("hi"),
302 provider_msg_id: Some("pmid-1".into()),
303 idempotency_key: None,
304 at: chrono::Utc::now(),
305 };
306 s.append_message(&rec).await.unwrap();
307 assert!(s.has_inbound(&chan, "pmid-1").await.unwrap());
308 assert_eq!(s.list_messages(&t.id, 0).await.unwrap().len(), 1);
309 }
310}