1use crate::{
10 processing::{Watermark, WindowType},
11 StreamEvent,
12};
13use anyhow::{anyhow, Result};
14use chrono::{DateTime, Duration, Utc};
15use serde::{Deserialize, Serialize};
16use std::collections::{HashMap, VecDeque};
17use std::sync::Arc;
18use tokio::sync::RwLock;
19use tracing::{debug, warn};
20
21#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
23pub enum JoinType {
24 Inner,
26 LeftOuter,
28 RightOuter,
30 FullOuter,
32}
33
34pub type JoinKeyExtractor = Arc<dyn Fn(&StreamEvent) -> Option<String> + Send + Sync>;
36
37pub type JoinCondition = Arc<dyn Fn(&StreamEvent, &StreamEvent) -> bool + Send + Sync>;
39
40pub type JoinResultTransformer =
42 Arc<dyn Fn(&StreamEvent, Option<&StreamEvent>) -> Result<StreamEvent> + Send + Sync>;
43
44#[derive(Clone)]
46pub struct JoinConfig {
47 pub join_type: JoinType,
49 pub window: Option<WindowType>,
51 pub left_key_extractor: JoinKeyExtractor,
53 pub right_key_extractor: JoinKeyExtractor,
55 pub join_condition: Option<JoinCondition>,
57 pub result_transformer: JoinResultTransformer,
59 pub temporal_tolerance: Option<Duration>,
61 pub buffer_size: usize,
63 pub collect_stats: bool,
65 pub allowed_lateness: Duration,
67}
68
69impl JoinConfig {
70 pub fn new(
71 join_type: JoinType,
72 left_key_extractor: JoinKeyExtractor,
73 right_key_extractor: JoinKeyExtractor,
74 result_transformer: JoinResultTransformer,
75 ) -> Self {
76 Self {
77 join_type,
78 window: None,
79 left_key_extractor,
80 right_key_extractor,
81 join_condition: None,
82 result_transformer,
83 temporal_tolerance: None,
84 buffer_size: 10000,
85 collect_stats: true,
86 allowed_lateness: Duration::minutes(5),
87 }
88 }
89
90 pub fn with_window(mut self, window: WindowType) -> Self {
91 self.window = Some(window);
92 self
93 }
94
95 pub fn with_temporal_tolerance(mut self, tolerance: Duration) -> Self {
96 self.temporal_tolerance = Some(tolerance);
97 self
98 }
99
100 pub fn with_condition(mut self, condition: JoinCondition) -> Self {
101 self.join_condition = Some(condition);
102 self
103 }
104}
105
106#[derive(Debug, Clone, Default, Serialize, Deserialize)]
108pub struct JoinStatistics {
109 pub left_events_processed: u64,
110 pub right_events_processed: u64,
111 pub matched_pairs: u64,
112 pub unmatched_left: u64,
113 pub unmatched_right: u64,
114 pub late_events_dropped: u64,
115 pub buffer_size_left: usize,
116 pub buffer_size_right: usize,
117 pub last_watermark: Option<DateTime<Utc>>,
118}
119
120pub struct StreamJoinProcessor {
122 config: JoinConfig,
123 left_buffer: Arc<RwLock<HashMap<String, VecDeque<StreamEvent>>>>,
124 right_buffer: Arc<RwLock<HashMap<String, VecDeque<StreamEvent>>>>,
125 watermark: Arc<RwLock<Watermark>>,
126 statistics: Arc<RwLock<JoinStatistics>>,
127}
128
129impl StreamJoinProcessor {
130 pub fn new(config: JoinConfig) -> Self {
131 Self {
132 config,
133 left_buffer: Arc::new(RwLock::new(HashMap::new())),
134 right_buffer: Arc::new(RwLock::new(HashMap::new())),
135 watermark: Arc::new(RwLock::new(Watermark::new())),
136 statistics: Arc::new(RwLock::new(JoinStatistics::default())),
137 }
138 }
139
140 pub async fn process_left(&self, event: StreamEvent) -> Result<Vec<StreamEvent>> {
142 let event_time = event.timestamp();
143
144 if self.is_late_event(event_time).await {
146 self.statistics.write().await.late_events_dropped += 1;
147 warn!("Dropping late left event: {:?}", event_time);
148 return Ok(vec![]);
149 }
150
151 let key = match (self.config.left_key_extractor)(&event) {
153 Some(k) => k,
154 None => {
155 debug!("No join key found for left event");
156 return Ok(vec![]);
157 }
158 };
159
160 {
162 let mut left_buffer = self.left_buffer.write().await;
163 left_buffer
164 .entry(key.clone())
165 .or_insert_with(VecDeque::new)
166 .push_back(event.clone());
167
168 if let Some(events) = left_buffer.get_mut(&key) {
170 while events.len() > self.config.buffer_size {
171 events.pop_front();
172 }
173 }
174 }
175
176 {
178 let mut stats = self.statistics.write().await;
179 stats.left_events_processed += 1;
180 stats.buffer_size_left = self
181 .left_buffer
182 .read()
183 .await
184 .values()
185 .map(|v| v.len())
186 .sum();
187 }
188
189 self.join_with_right(&key, &event).await
191 }
192
193 pub async fn process_right(&self, event: StreamEvent) -> Result<Vec<StreamEvent>> {
195 let event_time = event.timestamp();
196
197 if self.is_late_event(event_time).await {
199 self.statistics.write().await.late_events_dropped += 1;
200 warn!("Dropping late right event: {:?}", event_time);
201 return Ok(vec![]);
202 }
203
204 let key = match (self.config.right_key_extractor)(&event) {
206 Some(k) => k,
207 None => {
208 debug!("No join key found for right event");
209 return Ok(vec![]);
210 }
211 };
212
213 {
215 let mut right_buffer = self.right_buffer.write().await;
216 right_buffer
217 .entry(key.clone())
218 .or_insert_with(VecDeque::new)
219 .push_back(event.clone());
220
221 if let Some(events) = right_buffer.get_mut(&key) {
223 while events.len() > self.config.buffer_size {
224 events.pop_front();
225 }
226 }
227 }
228
229 {
231 let mut stats = self.statistics.write().await;
232 stats.right_events_processed += 1;
233 stats.buffer_size_right = self
234 .right_buffer
235 .read()
236 .await
237 .values()
238 .map(|v| v.len())
239 .sum();
240 }
241
242 self.join_with_left(&key, &event).await
244 }
245
246 pub async fn update_watermark(&self, watermark: DateTime<Utc>) -> Result<()> {
248 (*self.watermark.write().await).update(watermark);
249 self.statistics.write().await.last_watermark = Some(watermark);
250
251 self.clean_expired_events().await?;
253
254 Ok(())
255 }
256
257 pub async fn get_statistics(&self) -> JoinStatistics {
259 self.statistics.read().await.clone()
260 }
261
262 async fn join_with_right(
264 &self,
265 key: &str,
266 left_event: &StreamEvent,
267 ) -> Result<Vec<StreamEvent>> {
268 let mut results = Vec::new();
269 let right_buffer = self.right_buffer.read().await;
270
271 if let Some(right_events) = right_buffer.get(key) {
272 for right_event in right_events {
273 if self.should_join(left_event, right_event).await {
274 let joined = (self.config.result_transformer)(left_event, Some(right_event))?;
275 results.push(joined);
276 self.statistics.write().await.matched_pairs += 1;
277 }
278 }
279 }
280
281 if results.is_empty()
283 && matches!(
284 self.config.join_type,
285 JoinType::LeftOuter | JoinType::FullOuter
286 )
287 {
288 let joined = (self.config.result_transformer)(left_event, None)?;
289 results.push(joined);
290 self.statistics.write().await.unmatched_left += 1;
291 }
292
293 Ok(results)
294 }
295
296 async fn join_with_left(
298 &self,
299 key: &str,
300 right_event: &StreamEvent,
301 ) -> Result<Vec<StreamEvent>> {
302 let mut results = Vec::new();
303 let left_buffer = self.left_buffer.read().await;
304
305 if let Some(left_events) = left_buffer.get(key) {
306 for left_event in left_events {
307 if self.should_join(left_event, right_event).await {
308 let joined = (self.config.result_transformer)(left_event, Some(right_event))?;
309 results.push(joined);
310 self.statistics.write().await.matched_pairs += 1;
311 }
312 }
313 }
314
315 if results.is_empty()
317 && matches!(
318 self.config.join_type,
319 JoinType::RightOuter | JoinType::FullOuter
320 )
321 {
322 let joined = match &self.config.join_type {
324 JoinType::RightOuter => {
325 create_null_joined_event(right_event, true)?
327 }
328 _ => (self.config.result_transformer)(right_event, None)?,
329 };
330 results.push(joined);
331 self.statistics.write().await.unmatched_right += 1;
332 }
333
334 Ok(results)
335 }
336
337 async fn should_join(&self, left: &StreamEvent, right: &StreamEvent) -> bool {
339 if let Some(tolerance) = self.config.temporal_tolerance {
341 let time_diff = (left.timestamp() - right.timestamp()).abs();
342 if time_diff > tolerance {
343 return false;
344 }
345 }
346
347 if let Some(condition) = &self.config.join_condition {
349 condition(left, right)
350 } else {
351 true
352 }
353 }
354
355 async fn is_late_event(&self, event_time: DateTime<Utc>) -> bool {
357 let watermark = self.watermark.read().await;
358 let watermark_time = (*watermark).current();
359
360 event_time < watermark_time - self.config.allowed_lateness
361 }
362
363 async fn clean_expired_events(&self) -> Result<()> {
365 let watermark_time = self.watermark.read().await.current();
366 let expiry_time = watermark_time - self.config.allowed_lateness;
367
368 {
370 let mut left_buffer = self.left_buffer.write().await;
371 for events in left_buffer.values_mut() {
372 events.retain(|e| e.timestamp() >= expiry_time);
373 }
374 left_buffer.retain(|_, v| !v.is_empty());
375 }
376
377 {
379 let mut right_buffer = self.right_buffer.write().await;
380 for events in right_buffer.values_mut() {
381 events.retain(|e| e.timestamp() >= expiry_time);
382 }
383 right_buffer.retain(|_, v| !v.is_empty());
384 }
385
386 Ok(())
387 }
388}
389
390fn create_null_joined_event(event: &StreamEvent, is_right_null: bool) -> Result<StreamEvent> {
392 let mut metadata = event.metadata().clone();
394 metadata.properties.insert(
395 "join_type".to_string(),
396 if is_right_null {
397 "right_null".to_string()
398 } else {
399 "left_null".to_string()
400 },
401 );
402
403 match event {
404 StreamEvent::TripleAdded {
405 subject,
406 predicate,
407 object,
408 graph,
409 metadata: _,
410 } => Ok(StreamEvent::TripleAdded {
411 subject: subject.clone(),
412 predicate: predicate.clone(),
413 object: object.clone(),
414 graph: graph.clone(),
415 metadata: metadata.clone(),
416 }),
417 _ => Ok(event.clone()),
418 }
419}
420
421pub struct JoinBuilder {
423 join_type: JoinType,
424 left_key_extractor: Option<JoinKeyExtractor>,
425 right_key_extractor: Option<JoinKeyExtractor>,
426 result_transformer: Option<JoinResultTransformer>,
427 window: Option<WindowType>,
428 temporal_tolerance: Option<Duration>,
429 join_condition: Option<JoinCondition>,
430 buffer_size: usize,
431 allowed_lateness: Duration,
432}
433
434impl JoinBuilder {
435 pub fn new(join_type: JoinType) -> Self {
436 Self {
437 join_type,
438 left_key_extractor: None,
439 right_key_extractor: None,
440 result_transformer: None,
441 window: None,
442 temporal_tolerance: None,
443 join_condition: None,
444 buffer_size: 10000,
445 allowed_lateness: Duration::minutes(5),
446 }
447 }
448
449 pub fn with_keys(
450 mut self,
451 left_extractor: JoinKeyExtractor,
452 right_extractor: JoinKeyExtractor,
453 ) -> Self {
454 self.left_key_extractor = Some(left_extractor);
455 self.right_key_extractor = Some(right_extractor);
456 self
457 }
458
459 pub fn with_transformer(mut self, transformer: JoinResultTransformer) -> Self {
460 self.result_transformer = Some(transformer);
461 self
462 }
463
464 pub fn with_window(mut self, window: WindowType) -> Self {
465 self.window = Some(window);
466 self
467 }
468
469 pub fn with_temporal_tolerance(mut self, tolerance: Duration) -> Self {
470 self.temporal_tolerance = Some(tolerance);
471 self
472 }
473
474 pub fn with_condition(mut self, condition: JoinCondition) -> Self {
475 self.join_condition = Some(condition);
476 self
477 }
478
479 pub fn with_buffer_size(mut self, size: usize) -> Self {
480 self.buffer_size = size;
481 self
482 }
483
484 pub fn with_allowed_lateness(mut self, lateness: Duration) -> Self {
485 self.allowed_lateness = lateness;
486 self
487 }
488
489 pub fn build(self) -> Result<StreamJoinProcessor> {
490 let config = JoinConfig {
491 join_type: self.join_type,
492 window: self.window,
493 left_key_extractor: self
494 .left_key_extractor
495 .ok_or_else(|| anyhow!("Left key extractor is required"))?,
496 right_key_extractor: self
497 .right_key_extractor
498 .ok_or_else(|| anyhow!("Right key extractor is required"))?,
499 join_condition: self.join_condition,
500 result_transformer: self
501 .result_transformer
502 .ok_or_else(|| anyhow!("Result transformer is required"))?,
503 temporal_tolerance: self.temporal_tolerance,
504 buffer_size: self.buffer_size,
505 collect_stats: true,
506 allowed_lateness: self.allowed_lateness,
507 };
508
509 Ok(StreamJoinProcessor::new(config))
510 }
511}
512
513pub mod patterns {
515 use super::*;
516 use crate::StreamEvent;
517
518 pub fn subject_key_extractor() -> JoinKeyExtractor {
520 Arc::new(|event: &StreamEvent| match event {
521 StreamEvent::TripleAdded { subject, .. }
522 | StreamEvent::TripleRemoved { subject, .. } => Some(subject.clone()),
523 _ => None,
524 })
525 }
526
527 pub fn predicate_key_extractor() -> JoinKeyExtractor {
529 Arc::new(|event: &StreamEvent| match event {
530 StreamEvent::TripleAdded { predicate, .. }
531 | StreamEvent::TripleRemoved { predicate, .. } => Some(predicate.clone()),
532 _ => None,
533 })
534 }
535
536 pub fn graph_key_extractor() -> JoinKeyExtractor {
538 Arc::new(|event: &StreamEvent| match event {
539 StreamEvent::TripleAdded { graph, .. } | StreamEvent::TripleRemoved { graph, .. } => {
540 graph.clone()
541 }
542 _ => None,
543 })
544 }
545
546 pub fn merge_transformer() -> JoinResultTransformer {
548 Arc::new(|left: &StreamEvent, right: Option<&StreamEvent>| {
549 let mut metadata = left.metadata().clone();
550
551 if let Some(right_event) = right {
552 for (k, v) in right_event.metadata().properties.iter() {
554 metadata.properties.insert(format!("right_{k}"), v.clone());
555 }
556 metadata
557 .properties
558 .insert("join_result".to_string(), "matched".to_string());
559 } else {
560 metadata
561 .properties
562 .insert("join_result".to_string(), "unmatched".to_string());
563 }
564
565 match left {
567 StreamEvent::TripleAdded {
568 subject,
569 predicate,
570 object,
571 graph,
572 ..
573 } => Ok(StreamEvent::TripleAdded {
574 subject: subject.clone(),
575 predicate: predicate.clone(),
576 object: object.clone(),
577 graph: graph.clone(),
578 metadata,
579 }),
580 _ => Ok(left.clone()),
581 }
582 })
583 }
584}
585
586#[cfg(test)]
587mod tests {
588 use super::*;
589 use crate::{event::EventMetadata, StreamEvent};
590
591 fn create_test_event(subject: &str, timestamp: DateTime<Utc>) -> StreamEvent {
592 StreamEvent::TripleAdded {
593 subject: subject.to_string(),
594 predicate: "http://example.org/predicate".to_string(),
595 object: "http://example.org/object".to_string(),
596 graph: None,
597 metadata: EventMetadata {
598 event_id: uuid::Uuid::new_v4().to_string(),
599 timestamp,
600 source: "test".to_string(),
601 user: None,
602 context: None,
603 caused_by: None,
604 version: "1.0".to_string(),
605 properties: std::collections::HashMap::new(),
606 checksum: None,
607 },
608 }
609 }
610
611 #[tokio::test]
612 async fn test_inner_join() {
613 let processor = JoinBuilder::new(JoinType::Inner)
614 .with_keys(
615 patterns::subject_key_extractor(),
616 patterns::subject_key_extractor(),
617 )
618 .with_transformer(patterns::merge_transformer())
619 .build()
620 .unwrap();
621
622 let now = Utc::now();
623
624 let left_event = create_test_event("http://example.org/subject1", now);
626 let right_event =
627 create_test_event("http://example.org/subject1", now + Duration::seconds(1));
628
629 let results = processor.process_left(left_event.clone()).await.unwrap();
631 assert_eq!(results.len(), 0);
632
633 let results = processor.process_right(right_event).await.unwrap();
635 assert_eq!(results.len(), 1);
636
637 let stats = processor.get_statistics().await;
638 assert_eq!(stats.matched_pairs, 1);
639 assert_eq!(stats.unmatched_left, 0);
640 assert_eq!(stats.unmatched_right, 0);
641 }
642
643 #[tokio::test]
644 async fn test_left_outer_join() {
645 let processor = JoinBuilder::new(JoinType::LeftOuter)
646 .with_keys(
647 patterns::subject_key_extractor(),
648 patterns::subject_key_extractor(),
649 )
650 .with_transformer(patterns::merge_transformer())
651 .build()
652 .unwrap();
653
654 let now = Utc::now();
655
656 let left_event = create_test_event("http://example.org/subject1", now);
658 let results = processor.process_left(left_event).await.unwrap();
659
660 assert_eq!(results.len(), 1);
662
663 let stats = processor.get_statistics().await;
664 assert_eq!(stats.unmatched_left, 1);
665 }
666
667 #[tokio::test]
668 async fn test_temporal_join() {
669 let processor = JoinBuilder::new(JoinType::Inner)
670 .with_keys(
671 patterns::subject_key_extractor(),
672 patterns::subject_key_extractor(),
673 )
674 .with_transformer(patterns::merge_transformer())
675 .with_temporal_tolerance(Duration::seconds(5))
676 .build()
677 .unwrap();
678
679 let now = Utc::now();
680
681 let left_event = create_test_event("http://example.org/subject1", now);
683 processor.process_left(left_event).await.unwrap();
684
685 let right_event1 =
687 create_test_event("http://example.org/subject1", now + Duration::seconds(3));
688 let results = processor.process_right(right_event1).await.unwrap();
689 assert_eq!(results.len(), 1);
690
691 let right_event2 =
693 create_test_event("http://example.org/subject1", now + Duration::seconds(10));
694 let results = processor.process_right(right_event2).await.unwrap();
695 assert_eq!(results.len(), 0);
696
697 let stats = processor.get_statistics().await;
698 assert_eq!(stats.matched_pairs, 1);
699 }
700
701 #[tokio::test]
702 async fn test_late_event_handling() {
703 let processor = JoinBuilder::new(JoinType::Inner)
704 .with_keys(
705 patterns::subject_key_extractor(),
706 patterns::subject_key_extractor(),
707 )
708 .with_transformer(patterns::merge_transformer())
709 .with_allowed_lateness(Duration::minutes(1))
710 .build()
711 .unwrap();
712
713 let now = Utc::now();
714
715 processor.update_watermark(now).await.unwrap();
717
718 let late_event =
720 create_test_event("http://example.org/subject1", now - Duration::minutes(2));
721 let results = processor.process_left(late_event).await.unwrap();
722 assert_eq!(results.len(), 0);
723
724 let stats = processor.get_statistics().await;
725 assert_eq!(stats.late_events_dropped, 1);
726 }
727}