1use std::collections::VecDeque;
15use std::time::{Duration, SystemTime};
16
17use crate::error::StreamingError;
18
19#[derive(Debug, Clone)]
23pub struct JoinEvent<T> {
24 pub timestamp: SystemTime,
26 pub key: String,
28 pub payload: T,
30}
31
32impl<T> JoinEvent<T> {
33 pub fn new(timestamp: SystemTime, key: impl Into<String>, payload: T) -> Self {
35 Self {
36 timestamp,
37 key: key.into(),
38 payload,
39 }
40 }
41}
42
43#[derive(Debug, Clone)]
47pub struct JoinedPair<L, R> {
48 pub left: JoinEvent<L>,
50 pub right: JoinEvent<R>,
52 pub time_delta: Duration,
54}
55
56#[derive(Debug, Clone, PartialEq)]
60pub enum JoinMode {
61 Inner,
63 LeftOuter,
65 Interval {
67 lower: Duration,
69 upper: Duration,
71 },
72}
73
74#[derive(Debug, Clone)]
78pub struct TemporalJoinConfig {
79 pub time_tolerance: Duration,
82 pub max_buffer_size: usize,
84 pub mode: JoinMode,
86}
87
88impl Default for TemporalJoinConfig {
89 fn default() -> Self {
90 Self {
91 time_tolerance: Duration::from_secs(5),
92 max_buffer_size: 10_000,
93 mode: JoinMode::Inner,
94 }
95 }
96}
97
98pub struct TemporalJoiner<L: Clone, R: Clone> {
106 config: TemporalJoinConfig,
107 left_buffer: VecDeque<JoinEvent<L>>,
108 right_buffer: VecDeque<JoinEvent<R>>,
109 output: VecDeque<JoinedPair<L, R>>,
110 total_joined: u64,
111 total_expired_left: u64,
112 total_expired_right: u64,
113}
114
115impl<L: Clone, R: Clone> TemporalJoiner<L, R> {
116 pub fn new(config: TemporalJoinConfig) -> Self {
118 Self {
119 config,
120 left_buffer: VecDeque::new(),
121 right_buffer: VecDeque::new(),
122 output: VecDeque::new(),
123 total_joined: 0,
124 total_expired_left: 0,
125 total_expired_right: 0,
126 }
127 }
128
129 pub fn add_left(&mut self, event: JoinEvent<L>) -> Result<(), StreamingError> {
135 if self.left_buffer.len() >= self.config.max_buffer_size {
136 self.left_buffer.pop_front();
137 self.total_expired_left += 1;
138 }
139 self.try_join_with_left(&event);
140 self.left_buffer.push_back(event);
141 Ok(())
142 }
143
144 pub fn add_right(&mut self, event: JoinEvent<R>) -> Result<(), StreamingError> {
150 if self.right_buffer.len() >= self.config.max_buffer_size {
151 self.right_buffer.pop_front();
152 self.total_expired_right += 1;
153 }
154 self.try_join_with_right(&event);
155 self.right_buffer.push_back(event);
156 Ok(())
157 }
158
159 pub fn drain_output(&mut self) -> Vec<JoinedPair<L, R>> {
161 self.output.drain(..).collect()
162 }
163
164 pub fn total_joined(&self) -> u64 {
166 self.total_joined
167 }
168
169 pub fn left_buffer_size(&self) -> usize {
171 self.left_buffer.len()
172 }
173
174 pub fn right_buffer_size(&self) -> usize {
176 self.right_buffer.len()
177 }
178
179 pub fn total_expired_left(&self) -> u64 {
181 self.total_expired_left
182 }
183
184 pub fn total_expired_right(&self) -> u64 {
186 self.total_expired_right
187 }
188
189 fn time_delta(a: SystemTime, b: SystemTime) -> Duration {
193 a.duration_since(b)
194 .unwrap_or_else(|_| b.duration_since(a).unwrap_or(Duration::ZERO))
195 }
196
197 fn matches(&self, left_time: SystemTime, right_time: SystemTime) -> Option<Duration> {
200 let delta = Self::time_delta(left_time, right_time);
201 match &self.config.mode {
202 JoinMode::Inner | JoinMode::LeftOuter => {
203 if delta <= self.config.time_tolerance {
204 Some(delta)
205 } else {
206 None
207 }
208 }
209 JoinMode::Interval { lower, upper } => {
210 let lower_bound = right_time + *lower;
212 let upper_bound = right_time + *upper;
213 if left_time >= lower_bound && left_time <= upper_bound {
214 Some(delta)
215 } else {
216 None
217 }
218 }
219 }
220 }
221
222 fn try_join_with_left(&mut self, left: &JoinEvent<L>) {
224 for right in &self.right_buffer {
225 if right.key != left.key {
226 continue;
227 }
228 if let Some(delta) = self.matches(left.timestamp, right.timestamp) {
229 self.output.push_back(JoinedPair {
230 left: left.clone(),
231 right: right.clone(),
232 time_delta: delta,
233 });
234 self.total_joined += 1;
235 }
236 }
237 }
238
239 fn try_join_with_right(&mut self, right: &JoinEvent<R>) {
241 for left in &self.left_buffer {
242 if left.key != right.key {
243 continue;
244 }
245 if let Some(delta) = self.matches(left.timestamp, right.timestamp) {
246 self.output.push_back(JoinedPair {
247 left: left.clone(),
248 right: right.clone(),
249 time_delta: delta,
250 });
251 self.total_joined += 1;
252 }
253 }
254 }
255}
256
257#[cfg(test)]
258mod tests {
259 use super::*;
260 use std::time::UNIX_EPOCH;
261
262 fn ts(secs: u64) -> SystemTime {
263 UNIX_EPOCH + Duration::from_secs(secs)
264 }
265
266 fn left_event(secs: u64, key: &str) -> JoinEvent<&'static str> {
267 JoinEvent::new(ts(secs), key.to_string(), "left_payload")
268 }
269
270 fn right_event(secs: u64, key: &str) -> JoinEvent<&'static str> {
271 JoinEvent::new(ts(secs), key.to_string(), "right_payload")
272 }
273
274 #[test]
275 fn test_inner_join_matching_key_and_time() {
276 let mut joiner = TemporalJoiner::new(TemporalJoinConfig::default());
277 joiner.add_left(left_event(100, "k1")).expect("add ok");
278 joiner.add_right(right_event(102, "k1")).expect("add ok"); let pairs = joiner.drain_output();
280 assert_eq!(pairs.len(), 1);
281 assert_eq!(pairs[0].left.key, "k1");
282 }
283
284 #[test]
285 fn test_inner_join_miss_outside_tolerance() {
286 let mut joiner = TemporalJoiner::new(TemporalJoinConfig::default()); joiner.add_left(left_event(100, "k1")).expect("add ok");
288 joiner.add_right(right_event(110, "k1")).expect("add ok"); let pairs = joiner.drain_output();
290 assert!(pairs.is_empty());
291 }
292
293 #[test]
294 fn test_no_join_on_key_mismatch() {
295 let mut joiner = TemporalJoiner::new(TemporalJoinConfig::default());
296 joiner.add_left(left_event(100, "k1")).expect("add ok");
297 joiner.add_right(right_event(100, "k2")).expect("add ok");
298 let pairs = joiner.drain_output();
299 assert!(pairs.is_empty());
300 }
301
302 #[test]
303 fn test_left_outer_mode_config() {
304 let cfg = TemporalJoinConfig {
305 mode: JoinMode::LeftOuter,
306 ..Default::default()
307 };
308 let mut joiner = TemporalJoiner::<&str, &str>::new(cfg);
309 joiner.add_left(left_event(100, "k1")).expect("add ok");
310 joiner.add_right(right_event(103, "k1")).expect("add ok"); let pairs = joiner.drain_output();
312 assert_eq!(pairs.len(), 1);
314 }
315
316 #[test]
317 fn test_interval_join_matches_within_interval() {
318 let cfg = TemporalJoinConfig {
319 mode: JoinMode::Interval {
320 lower: Duration::from_secs(2),
321 upper: Duration::from_secs(8),
322 },
323 max_buffer_size: 100,
324 time_tolerance: Duration::from_secs(1), };
326 let mut joiner = TemporalJoiner::new(cfg);
327 joiner.add_right(right_event(100, "k1")).expect("add ok");
329 joiner.add_left(left_event(105, "k1")).expect("add ok"); let pairs = joiner.drain_output();
331 assert_eq!(pairs.len(), 1);
332 }
333
334 #[test]
335 fn test_interval_join_no_match_outside_interval() {
336 let cfg = TemporalJoinConfig {
337 mode: JoinMode::Interval {
338 lower: Duration::from_secs(2),
339 upper: Duration::from_secs(8),
340 },
341 max_buffer_size: 100,
342 time_tolerance: Duration::from_secs(1),
343 };
344 let mut joiner = TemporalJoiner::new(cfg);
345 joiner.add_right(right_event(100, "k1")).expect("add ok");
346 joiner.add_left(left_event(110, "k1")).expect("add ok"); let pairs = joiner.drain_output();
348 assert!(pairs.is_empty());
349 }
350
351 #[test]
352 fn test_buffer_eviction_when_max_exceeded() {
353 let cfg = TemporalJoinConfig {
354 max_buffer_size: 3,
355 ..Default::default()
356 };
357 let mut joiner = TemporalJoiner::<&str, &str>::new(cfg);
358 for i in 0u64..5 {
359 joiner.add_left(left_event(i * 1000, "kx")).expect("add ok");
360 }
361 assert_eq!(joiner.total_expired_left(), 2);
362 assert_eq!(joiner.left_buffer_size(), 3);
363 }
364
365 #[test]
366 fn test_time_delta_computation_is_correct() {
367 let mut joiner = TemporalJoiner::new(TemporalJoinConfig::default());
368 joiner.add_left(left_event(1000, "k1")).expect("add ok");
369 joiner.add_right(right_event(1003, "k1")).expect("add ok");
370 let pairs = joiner.drain_output();
371 assert_eq!(pairs[0].time_delta, Duration::from_secs(3));
372 }
373
374 #[test]
375 fn test_total_joined_counter() {
376 let mut joiner = TemporalJoiner::new(TemporalJoinConfig::default());
377 joiner.add_left(left_event(100, "k1")).expect("add ok");
378 joiner.add_right(right_event(101, "k1")).expect("add ok");
379 joiner.add_left(left_event(200, "k2")).expect("add ok");
380 joiner.add_right(right_event(201, "k2")).expect("add ok");
381 joiner.drain_output();
382 assert_eq!(joiner.total_joined(), 2);
383 }
384
385 #[test]
386 fn test_add_left_then_right_same_as_right_then_left() {
387 let mut j1 = TemporalJoiner::new(TemporalJoinConfig::default());
388 j1.add_left(left_event(100, "k")).expect("ok");
389 j1.add_right(right_event(102, "k")).expect("ok");
390 let p1 = j1.drain_output();
391
392 let mut j2 = TemporalJoiner::new(TemporalJoinConfig::default());
393 j2.add_right(right_event(102, "k")).expect("ok");
394 j2.add_left(left_event(100, "k")).expect("ok");
395 let p2 = j2.drain_output();
396
397 assert_eq!(p1.len(), 1);
399 assert_eq!(p2.len(), 1);
400 assert_eq!(p1[0].time_delta, p2[0].time_delta);
401 }
402
403 #[test]
404 fn test_multiple_right_events_match_single_left() {
405 let mut joiner = TemporalJoiner::new(TemporalJoinConfig::default());
406 joiner.add_right(right_event(100, "k")).expect("ok");
408 joiner.add_right(right_event(101, "k")).expect("ok");
409 joiner.add_left(left_event(102, "k")).expect("ok");
411 let pairs = joiner.drain_output();
412 assert_eq!(pairs.len(), 2);
413 }
414
415 #[test]
416 fn test_expired_right_counter() {
417 let cfg = TemporalJoinConfig {
418 max_buffer_size: 2,
419 ..Default::default()
420 };
421 let mut joiner = TemporalJoiner::<&str, &str>::new(cfg);
422 joiner.add_right(right_event(0, "a")).expect("ok");
423 joiner.add_right(right_event(1, "a")).expect("ok");
424 joiner.add_right(right_event(2, "a")).expect("ok"); assert_eq!(joiner.total_expired_right(), 1);
426 }
427
428 #[test]
429 fn test_no_cross_key_contamination() {
430 let mut joiner = TemporalJoiner::new(TemporalJoinConfig::default());
431 joiner.add_left(left_event(100, "alpha")).expect("ok");
432 joiner.add_left(left_event(100, "beta")).expect("ok");
433 joiner.add_right(right_event(101, "alpha")).expect("ok");
434 let pairs = joiner.drain_output();
435 assert_eq!(pairs.len(), 1);
436 assert_eq!(pairs[0].left.key, "alpha");
437 }
438}