use crate::{Result, SUBSCRIPTION_REL_SUB_DB};
use borderless::common::{Id, Introduction};
use borderless::events::Topic;
use borderless::{AgentId, Context};
use borderless_kv_store::{Db, RawWrite, RoCursor, RoTx, Tx};
use std::str::FromStr;
fn generate_key(publisher: Id, topic: String, subscriber: Option<AgentId>) -> String {
let publisher = match publisher {
Id::Contract { contract_id } => contract_id.to_string().to_ascii_lowercase(),
Id::Agent { agent_id } => agent_id.to_string().to_ascii_lowercase(),
};
let subscriber = subscriber
.map(|agent| agent.to_string().to_ascii_lowercase())
.unwrap_or_default();
let topic = topic.trim_matches('/').to_ascii_lowercase();
match (topic.is_empty(), subscriber.is_empty()) {
(true, true) => format!("{publisher}\n"),
(false, true) => format!("{publisher}\n{topic}\n"),
_ => format!("{publisher}\n{topic}\n{subscriber}"),
}
}
fn extract_entry(key: &[u8], value: &[u8]) -> Result<(Topic, AgentId)> {
let key = std::str::from_utf8(key).with_context(|| "DB key deserialization failed")?;
let method = std::str::from_utf8(value).with_context(|| "DB value deserialization failed")?;
let mut parts = key.splitn(3, '\n');
match (parts.next(), parts.next(), parts.next()) {
(Some(p), Some(topic), Some(s)) => {
let subscriber = AgentId::from_str(s).with_context(|| "Invalid subscriber")?;
let publisher = p.parse().with_context(|| "Invalid publisher")?;
Ok((Topic::new(publisher, topic, method), subscriber))
}
_ => Err(crate::Error::msg("Malformed key error")),
}
}
pub struct SubscriptionHandler<'a, S: Db> {
db: &'a S,
}
impl<'a, S: Db> SubscriptionHandler<'a, S> {
pub fn new(db: &'a S) -> Self {
Self { db }
}
pub fn init(&self, txn: &mut <S as Db>::RwTx<'_>, introduction: Introduction) -> Result<()> {
match introduction.id {
Id::Contract { .. } => {} Id::Agent { agent_id } => {
for s in introduction.subscriptions {
self.subscribe_txn(txn, agent_id, s)?
}
}
}
Ok(())
}
pub fn subscribe(&self, subscriber: AgentId, topic: Topic) -> Result<()> {
let mut txn = self.db.begin_rw_txn()?;
self.subscribe_txn(&mut txn, subscriber, topic)?;
Ok(txn.commit()?)
}
fn subscribe_txn(
&self,
txn: &mut <S as Db>::RwTx<'_>,
subscriber: AgentId,
topic: Topic,
) -> Result<()> {
let db_ptr = self.db.open_sub_db(SUBSCRIPTION_REL_SUB_DB)?;
let key = generate_key(topic.publisher, topic.topic, Some(subscriber));
txn.write(&db_ptr, &key, &topic.method)?;
Ok(())
}
pub fn unsubscribe(&self, subscriber: AgentId, topic: Topic) -> Result<()> {
let mut txn = self.db.begin_rw_txn()?;
self.unsubscribe_txn(&mut txn, subscriber, topic)?;
Ok(txn.commit()?)
}
fn unsubscribe_txn(
&self,
txn: &mut <S as Db>::RwTx<'_>,
subscriber: AgentId,
topic: Topic,
) -> Result<()> {
let db_ptr = self.db.open_sub_db(SUBSCRIPTION_REL_SUB_DB)?;
let key = generate_key(topic.publisher, topic.topic, Some(subscriber));
Ok(txn.delete(&db_ptr, &key)?)
}
pub fn get_topic_subscribers(
&self,
publisher: Id,
topic: String,
) -> Result<Vec<(AgentId, String)>> {
let db_ptr = self.db.open_sub_db(SUBSCRIPTION_REL_SUB_DB)?;
let txn = self.db.begin_ro_txn()?;
let mut cursor = txn.ro_cursor(&db_ptr)?;
let mut subscribers = Vec::new();
let prefix = generate_key(publisher, topic, None);
for (key, value) in cursor.iter_from(&prefix) {
if !key.starts_with(prefix.as_bytes()) {
break;
}
let (topic, subscriber) = extract_entry(key, value)?;
subscribers.push((subscriber, topic.method));
}
drop(cursor);
Ok(subscribers)
}
pub fn get_subscriptions(&self, target: AgentId) -> Result<Vec<Topic>> {
let db_ptr = self.db.open_sub_db(SUBSCRIPTION_REL_SUB_DB)?;
let txn = self.db.begin_ro_txn()?;
let mut cursor = txn.ro_cursor(&db_ptr)?;
let mut topics = Vec::new();
for (key, value) in cursor.iter() {
let (topic, subscriber) = extract_entry(key, value)?;
if target != subscriber {
continue;
}
topics.push(topic);
}
drop(cursor);
Ok(topics)
}
pub fn unsubscribe_all(&self, txn: &mut <S as Db>::RwTx<'_>, subscriber: Id) -> Result<()> {
let subscriber = match subscriber {
Id::Contract { .. } => return Ok(()), Id::Agent { agent_id } => agent_id,
};
let subscriptions = self.get_subscriptions(subscriber)?;
for topic in subscriptions {
self.unsubscribe_txn(txn, subscriber, topic)?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use crate::db::subscriptions::SubscriptionHandler;
use crate::SUBSCRIPTION_REL_SUB_DB;
use borderless::common::Id;
use borderless::events::Topic;
use borderless::{AgentId, ContractId, Result};
use borderless_kv_store::backend::lmdb::Lmdb;
use borderless_kv_store::Db;
use tempfile::tempdir;
const N: usize = 10;
fn open_tmp_lmdb() -> Lmdb {
let tmp_dir = tempdir().unwrap();
let env = Lmdb::new(tmp_dir.path(), 1).unwrap();
env.create_sub_db(SUBSCRIPTION_REL_SUB_DB).unwrap();
env
}
#[test]
fn subscription() -> Result<()> {
let lmdb = open_tmp_lmdb();
let handler = SubscriptionHandler::new(&lmdb);
let subscribers: Vec<AgentId> = std::iter::repeat_with(|| AgentId::generate())
.take(N)
.collect();
let publishers: Vec<Id> = std::iter::repeat_with(|| Id::contract(ContractId::generate()))
.take(N)
.collect();
let topic = "MyTopic";
for i in 0..N {
let topic = Topic::new(publishers[i], topic.to_string(), "method".to_string());
handler.subscribe(subscribers[i], topic)?;
}
for i in 0..N {
let subscriptions = handler.get_subscriptions(subscribers[i])?;
assert_eq!(subscriptions.len(), 1);
assert_eq!(subscriptions[0].publisher, publishers[i]);
assert_eq!(
subscriptions[0].topic,
topic.to_string().to_ascii_lowercase()
);
}
Ok(())
}
#[test]
fn unsubscription() -> Result<()> {
let lmdb = open_tmp_lmdb();
let handler = SubscriptionHandler::new(&lmdb);
let subscribers: Vec<AgentId> = std::iter::repeat_with(|| AgentId::generate())
.take(N)
.collect();
let publishers: Vec<Id> = std::iter::repeat_with(|| Id::agent(AgentId::generate()))
.take(N)
.collect();
let topic = "MyTopic";
for i in 0..N {
let topic = Topic::new(publishers[i], topic.to_string(), "method".to_string());
handler.subscribe(subscribers[i], topic)?;
}
for i in 0..N {
let s = subscribers[i];
let p = publishers[i];
handler.unsubscribe(s, Topic::new(p, topic.to_string(), String::default()))?;
}
for p in publishers {
assert!(handler
.get_topic_subscribers(p, topic.to_string())?
.is_empty());
}
Ok(())
}
#[test]
fn fetch_topic_subscribers() -> Result<()> {
let lmdb = open_tmp_lmdb();
let handler = SubscriptionHandler::new(&lmdb);
let mut subscribers: Vec<AgentId> = std::iter::repeat_with(|| AgentId::generate())
.take(N)
.collect();
let publisher = Id::contract(ContractId::generate());
let topic = "tennis";
for i in 0..N {
let topic = Topic::new(publisher, topic.to_string(), "method".to_string());
handler.subscribe(subscribers[i], topic)?;
}
let mut output: Vec<AgentId> = handler
.get_topic_subscribers(publisher, topic.to_string())?
.iter()
.map(|(aid, _)| aid)
.cloned()
.collect();
subscribers.sort();
output.sort();
assert_eq!(subscribers, output, "Mismatch in topic subscribers");
Ok(())
}
#[test]
fn fetch_subscribers() -> Result<()> {
let lmdb = open_tmp_lmdb();
let handler = SubscriptionHandler::new(&lmdb);
let mut subscribers: Vec<AgentId> = std::iter::repeat_with(|| AgentId::generate())
.take(N)
.collect();
let publisher = Id::contract(ContractId::generate());
let topics = vec!["Soccer", "Tennis", "Golf", "Basketball", "Football"];
for i in 0..N {
let topic = Topic::new(publisher, topics[i % 5].to_string(), "method".to_string());
handler.subscribe(subscribers[i], topic)?;
}
let mut output: Vec<AgentId> = handler
.get_topic_subscribers(publisher, String::default())?
.iter()
.map(|(aid, _)| aid)
.cloned()
.collect();
subscribers.sort();
output.sort();
assert_eq!(subscribers, output, "Mismatch in subscribers");
Ok(())
}
#[test]
fn fetch_subscriptions() -> Result<()> {
let lmdb = open_tmp_lmdb();
let handler = SubscriptionHandler::new(&lmdb);
let subscriber = AgentId::generate();
let topics = vec!["Soccer", "Tennis", "Golf", "Basketball", "Football"];
let mut susbcriptions: Vec<Topic> = Vec::new();
for i in 0..N {
let p = ContractId::generate();
let t = topics[i % 5].to_string().to_ascii_lowercase();
let topic = Topic::new(Id::contract(p), t, "method".to_string());
handler.subscribe(subscriber, topic.clone())?;
susbcriptions.push(topic);
}
let output = handler.get_subscriptions(subscriber)?;
for t in output {
assert!(susbcriptions.contains(&t), "Mismatch in subscriptions",);
}
Ok(())
}
}