1use arrow_array::{Array, RecordBatch};
24use fxhash::FxHasher;
25use std::hash::{Hash, Hasher};
26
27use crate::operator::Event;
28
29#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)]
36pub enum RouterError {
37 #[error("column not found by name")]
39 ColumnNotFoundByName,
40
41 #[error("column index out of range")]
43 ColumnIndexOutOfRange,
44
45 #[error("row index out of range")]
47 RowIndexOutOfRange,
48
49 #[error("unsupported data type for routing")]
51 UnsupportedDataType,
52
53 #[error("empty batch")]
55 EmptyBatch,
56}
57
58#[derive(Debug, Clone, Default)]
63pub enum KeySpec {
64 Columns(Vec<String>),
68
69 ColumnIndices(Vec<usize>),
73
74 #[default]
79 RoundRobin,
80
81 AllColumns,
85}
86
87pub struct KeyRouter {
100 num_cores: usize,
102 key_spec: KeySpec,
104 round_robin_counter: std::sync::atomic::AtomicUsize,
106}
107
108impl KeyRouter {
109 #[must_use]
115 pub fn new(num_cores: usize, key_spec: KeySpec) -> Self {
116 assert!(num_cores > 0, "num_cores must be > 0");
117 Self {
118 num_cores,
119 key_spec,
120 round_robin_counter: std::sync::atomic::AtomicUsize::new(0),
121 }
122 }
123
124 #[must_use]
126 pub fn num_cores(&self) -> usize {
127 self.num_cores
128 }
129
130 #[must_use]
132 pub fn key_spec(&self) -> &KeySpec {
133 &self.key_spec
134 }
135
136 pub fn route(&self, event: &Event) -> Result<usize, super::TpcError> {
144 match &self.key_spec {
145 KeySpec::RoundRobin => {
146 let counter = self
148 .round_robin_counter
149 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
150 Ok(counter % self.num_cores)
151 }
152 KeySpec::Columns(columns) => self.route_by_columns(event, columns),
153 KeySpec::ColumnIndices(indices) => self.route_by_indices(event, indices),
154 KeySpec::AllColumns => self.route_all_columns(event),
155 }
156 }
157
158 pub fn route_row(&self, batch: &RecordBatch, row: usize) -> Result<usize, super::TpcError> {
166 match &self.key_spec {
167 KeySpec::RoundRobin => {
168 let counter = self
169 .round_robin_counter
170 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
171 Ok(counter % self.num_cores)
172 }
173 KeySpec::Columns(columns) => self.route_row_by_columns(batch, row, columns),
174 KeySpec::ColumnIndices(indices) => self.route_row_by_indices(batch, row, indices),
175 KeySpec::AllColumns => self.route_row_all_columns(batch, row),
176 }
177 }
178
179 #[allow(clippy::cast_possible_truncation)]
181 fn hash_to_core(&self, hash: u64) -> usize {
182 (hash as usize) % self.num_cores
184 }
185
186 fn route_by_columns(
188 &self,
189 event: &Event,
190 columns: &[String],
191 ) -> Result<usize, super::TpcError> {
192 let batch = &event.data;
193 let mut hasher = FxHasher::default();
194
195 for col_name in columns {
196 let col_idx = batch
197 .schema()
198 .index_of(col_name)
199 .map_err(|_| RouterError::ColumnNotFoundByName)?;
200
201 hash_column(&mut hasher, batch.column(col_idx))?;
202 }
203
204 Ok(self.hash_to_core(hasher.finish()))
205 }
206
207 fn route_by_indices(&self, event: &Event, indices: &[usize]) -> Result<usize, super::TpcError> {
209 let batch = &event.data;
210 let mut hasher = FxHasher::default();
211
212 for &idx in indices {
213 if idx >= batch.num_columns() {
214 return Err(RouterError::ColumnIndexOutOfRange.into());
215 }
216 hash_column(&mut hasher, batch.column(idx))?;
217 }
218
219 Ok(self.hash_to_core(hasher.finish()))
220 }
221
222 fn route_all_columns(&self, event: &Event) -> Result<usize, super::TpcError> {
224 let batch = &event.data;
225 let mut hasher = FxHasher::default();
226
227 for col in batch.columns() {
228 hash_column(&mut hasher, col)?;
229 }
230
231 Ok(self.hash_to_core(hasher.finish()))
232 }
233
234 fn route_row_by_columns(
236 &self,
237 batch: &RecordBatch,
238 row: usize,
239 columns: &[String],
240 ) -> Result<usize, super::TpcError> {
241 let mut hasher = FxHasher::default();
242
243 for col_name in columns {
244 let col_idx = batch
245 .schema()
246 .index_of(col_name)
247 .map_err(|_| RouterError::ColumnNotFoundByName)?;
248
249 hash_row_value(&mut hasher, batch.column(col_idx), row)?;
250 }
251
252 Ok(self.hash_to_core(hasher.finish()))
253 }
254
255 fn route_row_by_indices(
257 &self,
258 batch: &RecordBatch,
259 row: usize,
260 indices: &[usize],
261 ) -> Result<usize, super::TpcError> {
262 let mut hasher = FxHasher::default();
263
264 for &idx in indices {
265 if idx >= batch.num_columns() {
266 return Err(RouterError::ColumnIndexOutOfRange.into());
267 }
268 hash_row_value(&mut hasher, batch.column(idx), row)?;
269 }
270
271 Ok(self.hash_to_core(hasher.finish()))
272 }
273
274 fn route_row_all_columns(
276 &self,
277 batch: &RecordBatch,
278 row: usize,
279 ) -> Result<usize, super::TpcError> {
280 let mut hasher = FxHasher::default();
281
282 for col in batch.columns() {
283 hash_row_value(&mut hasher, col, row)?;
284 }
285
286 Ok(self.hash_to_core(hasher.finish()))
287 }
288}
289
290fn hash_column(hasher: &mut FxHasher, array: &dyn Array) -> Result<(), RouterError> {
292 if array.is_empty() {
294 0u8.hash(hasher);
295 return Ok(());
296 }
297
298 hash_row_value(hasher, array, 0)
299}
300
301fn hash_row_value(hasher: &mut FxHasher, array: &dyn Array, row: usize) -> Result<(), RouterError> {
303 use arrow_array::{
304 BinaryArray, BooleanArray, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array,
305 Int8Array, StringArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
306 };
307 use arrow_schema::DataType;
308
309 if row >= array.len() {
310 return Err(RouterError::RowIndexOutOfRange);
311 }
312
313 if array.is_null(row) {
314 0xDEAD_BEEF_u64.hash(hasher);
316 return Ok(());
317 }
318
319 match array.data_type() {
320 DataType::Int8 => {
321 let arr = array.as_any().downcast_ref::<Int8Array>().unwrap();
322 arr.value(row).hash(hasher);
323 }
324 DataType::Int16 => {
325 let arr = array.as_any().downcast_ref::<Int16Array>().unwrap();
326 arr.value(row).hash(hasher);
327 }
328 DataType::Int32 => {
329 let arr = array.as_any().downcast_ref::<Int32Array>().unwrap();
330 arr.value(row).hash(hasher);
331 }
332 DataType::Int64 => {
333 let arr = array.as_any().downcast_ref::<Int64Array>().unwrap();
334 arr.value(row).hash(hasher);
335 }
336 DataType::UInt8 => {
337 let arr = array.as_any().downcast_ref::<UInt8Array>().unwrap();
338 arr.value(row).hash(hasher);
339 }
340 DataType::UInt16 => {
341 let arr = array.as_any().downcast_ref::<UInt16Array>().unwrap();
342 arr.value(row).hash(hasher);
343 }
344 DataType::UInt32 => {
345 let arr = array.as_any().downcast_ref::<UInt32Array>().unwrap();
346 arr.value(row).hash(hasher);
347 }
348 DataType::UInt64 => {
349 let arr = array.as_any().downcast_ref::<UInt64Array>().unwrap();
350 arr.value(row).hash(hasher);
351 }
352 DataType::Float32 => {
353 let arr = array.as_any().downcast_ref::<Float32Array>().unwrap();
354 arr.value(row).to_bits().hash(hasher);
356 }
357 DataType::Float64 => {
358 let arr = array.as_any().downcast_ref::<Float64Array>().unwrap();
359 arr.value(row).to_bits().hash(hasher);
360 }
361 DataType::Utf8 => {
362 let arr = array.as_any().downcast_ref::<StringArray>().unwrap();
363 arr.value(row).hash(hasher);
364 }
365 DataType::Binary => {
366 let arr = array.as_any().downcast_ref::<BinaryArray>().unwrap();
367 arr.value(row).hash(hasher);
368 }
369 DataType::Boolean => {
370 let arr = array.as_any().downcast_ref::<BooleanArray>().unwrap();
371 arr.value(row).hash(hasher);
372 }
373 _ => {
374 let formatted =
377 arrow_cast::display::array_value_to_string(array, row).unwrap_or_default();
378 formatted.hash(hasher);
379 }
380 }
381
382 Ok(())
383}
384
385impl std::fmt::Debug for KeyRouter {
386 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
387 f.debug_struct("KeyRouter")
388 .field("num_cores", &self.num_cores)
389 .field("key_spec", &self.key_spec)
390 .finish_non_exhaustive()
391 }
392}
393
394#[cfg(test)]
395mod tests {
396 use super::*;
397 use arrow_array::{Int64Array, StringArray};
398 use std::sync::Arc;
399
400 fn make_event(user_id: i64, name: &str, timestamp: i64) -> Event {
401 let user_ids = Arc::new(Int64Array::from(vec![user_id]));
402 let names = Arc::new(StringArray::from(vec![name]));
403 let batch =
404 RecordBatch::try_from_iter(vec![("user_id", user_ids as _), ("name", names as _)])
405 .unwrap();
406 Event::new(timestamp, batch)
407 }
408
409 #[test]
410 fn test_round_robin() {
411 let router = KeyRouter::new(4, KeySpec::RoundRobin);
412
413 let event = make_event(1, "alice", 1000);
415
416 let mut cores = Vec::new();
417 for _ in 0..8 {
418 cores.push(router.route(&event).unwrap());
419 }
420
421 assert_eq!(cores, vec![0, 1, 2, 3, 0, 1, 2, 3]);
422 }
423
424 #[test]
425 fn test_route_by_column_name() {
426 let router = KeyRouter::new(4, KeySpec::Columns(vec!["user_id".to_string()]));
427
428 let event1 = make_event(100, "alice", 1000);
430 let event2 = make_event(100, "bob", 2000); let event3 = make_event(200, "charlie", 3000); let core1 = router.route(&event1).unwrap();
434 let core2 = router.route(&event2).unwrap();
435 let core3 = router.route(&event3).unwrap();
436
437 assert_eq!(core1, core2);
439 assert!(core1 < 4 && core3 < 4);
441 }
442
443 #[test]
444 fn test_route_by_column_index() {
445 let router = KeyRouter::new(4, KeySpec::ColumnIndices(vec![0])); let event1 = make_event(100, "alice", 1000);
448 let event2 = make_event(100, "bob", 2000);
449
450 let core1 = router.route(&event1).unwrap();
451 let core2 = router.route(&event2).unwrap();
452
453 assert_eq!(core1, core2);
454 }
455
456 #[test]
457 fn test_route_all_columns() {
458 let router = KeyRouter::new(4, KeySpec::AllColumns);
459
460 let event1 = make_event(100, "alice", 1000);
461 let event2 = make_event(100, "alice", 2000); let event3 = make_event(100, "bob", 3000); let core1 = router.route(&event1).unwrap();
465 let core2 = router.route(&event2).unwrap();
466 let core3 = router.route(&event3).unwrap();
467
468 assert_eq!(core1, core2);
470 assert!(core1 < 4 && core3 < 4);
472 }
473
474 #[test]
475 fn test_route_column_not_found() {
476 let router = KeyRouter::new(4, KeySpec::Columns(vec!["nonexistent".to_string()]));
477 let event = make_event(100, "alice", 1000);
478
479 let result = router.route(&event);
480 assert!(matches!(
481 result,
482 Err(super::super::TpcError::RouterError(
483 RouterError::ColumnNotFoundByName
484 ))
485 ));
486 }
487
488 #[test]
489 fn test_route_index_out_of_range() {
490 let router = KeyRouter::new(4, KeySpec::ColumnIndices(vec![10])); let event = make_event(100, "alice", 1000);
492
493 let result = router.route(&event);
494 assert!(matches!(
495 result,
496 Err(super::super::TpcError::RouterError(
497 RouterError::ColumnIndexOutOfRange
498 ))
499 ));
500 }
501
502 #[test]
503 fn test_router_error_no_allocation() {
504 let err1 = RouterError::ColumnNotFoundByName;
506 let err2 = err1; assert_eq!(err1, err2);
508
509 let err3 = RouterError::ColumnIndexOutOfRange;
510 let err4 = RouterError::RowIndexOutOfRange;
511 let err5 = RouterError::UnsupportedDataType;
512 let err6 = RouterError::EmptyBatch;
513
514 assert_ne!(err1, err3);
516 assert_ne!(err3, err4);
517 assert_ne!(err4, err5);
518 assert_ne!(err5, err6);
519 }
520
521 #[test]
522 fn test_distribution() {
523 let router = KeyRouter::new(4, KeySpec::Columns(vec!["user_id".to_string()]));
525
526 let mut counts = [0usize; 4];
527 for user_id in 0..1000 {
528 let event = make_event(user_id, "user", 1000);
529 let core = router.route(&event).unwrap();
530 counts[core] += 1;
531 }
532
533 for count in &counts {
535 assert!(*count > 150, "Core count too low: {count}");
536 assert!(*count < 350, "Core count too high: {count}");
537 }
538 }
539
540 #[test]
541 fn test_route_row() {
542 let router = KeyRouter::new(4, KeySpec::Columns(vec!["user_id".to_string()]));
543
544 let user_ids = Arc::new(Int64Array::from(vec![100, 200, 100, 300]));
546 let names = Arc::new(StringArray::from(vec!["a", "b", "c", "d"]));
547 let batch =
548 RecordBatch::try_from_iter(vec![("user_id", user_ids as _), ("name", names as _)])
549 .unwrap();
550
551 let core0 = router.route_row(&batch, 0).unwrap();
553 let core2 = router.route_row(&batch, 2).unwrap();
554 assert_eq!(core0, core2); for row in 0..4 {
558 let core = router.route_row(&batch, row).unwrap();
559 assert!(core < 4);
560 }
561 }
562
563 #[test]
564 fn test_null_handling() {
565 let router = KeyRouter::new(4, KeySpec::ColumnIndices(vec![0]));
566
567 let user_ids = Arc::new(Int64Array::from(vec![Some(100), None, Some(100)]));
569 let batch = RecordBatch::try_from_iter(vec![("user_id", user_ids as _)]).unwrap();
570
571 let core0 = router.route_row(&batch, 0).unwrap();
573 let core1 = router.route_row(&batch, 1).unwrap(); let core2 = router.route_row(&batch, 2).unwrap();
575
576 assert_eq!(core0, core2);
578 assert!(core1 < 4);
580 }
581
582 #[test]
583 #[should_panic(expected = "num_cores must be > 0")]
584 fn test_zero_cores_panics() {
585 let _ = KeyRouter::new(0, KeySpec::RoundRobin);
586 }
587
588 #[test]
589 fn test_debug() {
590 let router = KeyRouter::new(4, KeySpec::Columns(vec!["user_id".to_string()]));
591 let debug_str = format!("{router:?}");
592 assert!(debug_str.contains("KeyRouter"));
593 assert!(debug_str.contains("num_cores"));
594 }
595}