1use ndarray::Array1;
3use rayon::prelude::*;
4use std::sync::mpsc::{self, Receiver, Sender};
5use std::sync::{Arc, Mutex};
6use std::thread;
7use std::time::{Duration, Instant};
8
9pub struct WorkStealingPool {
11 sender: Sender<Task>,
12 workers: Vec<WorkerHandle>,
13 shutdown: Arc<Mutex<bool>>,
14}
15
16pub enum Task {
18 VectorSimilarity {
19 query: Array1<f32>,
20 targets: Vec<Array1<f32>>,
21 result_sender: Sender<Vec<f32>>,
22 },
23 BatchEvaluation {
24 positions: Vec<String>, result_sender: Sender<Vec<f32>>,
26 },
27 DataProcessing {
28 data: Vec<u8>,
29 processor: Box<dyn Fn(&[u8]) -> Vec<u8> + Send + Sync>,
30 result_sender: Sender<Vec<u8>>,
31 },
32 Shutdown,
33}
34
35struct WorkerHandle {
37 handle: thread::JoinHandle<()>,
38 id: usize,
39}
40
41impl WorkStealingPool {
42 pub fn new(num_workers: usize) -> Self {
44 let (sender, receiver) = mpsc::channel();
45 let receiver = Arc::new(Mutex::new(receiver));
46 let shutdown = Arc::new(Mutex::new(false));
47
48 let mut workers = Vec::new();
49
50 for id in 0..num_workers {
51 let receiver = Arc::clone(&receiver);
52 let shutdown = Arc::clone(&shutdown);
53
54 let handle = thread::spawn(move || {
55 Self::worker_loop(id, receiver, shutdown);
56 });
57
58 workers.push(WorkerHandle { handle, id });
59 }
60
61 Self {
62 sender,
63 workers,
64 shutdown,
65 }
66 }
67
68 pub fn submit(&self, task: Task) -> Result<(), &'static str> {
70 self.sender.send(task).map_err(|_| "Failed to submit task")
71 }
72
73 fn worker_loop(
75 _worker_id: usize,
76 receiver: Arc<Mutex<Receiver<Task>>>,
77 shutdown: Arc<Mutex<bool>>,
78 ) {
79 loop {
80 if let Ok(shutdown_flag) = shutdown.lock() {
82 if *shutdown_flag {
83 break;
84 }
85 }
86
87 let task = {
89 if let Ok(receiver) = receiver.lock() {
90 match receiver.try_recv() {
91 Ok(task) => Some(task),
92 Err(_) => None,
93 }
94 } else {
95 None
96 }
97 };
98
99 if let Some(task) = task {
100 match task {
101 Task::VectorSimilarity {
102 query,
103 targets,
104 result_sender,
105 } => {
106 let similarities = Self::compute_vector_similarities(&query, &targets);
107 let _ = result_sender.send(similarities);
108 }
109 Task::BatchEvaluation {
110 positions,
111 result_sender,
112 } => {
113 let evaluations = Self::compute_batch_evaluations(&positions);
114 let _ = result_sender.send(evaluations);
115 }
116 Task::DataProcessing {
117 data,
118 processor,
119 result_sender,
120 } => {
121 let result = processor(&data);
122 let _ = result_sender.send(result);
123 }
124 Task::Shutdown => break,
125 }
126 } else {
127 thread::sleep(Duration::from_millis(1));
129 }
130 }
131 }
132
133 fn compute_vector_similarities(query: &Array1<f32>, targets: &[Array1<f32>]) -> Vec<f32> {
135 use crate::utils::simd::SimdVectorOps;
136
137 targets
138 .par_iter()
139 .map(|target| SimdVectorOps::cosine_similarity(query, target))
140 .collect()
141 }
142
143 fn compute_batch_evaluations(_positions: &[String]) -> Vec<f32> {
145 vec![0.0; _positions.len()]
148 }
149
150 pub fn shutdown(self) {
152 if let Ok(mut shutdown_flag) = self.shutdown.lock() {
154 *shutdown_flag = true;
155 }
156
157 for _ in 0..self.workers.len() {
159 let _ = self.sender.send(Task::Shutdown);
160 }
161
162 for worker in self.workers {
164 let _ = worker.handle.join();
165 }
166 }
167}
168
169pub struct ParallelSimilarityProcessor {
171 pool: WorkStealingPool,
172 batch_size: usize,
173}
174
175impl ParallelSimilarityProcessor {
176 pub fn new(num_workers: usize, batch_size: usize) -> Self {
178 Self {
179 pool: WorkStealingPool::new(num_workers),
180 batch_size,
181 }
182 }
183
184 pub fn process_similarities(&self, query: Array1<f32>, targets: Vec<Array1<f32>>) -> Vec<f32> {
186 let chunk_size = self.batch_size;
187 let chunks: Vec<_> = targets.chunks(chunk_size).collect();
188 let mut result_receivers = Vec::new();
189
190 for chunk in chunks {
192 let (result_sender, result_receiver) = mpsc::channel();
193
194 let task = Task::VectorSimilarity {
195 query: query.clone(),
196 targets: chunk.to_vec(),
197 result_sender,
198 };
199
200 if self.pool.submit(task).is_ok() {
201 result_receivers.push(result_receiver);
202 }
203 }
204
205 let mut all_similarities = Vec::new();
207 for receiver in result_receivers {
208 if let Ok(similarities) = receiver.recv() {
209 all_similarities.extend(similarities);
210 }
211 }
212
213 all_similarities
214 }
215}
216
217pub struct ParallelDataPipeline<T, U> {
219 input_queue: Arc<Mutex<Vec<T>>>,
220 output_queue: Arc<Mutex<Vec<U>>>,
221 processors: Vec<thread::JoinHandle<()>>,
222 shutdown: Arc<Mutex<bool>>,
223}
224
225impl<T, U> ParallelDataPipeline<T, U>
226where
227 T: Send + 'static,
228 U: Send + 'static,
229{
230 pub fn new<F>(num_processors: usize, processor: F) -> Self
232 where
233 F: Fn(T) -> U + Send + Sync + Clone + 'static,
234 {
235 let input_queue = Arc::new(Mutex::new(Vec::new()));
236 let output_queue = Arc::new(Mutex::new(Vec::new()));
237 let shutdown = Arc::new(Mutex::new(false));
238 let mut processors = Vec::new();
239
240 for _ in 0..num_processors {
241 let input_queue = Arc::clone(&input_queue);
242 let output_queue = Arc::clone(&output_queue);
243 let shutdown = Arc::clone(&shutdown);
244 let processor = processor.clone();
245
246 let handle = thread::spawn(move || {
247 loop {
248 if let Ok(shutdown_flag) = shutdown.lock() {
250 if *shutdown_flag {
251 break;
252 }
253 }
254
255 let work_item = {
257 if let Ok(mut queue) = input_queue.lock() {
258 queue.pop()
259 } else {
260 None
261 }
262 };
263
264 if let Some(item) = work_item {
265 let result = processor(item);
267
268 if let Ok(mut queue) = output_queue.lock() {
270 queue.push(result);
271 }
272 } else {
273 thread::sleep(Duration::from_millis(1));
275 }
276 }
277 });
278
279 processors.push(handle);
280 }
281
282 Self {
283 input_queue,
284 output_queue,
285 processors,
286 shutdown,
287 }
288 }
289
290 pub fn enqueue(&self, items: Vec<T>) {
292 if let Ok(mut queue) = self.input_queue.lock() {
293 queue.extend(items);
294 }
295 }
296
297 pub fn dequeue_results(&self) -> Vec<U> {
299 if let Ok(mut queue) = self.output_queue.lock() {
300 std::mem::take(&mut *queue)
301 } else {
302 Vec::new()
303 }
304 }
305
306 pub fn pending_count(&self) -> usize {
308 self.input_queue.lock().map(|q| q.len()).unwrap_or(0)
309 }
310
311 pub fn result_count(&self) -> usize {
313 self.output_queue.lock().map(|q| q.len()).unwrap_or(0)
314 }
315
316 pub fn shutdown(self) {
318 if let Ok(mut flag) = self.shutdown.lock() {
320 *flag = true;
321 }
322
323 for handle in self.processors {
325 let _ = handle.join();
326 }
327 }
328}
329
330pub struct ParallelPositionEvaluator {
332 workers: Vec<EvaluationWorker>,
333 current_worker: Arc<Mutex<usize>>,
334}
335
336struct EvaluationWorker {
337 sender: Sender<EvaluationRequest>,
338 _handle: thread::JoinHandle<()>,
339}
340
341struct EvaluationRequest {
342 position: String, response: Sender<f32>,
344}
345
346impl ParallelPositionEvaluator {
347 pub fn new(num_workers: usize) -> Self {
349 let mut workers = Vec::new();
350
351 for _ in 0..num_workers {
352 let (sender, receiver) = mpsc::channel::<EvaluationRequest>();
353
354 let handle = thread::spawn(move || {
355 for request in receiver {
356 let evaluation = Self::evaluate_position_sync(&request.position);
357 let _ = request.response.send(evaluation);
358 }
359 });
360
361 workers.push(EvaluationWorker {
362 sender,
363 _handle: handle,
364 });
365 }
366
367 Self {
368 workers,
369 current_worker: Arc::new(Mutex::new(0)),
370 }
371 }
372
373 pub fn evaluate_positions(&self, positions: Vec<String>) -> Vec<f32> {
375 let mut response_receivers = Vec::new();
376
377 for position in positions {
379 let worker_idx = {
380 if let Ok(mut idx) = self.current_worker.lock() {
381 let current = *idx;
382 *idx = (current + 1) % self.workers.len();
383 current
384 } else {
385 0
386 }
387 };
388
389 let (response_sender, response_receiver) = mpsc::channel();
390 let request = EvaluationRequest {
391 position,
392 response: response_sender,
393 };
394
395 if self.workers[worker_idx].sender.send(request).is_ok() {
396 response_receivers.push(response_receiver);
397 }
398 }
399
400 let mut evaluations = Vec::new();
402 for receiver in response_receivers {
403 if let Ok(evaluation) = receiver.recv() {
404 evaluations.push(evaluation);
405 }
406 }
407
408 evaluations
409 }
410
411 fn evaluate_position_sync(_fen: &str) -> f32 {
413 0.0
416 }
417}
418
419pub struct ParallelVectorOps;
421
422impl ParallelVectorOps {
423 pub fn parallel_dot_products(vectors_a: &[Array1<f32>], vectors_b: &[Array1<f32>]) -> Vec<f32> {
425 use crate::utils::simd::SimdVectorOps;
426
427 vectors_a
428 .par_iter()
429 .zip(vectors_b.par_iter())
430 .map(|(a, b)| SimdVectorOps::dot_product(a, b))
431 .collect()
432 }
433
434 pub fn parallel_similarity_matrix(vectors: &[Array1<f32>]) -> Vec<Vec<f32>> {
436 use crate::utils::simd::SimdVectorOps;
437
438 vectors
440 .par_iter()
441 .enumerate()
442 .map(|(_i, vec_a)| {
443 if vectors.len() > 100 {
445 vectors
446 .par_iter()
447 .map(|vec_b| SimdVectorOps::cosine_similarity(vec_a, vec_b))
448 .collect()
449 } else {
450 vectors
451 .iter()
452 .map(|vec_b| SimdVectorOps::cosine_similarity(vec_a, vec_b))
453 .collect()
454 }
455 })
456 .collect()
457 }
458
459 pub fn parallel_knn_search(
461 query: &Array1<f32>,
462 dataset: &[Array1<f32>],
463 k: usize,
464 ) -> Vec<(usize, f32)> {
465 use crate::utils::simd::SimdVectorOps;
466
467 let similarities: Vec<(usize, f32)> = dataset
469 .par_iter()
470 .enumerate()
471 .map(|(idx, vector)| {
472 let similarity = SimdVectorOps::cosine_similarity(query, vector);
473 (idx, similarity)
474 })
475 .collect();
476
477 let mut similarities = similarities;
479 similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
480 similarities.truncate(k);
481
482 similarities
483 }
484
485 pub fn parallel_batch_normalize(vectors: &mut [Array1<f32>]) {
487 use crate::utils::simd::SimdVectorOps;
488
489 vectors.par_iter_mut().for_each(|vector| {
490 let norm = SimdVectorOps::squared_norm(vector).sqrt();
491 if norm > 0.0 {
492 *vector = SimdVectorOps::scale_vector(vector, 1.0 / norm);
493 }
494 });
495 }
496}
497
498pub struct ParallelPerformanceMonitor {
500 start_time: Instant,
501 task_counts: Arc<Mutex<Vec<usize>>>,
502 total_tasks: Arc<Mutex<usize>>,
503}
504
505impl ParallelPerformanceMonitor {
506 pub fn new(num_workers: usize) -> Self {
508 Self {
509 start_time: Instant::now(),
510 task_counts: Arc::new(Mutex::new(vec![0; num_workers])),
511 total_tasks: Arc::new(Mutex::new(0)),
512 }
513 }
514
515 pub fn record_task_completion(&self, worker_id: usize) {
517 if let Ok(mut counts) = self.task_counts.lock() {
518 if worker_id < counts.len() {
519 counts[worker_id] += 1;
520 }
521 }
522
523 if let Ok(mut total) = self.total_tasks.lock() {
524 *total += 1;
525 }
526 }
527
528 pub fn get_stats(&self) -> ParallelPerformanceStats {
530 let elapsed = self.start_time.elapsed();
531 let total_tasks = self.total_tasks.lock().map(|t| *t).unwrap_or(0);
532 let task_counts = self
533 .task_counts
534 .lock()
535 .map(|counts| counts.clone())
536 .unwrap_or_default();
537
538 let tasks_per_second = if elapsed.as_secs_f64() > 0.0 {
539 total_tasks as f64 / elapsed.as_secs_f64()
540 } else {
541 0.0
542 };
543
544 let mean_tasks = if !task_counts.is_empty() {
546 task_counts.iter().sum::<usize>() as f64 / task_counts.len() as f64
547 } else {
548 0.0
549 };
550
551 let variance = if !task_counts.is_empty() {
552 task_counts
553 .iter()
554 .map(|&count| {
555 let diff = count as f64 - mean_tasks;
556 diff * diff
557 })
558 .sum::<f64>()
559 / task_counts.len() as f64
560 } else {
561 0.0
562 };
563
564 let load_balance = variance.sqrt() / mean_tasks.max(1.0);
565
566 ParallelPerformanceStats {
567 elapsed_time: elapsed,
568 total_tasks,
569 tasks_per_second,
570 worker_task_counts: task_counts,
571 load_balance_factor: load_balance,
572 }
573 }
574}
575
576#[derive(Debug, Clone)]
578pub struct ParallelPerformanceStats {
579 pub elapsed_time: Duration,
580 pub total_tasks: usize,
581 pub tasks_per_second: f64,
582 pub worker_task_counts: Vec<usize>,
583 pub load_balance_factor: f64, }
585
586#[cfg(test)]
587mod tests {
588 use super::*;
589 use std::time::Duration;
590
591 #[test]
592 fn test_work_stealing_pool() {
593 let pool = WorkStealingPool::new(2);
594 let (result_sender, result_receiver) = mpsc::channel();
595
596 let query = Array1::from_vec(vec![1.0, 0.0, 0.0]);
597 let targets = vec![
598 Array1::from_vec(vec![1.0, 0.0, 0.0]),
599 Array1::from_vec(vec![0.0, 1.0, 0.0]),
600 ];
601
602 let task = Task::VectorSimilarity {
603 query,
604 targets,
605 result_sender,
606 };
607
608 pool.submit(task).unwrap();
609
610 let result = result_receiver.recv_timeout(Duration::from_secs(1));
611 assert!(result.is_ok());
612
613 pool.shutdown();
614 }
615
616 #[test]
617 fn test_parallel_similarity_processor() {
618 let processor = ParallelSimilarityProcessor::new(2, 10);
619
620 let query = Array1::from_vec(vec![1.0, 0.0, 0.0]);
621 let targets = vec![
622 Array1::from_vec(vec![1.0, 0.0, 0.0]),
623 Array1::from_vec(vec![0.0, 1.0, 0.0]),
624 Array1::from_vec(vec![0.0, 0.0, 1.0]),
625 ];
626
627 let similarities = processor.process_similarities(query, targets);
628 assert_eq!(similarities.len(), 3);
629
630 assert!((similarities[0] - 1.0).abs() < 1e-6);
632 }
633
634 #[test]
635 fn test_parallel_data_pipeline() {
636 let pipeline = ParallelDataPipeline::new(2, |x: i32| x * 2);
637
638 pipeline.enqueue(vec![1, 2, 3, 4, 5]);
640
641 thread::sleep(Duration::from_millis(100));
643
644 let results = pipeline.dequeue_results();
646 assert!(!results.is_empty());
647
648 pipeline.shutdown();
649 }
650
651 #[test]
652 fn test_parallel_vector_ops() {
653 let vectors_a = vec![
654 Array1::from_vec(vec![1.0, 2.0]),
655 Array1::from_vec(vec![3.0, 4.0]),
656 ];
657 let vectors_b = vec![
658 Array1::from_vec(vec![2.0, 1.0]),
659 Array1::from_vec(vec![1.0, 2.0]),
660 ];
661
662 let dot_products = ParallelVectorOps::parallel_dot_products(&vectors_a, &vectors_b);
663 assert_eq!(dot_products.len(), 2);
664
665 assert!((dot_products[0] - 4.0).abs() < 1e-6);
667 }
668
669 #[test]
670 fn test_parallel_knn_search() {
671 let query = Array1::from_vec(vec![1.0, 0.0]);
672 let dataset = vec![
673 Array1::from_vec(vec![1.0, 0.0]), Array1::from_vec(vec![0.0, 1.0]), Array1::from_vec(vec![0.5, 0.5]), ];
677
678 let results = ParallelVectorOps::parallel_knn_search(&query, &dataset, 2);
679 assert_eq!(results.len(), 2);
680
681 assert_eq!(results[0].0, 0);
683 assert!((results[0].1 - 1.0).abs() < 1e-6);
684 }
685
686 #[test]
687 fn test_performance_monitor() {
688 let monitor = ParallelPerformanceMonitor::new(3);
689
690 monitor.record_task_completion(0);
692 monitor.record_task_completion(1);
693 monitor.record_task_completion(0);
694 monitor.record_task_completion(2);
695
696 let stats = monitor.get_stats();
697 assert_eq!(stats.total_tasks, 4);
698 assert_eq!(stats.worker_task_counts[0], 2);
699 assert_eq!(stats.worker_task_counts[1], 1);
700 assert_eq!(stats.worker_task_counts[2], 1);
701 }
702}