1use crate::{Result, SUBSCRIPTION_REL_SUB_DB};
2use borderless::common::{Id, Introduction};
3use borderless::events::Topic;
4use borderless::{AgentId, Context};
5use borderless_kv_store::{Db, RawWrite, RoCursor, RoTx, Tx};
6use std::str::FromStr;
7
8fn generate_key(publisher: Id, topic: String, subscriber: Option<AgentId>) -> String {
14 let publisher = match publisher {
16 Id::Contract { contract_id } => contract_id.to_string().to_ascii_lowercase(),
17 Id::Agent { agent_id } => agent_id.to_string().to_ascii_lowercase(),
18 };
19 let subscriber = subscriber
21 .map(|agent| agent.to_string().to_ascii_lowercase())
22 .unwrap_or_default();
23 let topic = topic.trim_matches('/').to_ascii_lowercase();
25
26 match (topic.is_empty(), subscriber.is_empty()) {
29 (true, true) => format!("{publisher}\n"),
30 (false, true) => format!("{publisher}\n{topic}\n"),
31 _ => format!("{publisher}\n{topic}\n{subscriber}"),
32 }
33}
34
35fn extract_entry(key: &[u8], value: &[u8]) -> Result<(Topic, AgentId)> {
39 let key = std::str::from_utf8(key).with_context(|| "DB key deserialization failed")?;
40 let method = std::str::from_utf8(value).with_context(|| "DB value deserialization failed")?;
41
42 let mut parts = key.splitn(3, '\n');
43 match (parts.next(), parts.next(), parts.next()) {
44 (Some(p), Some(topic), Some(s)) => {
45 let subscriber = AgentId::from_str(s).with_context(|| "Invalid subscriber")?;
47 let publisher = p.parse().with_context(|| "Invalid publisher")?;
49 Ok((Topic::new(publisher, topic, method), subscriber))
50 }
51 _ => Err(crate::Error::msg("Malformed key error")),
52 }
53}
54
55pub struct SubscriptionHandler<'a, S: Db> {
56 db: &'a S,
57}
58
59impl<'a, S: Db> SubscriptionHandler<'a, S> {
60 pub fn new(db: &'a S) -> Self {
61 Self { db }
62 }
63
64 pub fn init(&self, txn: &mut <S as Db>::RwTx<'_>, introduction: Introduction) -> Result<()> {
66 match introduction.id {
68 Id::Contract { .. } => {} Id::Agent { agent_id } => {
70 for s in introduction.subscriptions {
71 self.subscribe_txn(txn, agent_id, s)?
72 }
73 }
74 }
75 Ok(())
76 }
77
78 pub fn subscribe(&self, subscriber: AgentId, topic: Topic) -> Result<()> {
82 let mut txn = self.db.begin_rw_txn()?;
83 self.subscribe_txn(&mut txn, subscriber, topic)?;
84 Ok(txn.commit()?)
85 }
86
87 fn subscribe_txn(
91 &self,
92 txn: &mut <S as Db>::RwTx<'_>,
93 subscriber: AgentId,
94 topic: Topic,
95 ) -> Result<()> {
96 let db_ptr = self.db.open_sub_db(SUBSCRIPTION_REL_SUB_DB)?;
98 let key = generate_key(topic.publisher, topic.topic, Some(subscriber));
100 txn.write(&db_ptr, &key, &topic.method)?;
101 Ok(())
102 }
103
104 pub fn unsubscribe(&self, subscriber: AgentId, topic: Topic) -> Result<()> {
108 let mut txn = self.db.begin_rw_txn()?;
109 self.unsubscribe_txn(&mut txn, subscriber, topic)?;
110 Ok(txn.commit()?)
111 }
112
113 fn unsubscribe_txn(
117 &self,
118 txn: &mut <S as Db>::RwTx<'_>,
119 subscriber: AgentId,
120 topic: Topic,
121 ) -> Result<()> {
122 let db_ptr = self.db.open_sub_db(SUBSCRIPTION_REL_SUB_DB)?;
124 let key = generate_key(topic.publisher, topic.topic, Some(subscriber));
126 Ok(txn.delete(&db_ptr, &key)?)
127 }
128
129 pub fn get_topic_subscribers(
131 &self,
132 publisher: Id,
133 topic: String,
134 ) -> Result<Vec<(AgentId, String)>> {
135 let db_ptr = self.db.open_sub_db(SUBSCRIPTION_REL_SUB_DB)?;
137 let txn = self.db.begin_ro_txn()?;
138 let mut cursor = txn.ro_cursor(&db_ptr)?;
139
140 let mut subscribers = Vec::new();
141
142 let prefix = generate_key(publisher, topic, None);
144
145 for (key, value) in cursor.iter_from(&prefix) {
146 if !key.starts_with(prefix.as_bytes()) {
148 break;
149 }
150 let (topic, subscriber) = extract_entry(key, value)?;
151 subscribers.push((subscriber, topic.method));
153 }
154 drop(cursor);
156 Ok(subscribers)
157 }
158
159 pub fn get_subscriptions(&self, target: AgentId) -> Result<Vec<Topic>> {
161 let db_ptr = self.db.open_sub_db(SUBSCRIPTION_REL_SUB_DB)?;
163 let txn = self.db.begin_ro_txn()?;
164 let mut cursor = txn.ro_cursor(&db_ptr)?;
165
166 let mut topics = Vec::new();
167 for (key, value) in cursor.iter() {
168 let (topic, subscriber) = extract_entry(key, value)?;
169 if target != subscriber {
171 continue;
172 }
173 topics.push(topic);
175 }
176 drop(cursor);
178 Ok(topics)
179 }
180
181 pub fn unsubscribe_all(&self, txn: &mut <S as Db>::RwTx<'_>, subscriber: Id) -> Result<()> {
182 let subscriber = match subscriber {
183 Id::Contract { .. } => return Ok(()), Id::Agent { agent_id } => agent_id,
185 };
186 let subscriptions = self.get_subscriptions(subscriber)?;
188 for topic in subscriptions {
190 self.unsubscribe_txn(txn, subscriber, topic)?;
191 }
192 Ok(())
193 }
194}
195
196#[cfg(test)]
197mod tests {
198 use crate::db::subscriptions::SubscriptionHandler;
199 use crate::SUBSCRIPTION_REL_SUB_DB;
200 use borderless::common::Id;
201 use borderless::events::Topic;
202 use borderless::{AgentId, ContractId, Result};
203 use borderless_kv_store::backend::lmdb::Lmdb;
204 use borderless_kv_store::Db;
205 use tempfile::tempdir;
206
207 const N: usize = 10;
208
209 fn open_tmp_lmdb() -> Lmdb {
210 let tmp_dir = tempdir().unwrap();
211 let env = Lmdb::new(tmp_dir.path(), 1).unwrap();
212 env.create_sub_db(SUBSCRIPTION_REL_SUB_DB).unwrap();
213 env
214 }
215
216 #[test]
217 fn subscription() -> Result<()> {
218 let lmdb = open_tmp_lmdb();
220 let handler = SubscriptionHandler::new(&lmdb);
221
222 let subscribers: Vec<AgentId> = std::iter::repeat_with(|| AgentId::generate())
224 .take(N)
225 .collect();
226 let publishers: Vec<Id> = std::iter::repeat_with(|| Id::contract(ContractId::generate()))
227 .take(N)
228 .collect();
229 let topic = "MyTopic";
230
231 for i in 0..N {
233 let topic = Topic::new(publishers[i], topic.to_string(), "method".to_string());
234 handler.subscribe(subscribers[i], topic)?;
236 }
237
238 for i in 0..N {
240 let subscriptions = handler.get_subscriptions(subscribers[i])?;
241 assert_eq!(subscriptions.len(), 1);
242 assert_eq!(subscriptions[0].publisher, publishers[i]);
243 assert_eq!(
244 subscriptions[0].topic,
245 topic.to_string().to_ascii_lowercase()
246 );
247 }
248 Ok(())
249 }
250
251 #[test]
252 fn unsubscription() -> Result<()> {
253 let lmdb = open_tmp_lmdb();
255 let handler = SubscriptionHandler::new(&lmdb);
256
257 let subscribers: Vec<AgentId> = std::iter::repeat_with(|| AgentId::generate())
259 .take(N)
260 .collect();
261 let publishers: Vec<Id> = std::iter::repeat_with(|| Id::agent(AgentId::generate()))
262 .take(N)
263 .collect();
264 let topic = "MyTopic";
265
266 for i in 0..N {
268 let topic = Topic::new(publishers[i], topic.to_string(), "method".to_string());
269 handler.subscribe(subscribers[i], topic)?;
271 }
272
273 for i in 0..N {
275 let s = subscribers[i];
276 let p = publishers[i];
277 handler.unsubscribe(s, Topic::new(p, topic.to_string(), String::default()))?;
279 }
280
281 for p in publishers {
283 assert!(handler
284 .get_topic_subscribers(p, topic.to_string())?
285 .is_empty());
286 }
287 Ok(())
288 }
289
290 #[test]
291 fn fetch_topic_subscribers() -> Result<()> {
292 let lmdb = open_tmp_lmdb();
294 let handler = SubscriptionHandler::new(&lmdb);
295
296 let mut subscribers: Vec<AgentId> = std::iter::repeat_with(|| AgentId::generate())
298 .take(N)
299 .collect();
300 let publisher = Id::contract(ContractId::generate());
301 let topic = "tennis";
302
303 for i in 0..N {
305 let topic = Topic::new(publisher, topic.to_string(), "method".to_string());
306 handler.subscribe(subscribers[i], topic)?;
308 }
309
310 let mut output: Vec<AgentId> = handler
312 .get_topic_subscribers(publisher, topic.to_string())?
313 .iter()
314 .map(|(aid, _)| aid)
315 .cloned()
316 .collect();
317 subscribers.sort();
319 output.sort();
320 assert_eq!(subscribers, output, "Mismatch in topic subscribers");
321 Ok(())
322 }
323
324 #[test]
325 fn fetch_subscribers() -> Result<()> {
326 let lmdb = open_tmp_lmdb();
328 let handler = SubscriptionHandler::new(&lmdb);
329
330 let mut subscribers: Vec<AgentId> = std::iter::repeat_with(|| AgentId::generate())
332 .take(N)
333 .collect();
334 let publisher = Id::contract(ContractId::generate());
335 let topics = vec!["Soccer", "Tennis", "Golf", "Basketball", "Football"];
336
337 for i in 0..N {
339 let topic = Topic::new(publisher, topics[i % 5].to_string(), "method".to_string());
340 handler.subscribe(subscribers[i], topic)?;
342 }
343
344 let mut output: Vec<AgentId> = handler
346 .get_topic_subscribers(publisher, String::default())?
347 .iter()
348 .map(|(aid, _)| aid)
349 .cloned()
350 .collect();
351 subscribers.sort();
353 output.sort();
354 assert_eq!(subscribers, output, "Mismatch in subscribers");
355 Ok(())
356 }
357
358 #[test]
359 fn fetch_subscriptions() -> Result<()> {
360 let lmdb = open_tmp_lmdb();
362 let handler = SubscriptionHandler::new(&lmdb);
363
364 let subscriber = AgentId::generate();
366 let topics = vec!["Soccer", "Tennis", "Golf", "Basketball", "Football"];
367
368 let mut susbcriptions: Vec<Topic> = Vec::new();
369 for i in 0..N {
371 let p = ContractId::generate();
372 let t = topics[i % 5].to_string().to_ascii_lowercase();
373 let topic = Topic::new(Id::contract(p), t, "method".to_string());
375 handler.subscribe(subscriber, topic.clone())?;
376 susbcriptions.push(topic);
378 }
379
380 let output = handler.get_subscriptions(subscriber)?;
382 for t in output {
383 assert!(susbcriptions.contains(&t), "Mismatch in subscriptions",);
384 }
385 Ok(())
386 }
387}