use std::collections::HashMap;
use std::sync::{Arc, Mutex};
#[derive(Debug, Clone, PartialEq)]
pub enum ChangeType {
NodeInserted,
EdgeInserted,
NodeUpdated,
EdgeUpdated,
NodeDeleted,
EdgeDeleted,
}
#[derive(Debug, Clone)]
pub struct Change {
pub change_type: ChangeType,
pub topic: String,
pub node_id: Option<u32>,
pub src_id: Option<u32>,
pub dst_id: Option<u32>,
pub timestamp: i64,
}
impl Change {
pub fn node_inserted(node_id: u32, topic: String) -> Self {
Self {
change_type: ChangeType::NodeInserted,
topic,
node_id: Some(node_id),
src_id: None,
dst_id: None,
timestamp: chrono::Utc::now().timestamp(),
}
}
pub fn edge_inserted(src_id: u32, dst_id: u32, topic: String) -> Self {
Self {
change_type: ChangeType::EdgeInserted,
topic,
node_id: None,
src_id: Some(src_id),
dst_id: Some(dst_id),
timestamp: chrono::Utc::now().timestamp(),
}
}
pub fn node_deleted(node_id: u32, topic: String) -> Self {
Self {
change_type: ChangeType::NodeDeleted,
topic,
node_id: Some(node_id),
src_id: None,
dst_id: None,
timestamp: chrono::Utc::now().timestamp(),
}
}
pub fn edge_deleted(src_id: u32, dst_id: u32, topic: String) -> Self {
Self {
change_type: ChangeType::EdgeDeleted,
topic,
node_id: None,
src_id: Some(src_id),
dst_id: Some(dst_id),
timestamp: chrono::Utc::now().timestamp(),
}
}
}
pub type SubscriberCallback = Box<dyn Fn(&Change) + Send>;
pub struct Subscriber {
topics: Vec<String>,
callback: SubscriberCallback,
}
impl Subscriber {
pub fn new(topics: Vec<String>, callback: SubscriberCallback) -> Self {
Self { topics, callback }
}
pub fn is_interested(&self, topic: &str) -> bool {
self.topics.iter().any(|t| t == topic || t == "*")
}
pub fn notify(&self, change: &Change) {
if self.is_interested(&change.topic) {
(self.callback)(change);
}
}
}
pub struct PubSub {
subscribers: Arc<Mutex<HashMap<usize, Subscriber>>>,
change_log: Arc<Mutex<Vec<Change>>>,
next_subscriber_id: Arc<Mutex<usize>>,
consumer_groups: Arc<Mutex<HashMap<String, ConsumerGroupState>>>,
}
#[derive(Debug, Clone)]
struct ConsumerGroupState {
name: String,
topics: Vec<String>,
last_offset: usize,
pending: Vec<usize>,
}
impl PubSub {
pub fn new() -> Self {
Self {
subscribers: Arc::new(Mutex::new(HashMap::new())),
change_log: Arc::new(Mutex::new(Vec::new())),
next_subscriber_id: Arc::new(Mutex::new(0)),
consumer_groups: Arc::new(Mutex::new(HashMap::new())),
}
}
pub fn subscribe(&self, topics: Vec<String>, callback: SubscriberCallback) -> usize {
let mut id_guard = self.next_subscriber_id.lock().unwrap();
let id = *id_guard;
*id_guard = id + 1;
let subscriber = Subscriber::new(topics, callback);
let mut subscribers = self.subscribers.lock().unwrap();
subscribers.insert(id, subscriber);
id
}
pub fn unsubscribe(&self, subscriber_id: usize) {
let mut subscribers = self.subscribers.lock().unwrap();
subscribers.remove(&subscriber_id);
}
pub fn publish(&self, change: Change) {
{
let mut log = self.change_log.lock().unwrap();
log.push(change.clone());
}
let subscribers = self.subscribers.lock().unwrap();
for subscriber in subscribers.values() {
subscriber.notify(&change);
}
}
pub fn replay_wal(&self) -> Result<usize, String> {
let log = self.change_log.lock().unwrap();
let subscribers = self.subscribers.lock().unwrap();
let mut notified = 0;
for change in log.iter() {
for subscriber in subscribers.values() {
if subscriber.is_interested(&change.topic) {
subscriber.notify(change);
notified += 1;
}
}
}
Ok(notified)
}
pub fn change_log_size(&self) -> usize {
let log = self.change_log.lock().unwrap();
log.len()
}
pub fn clear_change_log(&self) {
let mut log = self.change_log.lock().unwrap();
log.clear();
}
pub fn subscriber_count(&self) -> usize {
let subscribers = self.subscribers.lock().unwrap();
subscribers.len()
}
pub fn create_consumer_group(
&self,
name: String,
) -> Result<(), crate::errors::SqliteGraphError> {
let mut groups = self.consumer_groups.lock().unwrap();
if groups.contains_key(&name) {
return Err(crate::errors::SqliteGraphError::QueryError(format!(
"Consumer group '{}' already exists",
name
)));
}
groups.insert(
name.clone(),
ConsumerGroupState {
name: name.clone(),
topics: Vec::new(), last_offset: 0,
pending: Vec::new(),
},
);
Ok(())
}
pub fn subscribe_group(
&self,
group: &str,
topics: Vec<String>,
) -> Result<usize, crate::errors::SqliteGraphError> {
let mut groups = self.consumer_groups.lock().unwrap();
if !groups.contains_key(group) {
return Err(crate::errors::SqliteGraphError::QueryError(format!(
"Consumer group '{}' not found",
group
)));
}
let group_state = groups.get_mut(group).unwrap();
group_state.topics = topics.clone();
Ok(self.subscribe(
topics,
Box::new(|_change| {
}),
))
}
pub fn fetch_messages(
&self,
group: &str,
limit: usize,
) -> Result<
Vec<crate::sharding::streams_spec::ConsumerGroupMessage>,
crate::errors::SqliteGraphError,
> {
let log = self.change_log.lock().unwrap();
let mut groups = self.consumer_groups.lock().unwrap();
let group_state = groups.get_mut(group).ok_or_else(|| {
crate::errors::SqliteGraphError::QueryError(format!(
"Consumer group '{}' not found",
group
))
})?;
let topics = group_state.topics.clone();
let mut messages = Vec::new();
let start_offset = group_state.last_offset;
for (idx, change) in log.iter().enumerate() {
if idx < start_offset {
continue; }
if messages.len() >= limit {
break;
}
if !topics.is_empty() && !topics.iter().any(|t| change.topic == *t || t == "*") {
continue;
}
group_state.pending.push(idx);
messages.push(crate::sharding::streams_spec::ConsumerGroupMessage {
offset: idx,
change: change.clone(),
});
}
Ok(messages)
}
pub fn ack(&self, group: &str, offset: usize) -> Result<(), crate::errors::SqliteGraphError> {
let mut groups = self.consumer_groups.lock().unwrap();
let group_state = groups.get_mut(group).ok_or_else(|| {
crate::errors::SqliteGraphError::QueryError(format!(
"Consumer group '{}' not found",
group
))
})?;
if let Some(pos) = group_state.pending.iter().position(|&x| x == offset) {
group_state.pending.remove(pos);
group_state.last_offset = offset + 1;
}
Ok(())
}
pub fn get_consumer_group(
&self,
group: &str,
) -> Result<crate::sharding::streams_spec::ConsumerGroup, crate::errors::SqliteGraphError> {
let groups = self.consumer_groups.lock().unwrap();
let group_state = groups.get(group).ok_or_else(|| {
crate::errors::SqliteGraphError::QueryError(format!(
"Consumer group '{}' not found",
group
))
})?;
Ok(crate::sharding::streams_spec::ConsumerGroup {
name: group_state.name.clone(),
last_offset: group_state.last_offset,
pending: group_state.pending.clone(),
})
}
pub fn list_consumer_groups(&self) -> Result<Vec<String>, crate::errors::SqliteGraphError> {
let groups = self.consumer_groups.lock().unwrap();
Ok(groups.keys().cloned().collect())
}
}
impl Default for PubSub {
fn default() -> Self {
Self::new()
}
}
pub const DEFAULT_TOPIC: &str = "graph.*";
pub const NODE_TOPIC: &str = "graph.node";
pub const EDGE_TOPIC: &str = "graph.edge";
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::time::Duration;
#[test]
fn test_pubsub_creation() {
let pubsub = PubSub::new();
assert_eq!(pubsub.subscriber_count(), 0);
assert_eq!(pubsub.change_log_size(), 0);
}
#[test]
fn test_subscribe_unsubscribe() {
let pubsub = PubSub::new();
let id = pubsub.subscribe(vec![NODE_TOPIC.to_string()], Box::new(|_| {}));
assert_eq!(pubsub.subscriber_count(), 1);
pubsub.unsubscribe(id);
assert_eq!(pubsub.subscriber_count(), 0);
}
#[test]
fn test_publish_single_subscriber() {
let pubsub = PubSub::new();
let notified = Arc::new(Mutex::new(0));
let notified_clone = notified.clone();
pubsub.subscribe(
vec![NODE_TOPIC.to_string()],
Box::new(move |_| {
*notified_clone.lock().unwrap() += 1;
}),
);
let change = Change::node_inserted(100, NODE_TOPIC.to_string());
pubsub.publish(change);
std::thread::sleep(Duration::from_millis(10));
assert_eq!(*notified.lock().unwrap(), 1);
}
#[test]
fn test_publish_multiple_subscribers() {
let pubsub = PubSub::new();
let count1 = Arc::new(Mutex::new(0));
let count1_clone = count1.clone();
let count2 = Arc::new(Mutex::new(0));
let count2_clone = count2.clone();
pubsub.subscribe(
vec![NODE_TOPIC.to_string()],
Box::new(move |_| {
*count1_clone.lock().unwrap() += 1;
}),
);
pubsub.subscribe(
vec![NODE_TOPIC.to_string()],
Box::new(move |_| {
*count2_clone.lock().unwrap() += 1;
}),
);
let change = Change::node_inserted(100, NODE_TOPIC.to_string());
pubsub.publish(change);
std::thread::sleep(Duration::from_millis(10));
assert_eq!(*count1.lock().unwrap(), 1);
assert_eq!(*count2.lock().unwrap(), 1);
}
#[test]
fn test_topic_filtering() {
let pubsub = PubSub::new();
let notified = Arc::new(Mutex::new(0));
let notified_clone = notified.clone();
pubsub.subscribe(
vec![EDGE_TOPIC.to_string()],
Box::new(move |_| {
*notified_clone.lock().unwrap() += 1;
}),
);
let node_change = Change::node_inserted(100, NODE_TOPIC.to_string());
pubsub.publish(node_change);
std::thread::sleep(Duration::from_millis(10));
assert_eq!(*notified.lock().unwrap(), 0);
let edge_change = Change::edge_inserted(100, 200, EDGE_TOPIC.to_string());
pubsub.publish(edge_change);
std::thread::sleep(Duration::from_millis(10));
assert_eq!(*notified.lock().unwrap(), 1);
}
#[test]
fn test_wildcard_topic() {
let pubsub = PubSub::new();
let notified = Arc::new(Mutex::new(0));
let notified_clone = notified.clone();
pubsub.subscribe(
vec!["*".to_string()], Box::new(move |_| {
*notified_clone.lock().unwrap() += 1;
}),
);
let change = Change::node_inserted(100, NODE_TOPIC.to_string());
pubsub.publish(change);
std::thread::sleep(Duration::from_millis(10));
assert_eq!(*notified.lock().unwrap(), 1);
}
#[test]
fn test_wal_replay() {
let pubsub = PubSub::new();
let count = Arc::new(Mutex::new(0));
let count_clone = count.clone();
pubsub.subscribe(
vec![NODE_TOPIC.to_string()],
Box::new(move |_| {
*count_clone.lock().unwrap() += 1;
}),
);
pubsub.publish(Change::node_inserted(100, NODE_TOPIC.to_string()));
pubsub.publish(Change::edge_inserted(100, 200, EDGE_TOPIC.to_string()));
assert_eq!(pubsub.change_log_size(), 2);
let count2 = Arc::new(Mutex::new(0));
let count2_clone = count2.clone();
pubsub.subscribe(
vec![NODE_TOPIC.to_string()],
Box::new(move |_| {
*count2_clone.lock().unwrap() += 1;
}),
);
let notified = pubsub.replay_wal().unwrap();
assert_eq!(notified, 2); assert_eq!(*count2.lock().unwrap(), 1); }
#[test]
fn test_change_log_persistence() {
let pubsub = PubSub::new();
assert_eq!(pubsub.change_log_size(), 0);
pubsub.publish(Change::node_inserted(100, NODE_TOPIC.to_string()));
assert_eq!(pubsub.change_log_size(), 1);
pubsub.publish(Change::edge_inserted(100, 200, EDGE_TOPIC.to_string()));
assert_eq!(pubsub.change_log_size(), 2);
pubsub.clear_change_log();
assert_eq!(pubsub.change_log_size(), 0);
}
#[test]
fn test_edge_deletion() {
let pubsub = PubSub::new();
let notified = Arc::new(Mutex::new(false));
let notified_clone = notified.clone();
pubsub.subscribe(
vec![EDGE_TOPIC.to_string()],
Box::new(move |change| {
if change.change_type == ChangeType::EdgeDeleted {
*notified_clone.lock().unwrap() = true;
}
}),
);
let change = Change::edge_deleted(100, 200, EDGE_TOPIC.to_string());
pubsub.publish(change);
std::thread::sleep(Duration::from_millis(10));
assert!(*notified.lock().unwrap());
}
#[test]
fn test_timestamp_ordering() {
let pubsub = PubSub::new();
pubsub.publish(Change::node_inserted(100, NODE_TOPIC.to_string()));
std::thread::sleep(Duration::from_millis(50));
pubsub.publish(Change::node_inserted(200, NODE_TOPIC.to_string()));
let log = pubsub.change_log.lock().unwrap();
assert!(log[0].timestamp > 0);
assert!(log[1].timestamp > 0);
assert_eq!(log.len(), 2);
}
}