1use crate::{Result, SUBSCRIPTION_REL_SUB_DB};
2use borderless::common::{Id, Introduction};
3use borderless::events::Topic;
4use borderless::{AgentId, Context, ContractId};
5use borderless_kv_store::{Db, RawWrite, RoCursor, RoTx};
6use std::str::FromStr;
7
8fn generate_key(publisher: Id, topic: String, subscriber: Option<AgentId>) -> String {
14 let publisher = match publisher {
16 Id::Contract { contract_id } => contract_id.to_string().to_ascii_lowercase(),
17 Id::Agent { agent_id } => agent_id.to_string().to_ascii_lowercase(),
18 };
19 let subscriber = subscriber
21 .map(|agent| agent.to_string().to_ascii_lowercase())
22 .unwrap_or_default();
23 let topic = topic.trim_matches('/').to_ascii_lowercase();
25
26 if topic.is_empty() && subscriber.is_empty() {
29 format!("{publisher}\n")
30 } else {
31 format!("{publisher}\n{topic}\n{subscriber}")
33 }
34}
35
36fn extract_key(key: &[u8]) -> Result<(String, AgentId)> {
40 let key = std::str::from_utf8(key).with_context(|| "DB key deserialization failed")?;
41
42 let mut parts = key.splitn(3, '\n');
43 match (parts.next(), parts.next(), parts.next()) {
44 (Some(publisher), Some(topic), Some(s)) => {
45 let subscriber =
46 AgentId::from_str(s).with_context(|| "AgentId deserialization error")?;
47 let full_topic = format!("/{publisher}/{topic}");
48 Ok((full_topic, subscriber))
49 }
50 _ => Err(crate::Error::msg(
51 "SubscriptionHandler: malformed key error",
52 )),
53 }
54}
55
56pub struct SubscriptionHandler<'a, S: Db> {
57 db: &'a S,
58}
59
60impl<'a, S: Db> SubscriptionHandler<'a, S> {
61 pub fn new(db: &'a S) -> Self {
62 Self { db }
63 }
64
65 pub fn init(&self, txn: &mut <S as Db>::RwTx<'_>, introduction: Introduction) -> Result<()> {
67 match introduction.id {
69 Id::Contract { .. } => {} Id::Agent { agent_id } => {
71 for s in introduction.subscriptions {
72 self.subscribe(txn, agent_id, s)?
73 }
74 }
75 }
76 Ok(())
77 }
78
79 pub fn subscribe(
81 &self,
82 txn: &mut <S as Db>::RwTx<'_>,
83 subscriber: AgentId,
84 topic: Topic,
85 ) -> Result<()> {
86 let db_ptr = self.db.open_sub_db(SUBSCRIPTION_REL_SUB_DB)?;
87 let key = generate_key(topic.publisher, topic.topic, Some(subscriber));
88 txn.write(&db_ptr, &key, &topic.method)?;
89 Ok(())
90 }
91
92 pub fn unsubscribe(
94 &self,
95 txn: &mut <S as Db>::RwTx<'_>,
96 subscriber: AgentId,
97 publisher: Id,
98 topic: String,
99 ) -> Result<()> {
100 let db_ptr = self.db.open_sub_db(SUBSCRIPTION_REL_SUB_DB)?;
101 let key = generate_key(publisher, topic, Some(subscriber));
102 Ok(txn.delete(&db_ptr, &key)?)
103 }
104
105 pub fn get_topic_subscribers(
107 &self,
108 publisher: Id,
109 topic: String,
110 ) -> Result<Vec<(AgentId, String)>> {
111 let db_ptr = self.db.open_sub_db(SUBSCRIPTION_REL_SUB_DB)?;
113 let txn = self.db.begin_ro_txn()?;
114 let mut cursor = txn.ro_cursor(&db_ptr)?;
115
116 let mut subscribers = Vec::new();
117
118 let prefix = generate_key(publisher, topic, None);
120
121 for (key, value) in cursor.iter_from(&prefix) {
122 if !key.starts_with(prefix.as_bytes()) {
124 break;
125 }
126 let topic =
128 String::from_utf8(value.to_vec()).with_context(|| "Failed to deserialize topic")?;
129 let (_, subscriber) = extract_key(key)?;
131 subscribers.push((subscriber, topic));
132 }
133 drop(cursor);
135 Ok(subscribers)
136 }
137
138 pub fn get_subscribers(&self, publisher: Id) -> Result<Vec<(AgentId, String)>> {
140 self.get_topic_subscribers(publisher, String::default())
141 }
142
143 pub fn get_subscriptions(&self, target: AgentId) -> Result<Vec<String>> {
145 let db_ptr = self.db.open_sub_db(SUBSCRIPTION_REL_SUB_DB)?;
147 let txn = self.db.begin_ro_txn()?;
148 let mut cursor = txn.ro_cursor(&db_ptr)?;
149
150 let mut topics = Vec::new();
151
152 for (key, _) in cursor.iter() {
154 let (full_topic, subscriber) = extract_key(key)?;
155 if target != subscriber {
157 continue;
158 }
159 topics.push(full_topic);
161 }
162 drop(cursor);
164 Ok(topics)
165 }
166
167 pub fn unsubscribe_all(&self, txn: &mut <S as Db>::RwTx<'_>, subscriber: Id) -> Result<()> {
168 let subscriber = match subscriber {
169 Id::Contract { .. } => return Ok(()), Id::Agent { agent_id } => agent_id,
171 };
172 let subscriptions = self.get_subscriptions(subscriber)?;
174
175 for s in subscriptions {
176 let mut parts = s.trim_matches('/').splitn(2, '/');
177 let p = parts.next().expect("Malformed key");
178 let topic = parts.next().expect("Malformed key").to_string();
179 let publisher = if let Ok(cid) = ContractId::from_str(p) {
180 Id::from(cid)
181 } else if let Ok(aid) = AgentId::from_str(p) {
182 Id::from(aid)
183 } else {
184 return Err(crate::error::Error::msg("Invalid publisher"));
185 };
186 self.unsubscribe(txn, subscriber, publisher, topic)?;
188 }
189 Ok(())
190 }
191}
192
193#[cfg(test)]
194mod tests {
195 use crate::db::subscriptions::SubscriptionHandler;
196 use crate::SUBSCRIPTION_REL_SUB_DB;
197 use borderless::common::Id;
198 use borderless::events::Topic;
199 use borderless::{AgentId, ContractId, Result};
200 use borderless_kv_store::backend::lmdb::Lmdb;
201 use borderless_kv_store::{Db, Tx};
202 use tempfile::tempdir;
203
204 const N: usize = 10;
205
206 fn open_tmp_lmdb() -> Lmdb {
207 let tmp_dir = tempdir().unwrap();
208 let env = Lmdb::new(tmp_dir.path(), 1).unwrap();
209 env.create_sub_db(SUBSCRIPTION_REL_SUB_DB).unwrap();
210 env
211 }
212
213 #[test]
214 fn subscription() -> Result<()> {
215 let lmdb = open_tmp_lmdb();
217 let handler = SubscriptionHandler::new(&lmdb);
218 let mut txn = lmdb.begin_rw_txn()?;
219
220 let subscribers: Vec<AgentId> = std::iter::repeat_with(|| AgentId::generate())
222 .take(N)
223 .collect();
224 let publishers: Vec<Id> = std::iter::repeat_with(|| Id::agent(AgentId::generate()))
225 .take(N)
226 .collect();
227 let topic = "MyTopic";
228
229 for i in 0..N {
231 let s = subscribers[i];
232 let p = publishers[i];
233 let topic = Topic::new(p, topic.to_string(), "method".to_string());
234 handler.subscribe(&mut txn, s, topic.clone())?;
235 }
236
237 txn.commit()?;
239
240 for i in 0..N {
242 let s = subscribers[i];
243 let p = publishers[i].to_string();
244
245 let subscriptions = handler.get_subscriptions(s)?;
246 assert_eq!(subscriptions.len(), 1);
247 let full_topic = format!("/{}/{}", p, topic.to_ascii_lowercase());
248 assert_eq!(subscriptions[0], full_topic);
249 }
250 Ok(())
251 }
252
253 #[test]
254 fn unsubscription() -> Result<()> {
255 let lmdb = open_tmp_lmdb();
257 let handler = SubscriptionHandler::new(&lmdb);
258
259 let subscribers: Vec<AgentId> = std::iter::repeat_with(|| AgentId::generate())
261 .take(N)
262 .collect();
263 let publishers: Vec<Id> = std::iter::repeat_with(|| Id::agent(AgentId::generate()))
264 .take(N)
265 .collect();
266 let topic = "MyTopic";
267
268 let mut txn = lmdb.begin_rw_txn()?;
270 for i in 0..N {
271 let topic = Topic::new(publishers[i], topic.to_string(), "method".to_string());
272 handler.subscribe(&mut txn, subscribers[i], topic)?;
274 }
275 txn.commit()?;
276
277 let mut txn = lmdb.begin_rw_txn()?;
279 for i in 0..N {
280 let s = subscribers[i];
281 let p = publishers[i];
282 assert!(handler.unsubscribe(&mut txn, s, p, topic.to_string())?);
284 }
285 Ok(())
286 }
287
288 #[test]
289 fn fetch_topic_subscribers() -> Result<()> {
290 let lmdb = open_tmp_lmdb();
292 let handler = SubscriptionHandler::new(&lmdb);
293
294 let mut subscribers: Vec<AgentId> = std::iter::repeat_with(|| AgentId::generate())
296 .take(N)
297 .collect();
298 let publisher = Id::contract(ContractId::generate());
299 let topic = "tennis";
300
301 let mut txn = lmdb.begin_rw_txn()?;
303 for i in 0..N {
304 let topic = Topic::new(publisher, topic.to_string(), "method".to_string());
305 handler.subscribe(&mut txn, subscribers[i], topic)?;
307 }
308 txn.commit()?;
309
310 let mut output: Vec<AgentId> = handler
312 .get_topic_subscribers(publisher, topic.to_string())?
313 .iter()
314 .map(|(aid, _)| aid)
315 .cloned()
316 .collect();
317 subscribers.sort();
319 output.sort();
320 assert_eq!(subscribers, output, "Mismatch in topic subscribers");
321 Ok(())
322 }
323
324 #[test]
325 fn fetch_subscribers() -> Result<()> {
326 let lmdb = open_tmp_lmdb();
328 let handler = SubscriptionHandler::new(&lmdb);
329
330 let mut subscribers: Vec<AgentId> = std::iter::repeat_with(|| AgentId::generate())
332 .take(N)
333 .collect();
334 let publisher = Id::contract(ContractId::generate());
335 let topics = vec!["Soccer", "Tennis", "Golf", "Basketball", "Football"];
336
337 let mut txn = lmdb.begin_rw_txn()?;
339 for i in 0..N {
340 let topic = Topic::new(publisher, topics[i % 5].to_string(), "method".to_string());
341 handler.subscribe(&mut txn, subscribers[i], topic)?;
343 }
344 txn.commit()?;
345
346 let mut output: Vec<AgentId> = handler
348 .get_subscribers(publisher)?
349 .iter()
350 .map(|(aid, _)| aid)
351 .cloned()
352 .collect();
353 subscribers.sort();
355 output.sort();
356 assert_eq!(subscribers, output, "Mismatch in subscribers");
357 Ok(())
358 }
359
360 #[test]
361 fn fetch_subscriptions() -> Result<()> {
362 let lmdb = open_tmp_lmdb();
364 let handler = SubscriptionHandler::new(&lmdb);
365
366 let subscriber = AgentId::generate();
368 let topics = vec!["Soccer", "Tennis", "Golf", "Basketball", "Football"];
369
370 let mut full_topic: Vec<String> = Vec::new();
371 let mut txn = lmdb.begin_rw_txn()?;
373 for i in 0..N {
374 let p = AgentId::generate();
375 let topic = topics[i % 5].to_string();
376 full_topic.push(format!("/{}/{}", p, topic.to_ascii_lowercase()));
377 let topic = Topic::new(Id::agent(p), topic, "method".to_string());
379 handler.subscribe(&mut txn, subscriber, topic)?;
380 }
381 txn.commit()?;
382
383 let mut output = handler.get_subscriptions(subscriber)?;
385 output.sort();
386 full_topic.sort();
387 assert_eq!(full_topic, output, "Mismatch in subscriptions");
388 Ok(())
389 }
390}