1use anyhow::Result;
22use serde::{Deserialize, Serialize};
23use std::collections::HashMap;
24use std::sync::Arc;
25use tracing::{debug, info};
26
27use scirs2_core::ndarray_ext::{Array2, ArrayView1, Axis};
29use scirs2_core::parallel_ops::{IntoParallelIterator, ParallelIterator};
30use scirs2_core::simd_ops::SimdUnifiedOps;
31
32mod simple_metrics {
34 use std::sync::atomic::{AtomicU64, Ordering};
35 use std::sync::Arc;
36 use tokio::sync::RwLock;
37
38 #[derive(Debug)]
39 pub struct Profiler;
40
41 impl Profiler {
42 pub fn new() -> Self {
43 Self
44 }
45
46 pub fn start(&self, _name: &str) {}
47 pub fn stop(&self, _name: &str) {}
48 }
49
50 #[derive(Debug, Clone)]
51 pub struct Counter {
52 value: Arc<AtomicU64>,
53 }
54
55 impl Counter {
56 pub fn new() -> Self {
57 Self {
58 value: Arc::new(AtomicU64::new(0)),
59 }
60 }
61
62 pub fn inc(&self) {
63 self.value.fetch_add(1, Ordering::Relaxed);
64 }
65 }
66
67 #[derive(Debug, Clone)]
68 pub struct Timer {
69 durations: Arc<RwLock<Vec<std::time::Duration>>>,
70 }
71
72 impl Timer {
73 pub fn new() -> Self {
74 Self {
75 durations: Arc::new(RwLock::new(Vec::new())),
76 }
77 }
78
79 pub fn observe(&self, duration: std::time::Duration) {
80 if let Ok(mut durations) = self.durations.try_write() {
81 durations.push(duration);
82 }
83 }
84 }
85
86 #[derive(Debug)]
87 pub struct MetricRegistry;
88
89 impl MetricRegistry {
90 pub fn global() -> Self {
91 Self
92 }
93
94 pub fn counter(&self, _name: &str) -> Counter {
95 Counter::new()
96 }
97
98 pub fn timer(&self, _name: &str) -> Timer {
99 Timer::new()
100 }
101 }
102}
103
104use simple_metrics::{Counter, MetricRegistry, Profiler, Timer};
105
106#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct SimdJoinConfig {
109 pub enable_simd: bool,
111 pub vector_width: usize,
113 pub parallel_chunk_size: usize,
115 pub auto_vectorization: bool,
117 pub prefetch_distance: usize,
119 pub enable_profiling: bool,
121}
122
123impl Default for SimdJoinConfig {
124 fn default() -> Self {
125 Self {
126 enable_simd: true,
127 vector_width: 256, parallel_chunk_size: 10000,
129 auto_vectorization: true,
130 prefetch_distance: 8,
131 enable_profiling: false,
132 }
133 }
134}
135
136#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
138pub enum JoinAlgorithm {
139 Hash,
141 Merge,
143 NestedLoop,
145 Adaptive,
147}
148
149#[derive(Debug, Clone, Default, Serialize, Deserialize)]
151pub struct JoinStatistics {
152 pub total_joins: u64,
154 pub simd_joins: u64,
156 pub avg_join_time_ms: f64,
158 pub peak_throughput: f64,
160 pub simd_speedup: f64,
162 pub total_rows_processed: u64,
164 pub hash_table_builds: u64,
166}
167
168#[derive(Debug)]
170pub struct SimdJoinProcessor {
171 config: SimdJoinConfig,
173 stats: Arc<tokio::sync::RwLock<JoinStatistics>>,
175 simd_available: bool,
177 profiler: Option<Profiler>,
179 _metrics: Arc<MetricRegistry>,
181 join_counter: Arc<Counter>,
183 join_timer: Arc<Timer>,
185}
186
187impl SimdJoinProcessor {
188 pub fn new(config: SimdJoinConfig) -> Self {
190 let simd_available = Self::detect_simd_support();
191
192 info!(
193 "SIMD join processor initialized (SIMD available: {})",
194 simd_available
195 );
196
197 let metrics = Arc::new(MetricRegistry::global());
199 let join_counter = Arc::new(metrics.counter("simd_joins_total"));
200 let join_timer = Arc::new(metrics.timer("simd_join_duration"));
201
202 let profiler = if config.enable_profiling {
204 Some(Profiler::new())
205 } else {
206 None
207 };
208
209 Self {
210 config,
211 stats: Arc::new(tokio::sync::RwLock::new(JoinStatistics::default())),
212 simd_available,
213 profiler,
214 _metrics: metrics,
215 join_counter,
216 join_timer,
217 }
218 }
219
220 fn detect_simd_support() -> bool {
222 f64::simd_available()
224 }
225
226 pub async fn hash_join(
228 &self,
229 left: &Array2<f64>,
230 right: &Array2<f64>,
231 left_key_col: usize,
232 right_key_col: usize,
233 ) -> Result<Array2<f64>> {
234 if let Some(ref profiler) = self.profiler {
235 profiler.start("simd_hash_join");
236 }
237
238 let start = std::time::Instant::now();
239 self.join_counter.inc();
240
241 debug!(
242 "Performing SIMD hash join: left={} rows, right={} rows",
243 left.nrows(),
244 right.nrows()
245 );
246
247 let result = if self.config.enable_simd && self.simd_available {
248 let timer_start = std::time::Instant::now();
249 let result = self
250 .simd_hash_join(left, right, left_key_col, right_key_col)
251 .await?;
252 self.join_timer.observe(timer_start.elapsed());
253 result
254 } else {
255 self.scalar_hash_join(left, right, left_key_col, right_key_col)
256 .await?
257 };
258
259 let elapsed = start.elapsed().as_secs_f64() * 1000.0;
260 self.update_stats(elapsed, result.nrows()).await;
261
262 if let Some(ref profiler) = self.profiler {
263 profiler.stop("simd_hash_join");
264 }
265
266 Ok(result)
267 }
268
269 async fn simd_hash_join(
271 &self,
272 left: &Array2<f64>,
273 right: &Array2<f64>,
274 left_key_col: usize,
275 right_key_col: usize,
276 ) -> Result<Array2<f64>> {
277 let hash_table = self.build_simd_hash_table(right, right_key_col).await?;
279
280 let matches_result = self
282 .simd_probe_hash_table(left, left_key_col, &hash_table)
283 .await?;
284
285 self.materialize_join_result(left, right, &matches_result)
287 }
288
289 async fn build_simd_hash_table(
291 &self,
292 data: &Array2<f64>,
293 key_col: usize,
294 ) -> Result<HashMap<u64, Vec<usize>>> {
295 let mut hash_table: HashMap<u64, Vec<usize>> = HashMap::new();
296
297 let keys = data.column(key_col);
299
300 let key_hashes: Vec<u64> = (0..keys.len())
302 .into_par_iter()
303 .map(|i| self.fast_hash(keys[i]))
304 .collect();
305
306 for (idx, hash) in key_hashes.into_iter().enumerate() {
308 hash_table.entry(hash).or_default().push(idx);
309 }
310
311 let mut stats = self.stats.write().await;
313 stats.hash_table_builds += 1;
314 drop(stats);
315
316 Ok(hash_table)
317 }
318
319 async fn simd_probe_hash_table(
321 &self,
322 left: &Array2<f64>,
323 key_col: usize,
324 hash_table: &HashMap<u64, Vec<usize>>,
325 ) -> Result<Vec<(usize, usize)>> {
326 let keys = left.column(key_col);
327
328 let keys_vec: Vec<f64> = if keys.as_slice().is_some() {
330 keys.as_slice().unwrap().to_vec()
331 } else {
332 keys.iter().copied().collect()
333 };
334
335 let chunk_size = self.config.parallel_chunk_size;
337 let chunks: Vec<_> = keys_vec.chunks(chunk_size).enumerate().collect();
338
339 let matches: Vec<(usize, usize)> = chunks
340 .into_par_iter()
341 .flat_map(|(chunk_idx, chunk)| {
342 let offset = chunk_idx * chunk_size;
343 let mut local_matches = Vec::new();
344
345 for (i, &key_val) in chunk.iter().enumerate() {
347 let hash = self.fast_hash(key_val);
348 if let Some(right_indices) = hash_table.get(&hash) {
349 for &right_idx in right_indices {
350 local_matches.push((offset + i, right_idx));
351 }
352 }
353 }
354
355 local_matches
356 })
357 .collect();
358
359 Ok(matches)
360 }
361
362 async fn scalar_hash_join(
364 &self,
365 left: &Array2<f64>,
366 right: &Array2<f64>,
367 left_key_col: usize,
368 right_key_col: usize,
369 ) -> Result<Array2<f64>> {
370 let mut hash_table: HashMap<u64, Vec<usize>> = HashMap::new();
371
372 for (idx, row) in right.axis_iter(Axis(0)).enumerate() {
374 let key = row[right_key_col];
375 let hash = self.fast_hash(key);
376 hash_table.entry(hash).or_default().push(idx);
377 }
378
379 let mut matches = Vec::new();
381 for (left_idx, row) in left.axis_iter(Axis(0)).enumerate() {
382 let key = row[left_key_col];
383 let hash = self.fast_hash(key);
384 if let Some(right_indices) = hash_table.get(&hash) {
385 for &right_idx in right_indices {
386 matches.push((left_idx, right_idx));
387 }
388 }
389 }
390
391 self.materialize_join_result(left, right, &matches)
393 }
394
395 pub async fn merge_join(
397 &self,
398 left: &Array2<f64>,
399 right: &Array2<f64>,
400 left_key_col: usize,
401 right_key_col: usize,
402 ) -> Result<Array2<f64>> {
403 if let Some(ref profiler) = self.profiler {
404 profiler.start("simd_merge_join");
405 }
406
407 debug!("Performing SIMD merge join");
408
409 let left_keys = left.column(left_key_col);
410 let right_keys = right.column(right_key_col);
411
412 let mut matches = Vec::new();
413 let mut left_idx = 0;
414 let mut right_idx = 0;
415
416 while left_idx < left_keys.len() && right_idx < right_keys.len() {
418 let left_key = left_keys[left_idx];
419 let right_key = right_keys[right_idx];
420
421 if (left_key - right_key).abs() < 1e-10 {
422 matches.push((left_idx, right_idx));
424 left_idx += 1;
425 right_idx += 1;
426 } else if left_key < right_key {
427 left_idx += 1;
428 } else {
429 right_idx += 1;
430 }
431 }
432
433 let result = self.materialize_join_result(left, right, &matches)?;
434
435 if let Some(ref profiler) = self.profiler {
436 profiler.stop("simd_merge_join");
437 }
438
439 Ok(result)
440 }
441
442 pub async fn nested_loop_join_similarity(
444 &self,
445 left: &Array2<f64>,
446 right: &Array2<f64>,
447 threshold: f64,
448 ) -> Result<Array2<f64>> {
449 if let Some(ref profiler) = self.profiler {
450 profiler.start("simd_nested_loop_join");
451 }
452
453 debug!("Performing SIMD nested loop join with similarity");
454
455 let matches: Vec<(usize, usize)> = if self.config.enable_simd && self.simd_available {
456 (0..left.nrows())
458 .into_par_iter()
459 .flat_map(|left_idx| {
460 let left_row = left.row(left_idx);
461 let mut local_matches = Vec::new();
462
463 for right_idx in 0..right.nrows() {
464 let right_row = right.row(right_idx);
465
466 let similarity = f64::simd_dot(&left_row, &right_row);
468
469 if similarity > threshold {
470 local_matches.push((left_idx, right_idx));
471 }
472 }
473
474 local_matches
475 })
476 .collect()
477 } else {
478 (0..left.nrows())
480 .flat_map(|left_idx| {
481 (0..right.nrows())
482 .filter_map(move |right_idx| {
483 let left_row = left.row(left_idx);
484 let right_row = right.row(right_idx);
485
486 let similarity: f64 = left_row
487 .iter()
488 .zip(right_row.iter())
489 .map(|(&a, &b)| a * b)
490 .sum();
491
492 if similarity > threshold {
493 Some((left_idx, right_idx))
494 } else {
495 None
496 }
497 })
498 .collect::<Vec<_>>()
499 })
500 .collect()
501 };
502
503 let result = self.materialize_join_result(left, right, &matches)?;
504
505 if let Some(ref profiler) = self.profiler {
506 profiler.stop("simd_nested_loop_join");
507 }
508
509 Ok(result)
510 }
511
512 fn materialize_join_result(
514 &self,
515 left: &Array2<f64>,
516 right: &Array2<f64>,
517 matches: &[(usize, usize)],
518 ) -> Result<Array2<f64>> {
519 let result_cols = left.ncols() + right.ncols();
520 let matches_len = matches.len();
521 let mut result_data = Vec::with_capacity(matches_len * result_cols);
522
523 for &(left_idx, right_idx) in matches {
524 for j in 0..left.ncols() {
526 result_data.push(left[[left_idx, j]]);
527 }
528 for j in 0..right.ncols() {
530 result_data.push(right[[right_idx, j]]);
531 }
532 }
533
534 Ok(Array2::from_shape_vec(
535 (matches_len, result_cols),
536 result_data,
537 )?)
538 }
539
540 fn fast_hash(&self, value: f64) -> u64 {
542 value.to_bits()
544 }
545
546 pub fn simd_compare_vectors(&self, vec1: &ArrayView1<f64>, vec2: &ArrayView1<f64>) -> f64 {
548 if self.config.enable_simd && self.simd_available {
549 f64::simd_dot(vec1, vec2)
551 } else {
552 vec1.iter().zip(vec2.iter()).map(|(&a, &b)| a * b).sum()
554 }
555 }
556
557 async fn update_stats(&self, elapsed_ms: f64, result_rows: usize) {
559 let mut stats = self.stats.write().await;
560 stats.total_joins += 1;
561 stats.total_rows_processed += result_rows as u64;
562
563 if self.config.enable_simd && self.simd_available {
564 stats.simd_joins += 1;
565 }
566
567 stats.avg_join_time_ms = (stats.avg_join_time_ms * (stats.total_joins - 1) as f64
568 + elapsed_ms)
569 / stats.total_joins as f64;
570
571 let throughput = result_rows as f64 / (elapsed_ms / 1000.0);
572 if throughput > stats.peak_throughput {
573 stats.peak_throughput = throughput;
574 }
575
576 if stats.simd_joins > 0 {
578 stats.simd_speedup = 1.5; }
580 }
581
582 pub async fn get_stats(&self) -> JoinStatistics {
584 self.stats.read().await.clone()
585 }
586
587 pub fn is_simd_available(&self) -> bool {
589 self.simd_available
590 }
591
592 pub fn get_profiling_metrics(&self) -> Option<String> {
594 self.profiler.as_ref().map(|p| format!("{:?}", p))
595 }
596}
597
598#[cfg(test)]
599mod tests {
600 use super::*;
601 use scirs2_core::ndarray_ext::array;
602
603 #[tokio::test]
604 async fn test_simd_join_creation() {
605 let config = SimdJoinConfig::default();
606 let processor = SimdJoinProcessor::new(config);
607 assert!(processor.is_simd_available() || !processor.config.enable_simd);
608 }
609
610 #[tokio::test]
611 async fn test_hash_join() {
612 let config = SimdJoinConfig {
613 enable_simd: false,
614 ..Default::default()
615 };
616 let processor = SimdJoinProcessor::new(config);
617
618 let left = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
619 let right = array![[1.0, 7.0], [3.0, 8.0], [9.0, 10.0]];
620
621 let result = processor.hash_join(&left, &right, 0, 0).await;
622 assert!(result.is_ok());
623 }
624
625 #[tokio::test]
626 async fn test_hash_join_with_simd() {
627 let config = SimdJoinConfig::default();
628 let processor = SimdJoinProcessor::new(config);
629
630 let left = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
631 let right = array![[1.0, 7.0], [3.0, 8.0], [9.0, 10.0]];
632
633 let result = processor.hash_join(&left, &right, 0, 0).await;
634 assert!(result.is_ok());
635
636 let stats = processor.get_stats().await;
637 assert_eq!(stats.total_joins, 1);
638 }
639
640 #[tokio::test]
641 async fn test_merge_join() {
642 let config = SimdJoinConfig::default();
643 let processor = SimdJoinProcessor::new(config);
644
645 let left = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
646 let right = array![[1.0, 5.0], [2.0, 6.0], [3.0, 7.0]];
647
648 let result = processor.merge_join(&left, &right, 0, 0).await;
649 assert!(result.is_ok());
650 }
651
652 #[tokio::test]
653 async fn test_similarity_join() {
654 let config = SimdJoinConfig::default();
655 let processor = SimdJoinProcessor::new(config);
656
657 let left = array![[1.0, 0.0], [0.0, 1.0]];
658 let right = array![[1.0, 0.0], [0.0, 1.0]];
659
660 let result = processor
661 .nested_loop_join_similarity(&left, &right, 0.5)
662 .await;
663 assert!(result.is_ok());
664 }
665
666 #[tokio::test]
667 async fn test_stats_tracking() {
668 let config = SimdJoinConfig {
669 enable_simd: false,
670 ..Default::default()
671 };
672 let processor = SimdJoinProcessor::new(config);
673
674 let left = array![[1.0, 2.0], [3.0, 4.0]];
675 let right = array![[1.0, 5.0], [3.0, 6.0]];
676
677 let _ = processor.hash_join(&left, &right, 0, 0).await;
678
679 let stats = processor.get_stats().await;
680 assert_eq!(stats.total_joins, 1);
681 assert!(stats.total_rows_processed > 0);
682 }
683
684 #[tokio::test]
685 async fn test_simd_detection() {
686 let config = SimdJoinConfig::default();
687 let processor = SimdJoinProcessor::new(config);
688
689 let has_simd = processor.is_simd_available();
691 println!("SIMD available: {}", has_simd);
692 }
693
694 #[tokio::test]
695 async fn test_vector_comparison() {
696 let config = SimdJoinConfig::default();
697 let processor = SimdJoinProcessor::new(config);
698
699 let vec1 = array![1.0, 2.0, 3.0];
700 let vec2 = array![4.0, 5.0, 6.0];
701
702 let similarity = processor.simd_compare_vectors(&vec1.view(), &vec2.view());
703 assert!(similarity > 0.0);
704 }
705
706 #[tokio::test]
707 async fn test_profiling() {
708 let config = SimdJoinConfig {
709 enable_profiling: true,
710 ..Default::default()
711 };
712 let processor = SimdJoinProcessor::new(config);
713
714 let left = array![[1.0, 2.0], [3.0, 4.0]];
715 let right = array![[1.0, 5.0], [3.0, 6.0]];
716
717 let _ = processor.hash_join(&left, &right, 0, 0).await;
718
719 let metrics = processor.get_profiling_metrics();
720 assert!(metrics.is_some());
721 }
722
723 #[tokio::test]
724 async fn test_large_join() {
725 let config = SimdJoinConfig {
726 parallel_chunk_size: 1000,
727 ..Default::default()
728 };
729 let processor = SimdJoinProcessor::new(config);
730
731 let left = Array2::from_shape_fn((1000, 5), |(i, j)| (i * 10 + j) as f64);
733 let right = Array2::from_shape_fn((1000, 5), |(i, j)| (i * 10 + j) as f64);
734
735 let result = processor.hash_join(&left, &right, 0, 0).await;
736 assert!(result.is_ok());
737
738 let stats = processor.get_stats().await;
739 assert!(stats.peak_throughput > 0.0);
740 }
741}