use crate::rete::stream_join_node::{JoinedEvent, StreamJoinNode};
use crate::streaming::event::StreamEvent;
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
pub struct StreamJoinManager {
joins: HashMap<String, Arc<Mutex<StreamJoinNode>>>,
stream_to_joins: HashMap<String, Vec<String>>,
result_handlers: HashMap<String, Box<dyn Fn(JoinedEvent) + Send + Sync>>,
}
impl StreamJoinManager {
pub fn new() -> Self {
Self {
joins: HashMap::new(),
stream_to_joins: HashMap::new(),
result_handlers: HashMap::new(),
}
}
pub fn register_join(
&mut self,
join_id: String,
join_node: StreamJoinNode,
result_handler: Box<dyn Fn(JoinedEvent) + Send + Sync>,
) {
let left_stream = join_node.left_stream.clone();
let right_stream = join_node.right_stream.clone();
self.stream_to_joins
.entry(left_stream)
.or_default()
.push(join_id.clone());
self.stream_to_joins
.entry(right_stream)
.or_default()
.push(join_id.clone());
self.joins
.insert(join_id.clone(), Arc::new(Mutex::new(join_node)));
self.result_handlers.insert(join_id, result_handler);
}
pub fn unregister_join(&mut self, join_id: &str) {
if let Some(join) = self.joins.get(join_id) {
let join_lock = join.lock().unwrap();
let left_stream = join_lock.left_stream.clone();
let right_stream = join_lock.right_stream.clone();
if let Some(joins) = self.stream_to_joins.get_mut(&left_stream) {
joins.retain(|id| id != join_id);
}
if let Some(joins) = self.stream_to_joins.get_mut(&right_stream) {
joins.retain(|id| id != join_id);
}
}
self.joins.remove(join_id);
self.result_handlers.remove(join_id);
}
pub fn process_event(&self, event: StreamEvent) {
let stream_id = event.metadata.source.clone();
if let Some(join_ids) = self.stream_to_joins.get(&stream_id) {
for join_id in join_ids {
if let Some(join) = self.joins.get(join_id) {
let mut join_lock = join.lock().unwrap();
let results = if join_lock.left_stream == stream_id {
join_lock.process_left(event.clone())
} else {
join_lock.process_right(event.clone())
};
if let Some(handler) = self.result_handlers.get(join_id) {
for joined in results {
handler(joined);
}
}
}
}
}
}
pub fn update_watermark(&self, stream_id: &str, watermark: i64) {
if let Some(join_ids) = self.stream_to_joins.get(stream_id) {
for join_id in join_ids {
if let Some(join) = self.joins.get(join_id) {
let mut join_lock = join.lock().unwrap();
let results = join_lock.update_watermark(watermark);
if let Some(handler) = self.result_handlers.get(join_id) {
for joined in results {
handler(joined);
}
}
}
}
}
}
pub fn get_all_stats(&self) -> HashMap<String, crate::rete::stream_join_node::JoinNodeStats> {
let mut stats = HashMap::new();
for (join_id, join) in &self.joins {
let join_lock = join.lock().unwrap();
stats.insert(join_id.clone(), join_lock.get_stats());
}
stats
}
pub fn get_join_stats(
&self,
join_id: &str,
) -> Option<crate::rete::stream_join_node::JoinNodeStats> {
self.joins.get(join_id).map(|join| {
let join_lock = join.lock().unwrap();
join_lock.get_stats()
})
}
pub fn clear(&mut self) {
self.joins.clear();
self.stream_to_joins.clear();
self.result_handlers.clear();
}
}
impl Default for StreamJoinManager {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::rete::stream_join_node::{JoinStrategy, JoinType};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::Duration;
fn create_test_event(stream_id: &str, timestamp: i64, user_id: &str) -> StreamEvent {
use crate::streaming::event::EventMetadata;
use crate::types::Value;
StreamEvent {
id: format!("test_{}_{}", stream_id, timestamp),
event_type: "test".to_string(),
data: vec![("user_id".to_string(), Value::String(user_id.to_string()))]
.into_iter()
.collect(),
metadata: EventMetadata {
timestamp: timestamp as u64,
source: stream_id.to_string(),
sequence: 0,
tags: HashMap::new(),
},
}
}
#[test]
fn test_register_and_route_events() {
let mut manager = StreamJoinManager::new();
let result_count = Arc::new(AtomicUsize::new(0));
let result_count_clone = result_count.clone();
let join_node = StreamJoinNode::new(
"left".to_string(),
"right".to_string(),
JoinType::Inner,
JoinStrategy::TimeWindow {
duration: Duration::from_secs(10),
},
Box::new(|e| e.data.get("user_id").and_then(|v| v.as_string())),
Box::new(|e| e.data.get("user_id").and_then(|v| v.as_string())),
Box::new(|_, _| true),
);
manager.register_join(
"join1".to_string(),
join_node,
Box::new(move |_| {
result_count_clone.fetch_add(1, Ordering::SeqCst);
}),
);
let left_event = create_test_event("left", 1000, "user1");
let right_event = create_test_event("right", 1005, "user1");
manager.process_event(left_event);
manager.process_event(right_event);
assert_eq!(result_count.load(Ordering::SeqCst), 1);
}
#[test]
fn test_multiple_joins_same_stream() {
let mut manager = StreamJoinManager::new();
let result_count1 = Arc::new(AtomicUsize::new(0));
let result_count2 = Arc::new(AtomicUsize::new(0));
let rc1 = result_count1.clone();
let rc2 = result_count2.clone();
let join1 = StreamJoinNode::new(
"left".to_string(),
"right".to_string(),
JoinType::Inner,
JoinStrategy::TimeWindow {
duration: Duration::from_secs(10),
},
Box::new(|e| e.data.get("user_id").and_then(|v| v.as_string())),
Box::new(|e| e.data.get("user_id").and_then(|v| v.as_string())),
Box::new(|_, _| true),
);
let join2 = StreamJoinNode::new(
"left".to_string(),
"other".to_string(),
JoinType::Inner,
JoinStrategy::TimeWindow {
duration: Duration::from_secs(10),
},
Box::new(|e| e.data.get("user_id").and_then(|v| v.as_string())),
Box::new(|e| e.data.get("user_id").and_then(|v| v.as_string())),
Box::new(|_, _| true),
);
manager.register_join(
"join1".to_string(),
join1,
Box::new(move |_| {
rc1.fetch_add(1, Ordering::SeqCst);
}),
);
manager.register_join(
"join2".to_string(),
join2,
Box::new(move |_| {
rc2.fetch_add(1, Ordering::SeqCst);
}),
);
let left_event = create_test_event("left", 1000, "user1");
manager.process_event(left_event);
let right_event = create_test_event("right", 1005, "user1");
manager.process_event(right_event);
let other_event = create_test_event("other", 1005, "user1");
manager.process_event(other_event);
assert_eq!(result_count1.load(Ordering::SeqCst), 1);
assert_eq!(result_count2.load(Ordering::SeqCst), 1);
}
#[test]
fn test_unregister_join() {
let mut manager = StreamJoinManager::new();
let result_count = Arc::new(AtomicUsize::new(0));
let rc = result_count.clone();
let join_node = StreamJoinNode::new(
"left".to_string(),
"right".to_string(),
JoinType::Inner,
JoinStrategy::TimeWindow {
duration: Duration::from_secs(10),
},
Box::new(|e| e.data.get("user_id").and_then(|v| v.as_string())),
Box::new(|e| e.data.get("user_id").and_then(|v| v.as_string())),
Box::new(|_, _| true),
);
manager.register_join(
"join1".to_string(),
join_node,
Box::new(move |_| {
rc.fetch_add(1, Ordering::SeqCst);
}),
);
manager.unregister_join("join1");
let left_event = create_test_event("left", 1000, "user1");
let right_event = create_test_event("right", 1005, "user1");
manager.process_event(left_event);
manager.process_event(right_event);
assert_eq!(result_count.load(Ordering::SeqCst), 0);
}
#[test]
fn test_watermark_update() {
let mut manager = StreamJoinManager::new();
let result_count = Arc::new(AtomicUsize::new(0));
let rc = result_count.clone();
let join_node = StreamJoinNode::new(
"left".to_string(),
"right".to_string(),
JoinType::LeftOuter,
JoinStrategy::TimeWindow {
duration: Duration::from_secs(5),
},
Box::new(|e| e.data.get("user_id").and_then(|v| v.as_string())),
Box::new(|e| e.data.get("user_id").and_then(|v| v.as_string())),
Box::new(|_, _| true),
);
manager.register_join(
"join1".to_string(),
join_node,
Box::new(move |_| {
rc.fetch_add(1, Ordering::SeqCst);
}),
);
let left_event = create_test_event("left", 1000, "user1");
manager.process_event(left_event);
assert_eq!(result_count.load(Ordering::SeqCst), 1);
manager.update_watermark("left", 10000);
assert_eq!(result_count.load(Ordering::SeqCst), 1);
}
}