use std::collections::HashMap;
use super::{WindowJoinKey, WindowJoinResult, WindowJoinStats};
#[derive(Debug, Clone)]
pub struct SessionSessionJoinConfig {
pub gap_ms: i64,
pub allowed_lateness_ms: i64,
}
impl SessionSessionJoinConfig {
pub fn new(gap_ms: i64) -> Self {
assert!(gap_ms > 0, "gap_ms must be > 0");
Self {
gap_ms,
allowed_lateness_ms: 0,
}
}
pub fn with_lateness(mut self, allowed_lateness_ms: i64) -> Self {
self.allowed_lateness_ms = allowed_lateness_ms;
self
}
}
#[derive(Debug, Clone)]
struct Session<E: Clone> {
first_ts_ms: i64,
last_ts_ms: i64,
events: Vec<(i64, E)>,
}
impl<E: Clone> Session<E> {
fn new(ts_ms: i64, event: E) -> Self {
Self {
first_ts_ms: ts_ms,
last_ts_ms: ts_ms,
events: vec![(ts_ms, event)],
}
}
fn extend(&mut self, ts_ms: i64, event: E) {
if ts_ms < self.first_ts_ms {
self.first_ts_ms = ts_ms;
}
if ts_ms > self.last_ts_ms {
self.last_ts_ms = ts_ms;
}
self.events.push((ts_ms, event));
}
fn end_ms(&self, gap: i64) -> i64 {
self.last_ts_ms.saturating_add(gap)
}
fn close_at(&self, gap: i64, lateness: i64) -> i64 {
self.end_ms(gap).saturating_add(lateness)
}
}
pub struct SessionSessionJoin<L: Clone, R: Clone> {
config: SessionSessionJoinConfig,
left_sessions: HashMap<WindowJoinKey, Vec<Session<L>>>,
right_sessions: HashMap<WindowJoinKey, Vec<Session<R>>>,
last_watermark_ms: i64,
stats: WindowJoinStats,
}
impl<L: Clone, R: Clone> SessionSessionJoin<L, R> {
pub fn new(config: SessionSessionJoinConfig) -> Self {
Self {
config,
left_sessions: HashMap::new(),
right_sessions: HashMap::new(),
last_watermark_ms: i64::MIN,
stats: WindowJoinStats::default(),
}
}
fn extend_or_new<E: Clone>(
sessions: &mut Vec<Session<E>>,
ts_ms: i64,
gap: i64,
event: E,
) -> bool {
for s in sessions.iter_mut() {
if ts_ms.saturating_sub(s.last_ts_ms).abs() <= gap
|| ts_ms.saturating_sub(s.first_ts_ms).abs() <= gap
|| (ts_ms >= s.first_ts_ms && ts_ms <= s.last_ts_ms)
{
s.extend(ts_ms, event);
return true;
}
}
sessions.push(Session::new(ts_ms, event));
false
}
fn is_late(&self, ts_ms: i64) -> bool {
if self.last_watermark_ms == i64::MIN {
return false;
}
ts_ms.saturating_add(self.config.gap_ms)
< self
.last_watermark_ms
.saturating_sub(self.config.allowed_lateness_ms)
}
pub fn push_left(&mut self, key: WindowJoinKey, ts_ms: i64, event: L) {
if self.is_late(ts_ms) {
self.stats.late_events_dropped += 1;
return;
}
self.stats.left_events += 1;
let gap = self.config.gap_ms;
let entry = self.left_sessions.entry(key).or_default();
let _ = Self::extend_or_new(entry, ts_ms, gap, event);
}
pub fn push_right(&mut self, key: WindowJoinKey, ts_ms: i64, event: R) {
if self.is_late(ts_ms) {
self.stats.late_events_dropped += 1;
return;
}
self.stats.right_events += 1;
let gap = self.config.gap_ms;
let entry = self.right_sessions.entry(key).or_default();
let _ = Self::extend_or_new(entry, ts_ms, gap, event);
}
pub fn advance_watermark(&mut self, watermark_ms: i64) -> Vec<WindowJoinResult<L, R>> {
if watermark_ms < self.last_watermark_ms {
return Vec::new();
}
self.last_watermark_ms = watermark_ms;
let gap = self.config.gap_ms;
let lat = self.config.allowed_lateness_ms;
let mut emitted = Vec::new();
let mut purged = 0usize;
let keys: Vec<WindowJoinKey> = {
let mut k: Vec<WindowJoinKey> = self
.left_sessions
.keys()
.chain(self.right_sessions.keys())
.cloned()
.collect();
k.sort();
k.dedup();
k
};
for key in keys {
let left_closed_count = self
.left_sessions
.get(&key)
.map(|v| {
v.iter()
.filter(|s| s.close_at(gap, lat) <= watermark_ms)
.count()
})
.unwrap_or(0);
let right_closed_count = self
.right_sessions
.get(&key)
.map(|v| {
v.iter()
.filter(|s| s.close_at(gap, lat) <= watermark_ms)
.count()
})
.unwrap_or(0);
let left_total = self.left_sessions.get(&key).map(|v| v.len()).unwrap_or(0);
let right_total = self.right_sessions.get(&key).map(|v| v.len()).unwrap_or(0);
let both_sides_closed = left_closed_count == left_total
&& right_closed_count == right_total
&& left_total > 0
&& right_total > 0;
if !both_sides_closed {
continue;
}
let lefts: Vec<Session<L>> = self.left_sessions.get(&key).cloned().unwrap_or_default();
let rights: Vec<Session<R>> =
self.right_sessions.get(&key).cloned().unwrap_or_default();
for ls in &lefts {
for rs in &rights {
if self.sessions_overlap(ls, rs) {
for (_, l_ev) in &ls.events {
for (_, r_ev) in &rs.events {
emitted.push(WindowJoinResult {
key: key.clone(),
left: l_ev.clone(),
right: r_ev.clone(),
pane_end_ms: ls.end_ms(gap).max(rs.end_ms(gap)),
});
}
}
}
}
}
purged += left_total + right_total;
self.left_sessions.remove(&key);
self.right_sessions.remove(&key);
}
self.stats.joined_pairs += emitted.len() as u64;
self.stats.windows_closed += purged as u64;
emitted
}
fn sessions_overlap(&self, a: &Session<L>, b: &Session<R>) -> bool {
let gap = self.config.gap_ms;
let a_end = a.last_ts_ms.saturating_add(gap);
let b_end = b.last_ts_ms.saturating_add(gap);
a.first_ts_ms <= b_end && b.first_ts_ms <= a_end
}
pub fn stats(&self) -> &WindowJoinStats {
&self.stats
}
pub fn session_count(&self) -> usize {
self.left_sessions.values().map(|v| v.len()).sum::<usize>()
+ self.right_sessions.values().map(|v| v.len()).sum::<usize>()
}
pub fn watermark(&self) -> i64 {
self.last_watermark_ms
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn overlapping_sessions_emit_on_close() {
let cfg = SessionSessionJoinConfig::new(500);
let mut j: SessionSessionJoin<&str, &str> = SessionSessionJoin::new(cfg);
j.push_left("k".into(), 100, "L0");
j.push_left("k".into(), 200, "L1");
j.push_left("k".into(), 300, "L2");
j.push_right("k".into(), 250, "R0");
j.push_right("k".into(), 350, "R1");
let out = j.advance_watermark(900);
assert_eq!(out.len(), 6);
assert_eq!(j.session_count(), 0);
}
#[test]
fn non_overlapping_sessions_dont_emit() {
let cfg = SessionSessionJoinConfig::new(50);
let mut j: SessionSessionJoin<&str, &str> = SessionSessionJoin::new(cfg);
j.push_left("k".into(), 100, "L0"); j.push_right("k".into(), 1_000, "R0"); let out = j.advance_watermark(2_000);
assert!(out.is_empty());
}
#[test]
fn separate_keys_dont_join() {
let cfg = SessionSessionJoinConfig::new(500);
let mut j: SessionSessionJoin<&str, &str> = SessionSessionJoin::new(cfg);
j.push_left("a".into(), 100, "La");
j.push_right("b".into(), 200, "Rb");
let out = j.advance_watermark(2_000);
assert!(out.is_empty());
}
#[test]
fn late_event_after_emit_is_dropped() {
let cfg = SessionSessionJoinConfig::new(50);
let mut j: SessionSessionJoin<&str, &str> = SessionSessionJoin::new(cfg);
j.push_left("k".into(), 100, "L0");
j.advance_watermark(10_000);
j.push_left("k".into(), 100, "Late");
assert_eq!(j.stats.late_events_dropped, 1);
}
#[test]
fn allowed_lateness_keeps_session_open() {
let cfg = SessionSessionJoinConfig::new(50).with_lateness(1_000);
let mut j: SessionSessionJoin<&str, &str> = SessionSessionJoin::new(cfg);
j.push_left("k".into(), 100, "L0");
let out = j.advance_watermark(800);
assert!(out.is_empty());
j.push_right("k".into(), 120, "R0");
let out = j.advance_watermark(2_000);
assert_eq!(out.len(), 1);
}
#[test]
fn watermark_emits_only_closed_sessions() {
let cfg = SessionSessionJoinConfig::new(100);
let mut j: SessionSessionJoin<&str, &str> = SessionSessionJoin::new(cfg);
j.push_left("k".into(), 100, "L0"); j.push_right("k".into(), 150, "R0"); let out = j.advance_watermark(220);
assert!(out.is_empty());
let out = j.advance_watermark(260);
assert_eq!(out.len(), 1);
}
}