1use crate::prelude::SimulatorError;
7use scirs2_core::ndarray::{ArrayD, Dimension, IxDyn};
8use scirs2_core::parallel_ops::{
9 current_num_threads, IndexedParallelIterator, ParallelIterator, ThreadPool, ThreadPoolBuilder,
10};
11use scirs2_core::Complex64;
12use std::collections::{HashMap, HashSet, VecDeque};
13use std::sync::{Arc, Mutex, RwLock};
14use std::thread;
15use std::time::{Duration, Instant};
16
17use crate::error::Result;
18
19#[derive(Debug, Clone)]
21pub struct ParallelTensorConfig {
22 pub num_threads: usize,
24 pub chunk_size: usize,
26 pub enable_work_stealing: bool,
28 pub parallel_threshold_bytes: usize,
30 pub load_balancing: LoadBalancingStrategy,
32 pub numa_aware: bool,
34 pub thread_affinity: ThreadAffinityConfig,
36}
37
38impl Default for ParallelTensorConfig {
39 fn default() -> Self {
40 Self {
41 num_threads: current_num_threads(), chunk_size: 1024,
43 enable_work_stealing: true,
44 parallel_threshold_bytes: 1024 * 1024, load_balancing: LoadBalancingStrategy::DynamicWorkStealing,
46 numa_aware: true,
47 thread_affinity: ThreadAffinityConfig::default(),
48 }
49 }
50}
51
52#[derive(Debug, Clone, Copy, PartialEq, Eq)]
54pub enum LoadBalancingStrategy {
55 RoundRobin,
57 DynamicWorkStealing,
59 NumaAware,
61 CostBased,
63 Adaptive,
65}
66
67#[derive(Debug, Clone, Default)]
69pub struct ThreadAffinityConfig {
70 pub enable_affinity: bool,
72 pub core_mapping: Vec<usize>,
74 pub numa_preferences: HashMap<usize, usize>,
76}
77
78#[derive(Debug, Clone)]
80pub struct TensorWorkUnit {
81 pub id: usize,
83 pub input_tensors: Vec<usize>,
85 pub output_tensor: usize,
87 pub contraction_indices: Vec<Vec<usize>>,
89 pub estimated_cost: f64,
91 pub memory_requirement: usize,
93 pub dependencies: HashSet<usize>,
95 pub priority: i32,
97}
98
99#[derive(Debug)]
101pub struct TensorWorkQueue {
102 pending: Mutex<VecDeque<TensorWorkUnit>>,
104 completed: RwLock<HashSet<usize>>,
106 in_progress: RwLock<HashMap<usize, Instant>>,
108 total_units: usize,
110 config: ParallelTensorConfig,
112}
113
114impl TensorWorkQueue {
115 #[must_use]
117 pub fn new(work_units: Vec<TensorWorkUnit>, config: ParallelTensorConfig) -> Self {
118 let total_units = work_units.len();
119 let mut pending = VecDeque::from(work_units);
120
121 pending.make_contiguous().sort_by(|a, b| {
123 b.priority
124 .cmp(&a.priority)
125 .then_with(|| a.dependencies.len().cmp(&b.dependencies.len()))
126 });
127
128 Self {
129 pending: Mutex::new(pending),
130 completed: RwLock::new(HashSet::new()),
131 in_progress: RwLock::new(HashMap::new()),
132 total_units,
133 config,
134 }
135 }
136
137 pub fn get_work(&self) -> Option<TensorWorkUnit> {
139 let mut pending = self
141 .pending
142 .lock()
143 .expect("pending lock should not be poisoned");
144 let completed = self
145 .completed
146 .read()
147 .expect("completed lock should not be poisoned");
148
149 for i in 0..pending.len() {
151 let work_unit = &pending[i];
152 let dependencies_satisfied = work_unit
153 .dependencies
154 .iter()
155 .all(|dep| completed.contains(dep));
156
157 if dependencies_satisfied {
158 let work_unit = pending
160 .remove(i)
161 .expect("index i is guaranteed to be within bounds");
162
163 drop(completed);
165 let mut in_progress = self
166 .in_progress
167 .write()
168 .expect("in_progress lock should not be poisoned");
169 in_progress.insert(work_unit.id, Instant::now());
170
171 return Some(work_unit);
172 }
173 }
174
175 None
176 }
177
178 pub fn complete_work(&self, work_id: usize) {
180 let mut completed = self
181 .completed
182 .write()
183 .expect("completed lock should not be poisoned");
184 completed.insert(work_id);
185
186 let mut in_progress = self
187 .in_progress
188 .write()
189 .expect("in_progress lock should not be poisoned");
190 in_progress.remove(&work_id);
191 }
192
193 pub fn is_complete(&self) -> bool {
195 let completed = self
196 .completed
197 .read()
198 .expect("completed lock should not be poisoned");
199 completed.len() == self.total_units
200 }
201
202 pub fn get_progress(&self) -> (usize, usize, usize) {
204 let completed = self
205 .completed
206 .read()
207 .expect("completed lock should not be poisoned")
208 .len();
209 let in_progress = self
210 .in_progress
211 .read()
212 .expect("in_progress lock should not be poisoned")
213 .len();
214 let pending = self
215 .pending
216 .lock()
217 .expect("pending lock should not be poisoned")
218 .len();
219 (completed, in_progress, pending)
220 }
221}
222
223pub struct ParallelTensorEngine {
225 config: ParallelTensorConfig,
227 thread_pool: ThreadPool, stats: Arc<Mutex<ParallelTensorStats>>,
231}
232
233#[derive(Debug, Clone, Default)]
235pub struct ParallelTensorStats {
236 pub total_contractions: u64,
238 pub total_computation_time: Duration,
240 pub parallel_efficiency: f64,
242 pub peak_memory_usage: usize,
244 pub thread_utilization: Vec<f64>,
246 pub load_balance_factor: f64,
248 pub cache_hit_rate: f64,
250}
251
252impl ParallelTensorEngine {
253 pub fn new(config: ParallelTensorConfig) -> Result<Self> {
255 let thread_pool = ThreadPoolBuilder::new() .num_threads(config.num_threads)
257 .build()
258 .map_err(|e| {
259 SimulatorError::InitializationFailed(format!("Thread pool creation failed: {e}"))
260 })?;
261
262 Ok(Self {
263 config,
264 thread_pool,
265 stats: Arc::new(Mutex::new(ParallelTensorStats::default())),
266 })
267 }
268
269 pub fn contract_network(
271 &self,
272 tensors: &[ArrayD<Complex64>],
273 contraction_sequence: &[ContractionPair],
274 ) -> Result<ArrayD<Complex64>> {
275 let start_time = Instant::now();
276
277 let work_units = self.create_work_units(tensors, contraction_sequence)?;
279
280 let work_queue = Arc::new(TensorWorkQueue::new(work_units, self.config.clone()));
282
283 let intermediate_results =
285 Arc::new(RwLock::new(HashMap::<usize, ArrayD<Complex64>>::new()));
286
287 {
289 let mut results = intermediate_results
290 .write()
291 .expect("intermediate_results lock should not be poisoned");
292 for (i, tensor) in tensors.iter().enumerate() {
293 results.insert(i, tensor.clone());
294 }
295 }
296
297 let final_result = self.execute_parallel_contractions(work_queue, intermediate_results)?;
299
300 let elapsed = start_time.elapsed();
302 let mut stats = self
303 .stats
304 .lock()
305 .expect("stats lock should not be poisoned");
306 stats.total_contractions += contraction_sequence.len() as u64;
307 stats.total_computation_time += elapsed;
308
309 let sequential_estimate = self.estimate_sequential_time(contraction_sequence);
311 stats.parallel_efficiency = sequential_estimate.as_secs_f64() / elapsed.as_secs_f64();
312
313 Ok(final_result)
314 }
315
316 fn create_work_units(
318 &self,
319 tensors: &[ArrayD<Complex64>],
320 contraction_sequence: &[ContractionPair],
321 ) -> Result<Vec<TensorWorkUnit>> {
322 let mut work_units: Vec<TensorWorkUnit> = Vec::new();
323 let mut next_tensor_id = tensors.len();
324
325 for (i, contraction) in contraction_sequence.iter().enumerate() {
326 let estimated_cost = self.estimate_contraction_cost(contraction, tensors)?;
327 let memory_requirement = self.estimate_memory_requirement(contraction, tensors)?;
328
329 let mut dependencies = HashSet::new();
331 for &input_id in &[contraction.tensor1_id, contraction.tensor2_id] {
332 if input_id >= tensors.len() {
333 for prev_unit in &work_units {
335 if prev_unit.output_tensor == input_id {
336 dependencies.insert(prev_unit.id);
337 break;
338 }
339 }
340 }
341 }
342
343 let work_unit = TensorWorkUnit {
344 id: i,
345 input_tensors: vec![contraction.tensor1_id, contraction.tensor2_id],
346 output_tensor: next_tensor_id,
347 contraction_indices: vec![
348 contraction.tensor1_indices.clone(),
349 contraction.tensor2_indices.clone(),
350 ],
351 estimated_cost,
352 memory_requirement,
353 dependencies,
354 priority: self.calculate_priority(estimated_cost, memory_requirement),
355 };
356
357 work_units.push(work_unit);
358 next_tensor_id += 1;
359 }
360
361 Ok(work_units)
362 }
363
364 fn execute_parallel_contractions(
366 &self,
367 work_queue: Arc<TensorWorkQueue>,
368 intermediate_results: Arc<RwLock<HashMap<usize, ArrayD<Complex64>>>>,
369 ) -> Result<ArrayD<Complex64>> {
370 let num_threads = self.config.num_threads;
371 let mut handles = Vec::new();
372
373 for thread_id in 0..num_threads {
375 let work_queue = work_queue.clone();
376 let intermediate_results = intermediate_results.clone();
377 let config = self.config.clone();
378
379 let handle = thread::spawn(move || {
380 Self::worker_thread(thread_id, work_queue, intermediate_results, config)
381 });
382 handles.push(handle);
383 }
384
385 for handle in handles {
387 handle.join().map_err(|e| {
388 SimulatorError::ComputationError(format!("Thread join failed: {e:?}"))
389 })??;
390 }
391
392 let results = intermediate_results
394 .read()
395 .expect("intermediate_results lock should not be poisoned");
396 let max_id = results.keys().max().copied().unwrap_or(0);
397 Ok(results[&max_id].clone())
398 }
399
400 fn worker_thread(
402 _thread_id: usize,
403 work_queue: Arc<TensorWorkQueue>,
404 intermediate_results: Arc<RwLock<HashMap<usize, ArrayD<Complex64>>>>,
405 _config: ParallelTensorConfig,
406 ) -> Result<()> {
407 while !work_queue.is_complete() {
408 if let Some(work_unit) = work_queue.get_work() {
409 let tensor1 = {
411 let results = intermediate_results
412 .read()
413 .expect("intermediate_results lock should not be poisoned");
414 results[&work_unit.input_tensors[0]].clone()
415 };
416
417 let tensor2 = {
418 let results = intermediate_results
419 .read()
420 .expect("intermediate_results lock should not be poisoned");
421 results[&work_unit.input_tensors[1]].clone()
422 };
423
424 let result = Self::perform_tensor_contraction(
426 &tensor1,
427 &tensor2,
428 &work_unit.contraction_indices[0],
429 &work_unit.contraction_indices[1],
430 )?;
431
432 {
434 let mut results = intermediate_results
435 .write()
436 .expect("intermediate_results lock should not be poisoned");
437 results.insert(work_unit.output_tensor, result);
438 }
439
440 work_queue.complete_work(work_unit.id);
442 } else {
443 thread::sleep(Duration::from_millis(1));
445 }
446 }
447
448 Ok(())
449 }
450
451 fn perform_tensor_contraction(
453 tensor1: &ArrayD<Complex64>,
454 tensor2: &ArrayD<Complex64>,
455 indices1: &[usize],
456 indices2: &[usize],
457 ) -> Result<ArrayD<Complex64>> {
458 let shape1 = tensor1.shape();
462 let shape2 = tensor2.shape();
463
464 let mut output_shape = Vec::new();
466 for (i, &size) in shape1.iter().enumerate() {
467 if !indices1.contains(&i) {
468 output_shape.push(size);
469 }
470 }
471 for (i, &size) in shape2.iter().enumerate() {
472 if !indices2.contains(&i) {
473 output_shape.push(size);
474 }
475 }
476
477 let output_dim = IxDyn(&output_shape);
479 let mut output = ArrayD::zeros(output_dim);
480
481 Ok(output)
484 }
485
486 fn estimate_contraction_cost(
488 &self,
489 contraction: &ContractionPair,
490 _tensors: &[ArrayD<Complex64>],
491 ) -> Result<f64> {
492 let cost = contraction.tensor1_indices.len() as f64
494 * contraction.tensor2_indices.len() as f64
495 * 1000.0; Ok(cost)
497 }
498
499 const fn estimate_memory_requirement(
501 &self,
502 _contraction: &ContractionPair,
503 _tensors: &[ArrayD<Complex64>],
504 ) -> Result<usize> {
505 Ok(1024 * 1024) }
508
509 fn calculate_priority(&self, cost: f64, memory: usize) -> i32 {
511 let cost_factor = (cost / 1000.0) as i32;
513 let memory_factor = (1_000_000 / (memory + 1)) as i32;
514 cost_factor + memory_factor
515 }
516
517 const fn estimate_sequential_time(&self, contraction_sequence: &[ContractionPair]) -> Duration {
519 let estimated_ops = contraction_sequence.len() as u64 * 1000; Duration::from_millis(estimated_ops)
521 }
522
523 #[must_use]
525 pub fn get_stats(&self) -> ParallelTensorStats {
526 self.stats
527 .lock()
528 .expect("stats lock should not be poisoned")
529 .clone()
530 }
531}
532
533#[derive(Debug, Clone)]
535pub struct ContractionPair {
536 pub tensor1_id: usize,
538 pub tensor2_id: usize,
540 pub tensor1_indices: Vec<usize>,
542 pub tensor2_indices: Vec<usize>,
544}
545
546pub mod strategies {
548 use super::{
549 ArrayD, Complex64, ContractionPair, LoadBalancingStrategy, NumaTopology,
550 ParallelTensorConfig, ParallelTensorEngine, Result,
551 };
552
553 pub fn work_stealing_contraction(
555 tensors: &[ArrayD<Complex64>],
556 contraction_sequence: &[ContractionPair],
557 num_threads: usize,
558 ) -> Result<ArrayD<Complex64>> {
559 let config = ParallelTensorConfig {
560 num_threads,
561 load_balancing: LoadBalancingStrategy::DynamicWorkStealing,
562 ..Default::default()
563 };
564
565 let engine = ParallelTensorEngine::new(config)?;
566 engine.contract_network(tensors, contraction_sequence)
567 }
568
569 pub fn numa_aware_contraction(
571 tensors: &[ArrayD<Complex64>],
572 contraction_sequence: &[ContractionPair],
573 numa_topology: &NumaTopology,
574 ) -> Result<ArrayD<Complex64>> {
575 let config = ParallelTensorConfig {
576 load_balancing: LoadBalancingStrategy::NumaAware,
577 numa_aware: true,
578 ..Default::default()
579 };
580
581 let engine = ParallelTensorEngine::new(config)?;
582 engine.contract_network(tensors, contraction_sequence)
583 }
584
585 pub fn adaptive_contraction(
587 tensors: &[ArrayD<Complex64>],
588 contraction_sequence: &[ContractionPair],
589 ) -> Result<ArrayD<Complex64>> {
590 let config = ParallelTensorConfig {
591 load_balancing: LoadBalancingStrategy::Adaptive,
592 enable_work_stealing: true,
593 ..Default::default()
594 };
595
596 let engine = ParallelTensorEngine::new(config)?;
597 engine.contract_network(tensors, contraction_sequence)
598 }
599}
600
601#[derive(Debug, Clone)]
603pub struct NumaTopology {
604 pub num_nodes: usize,
606 pub cores_per_node: Vec<usize>,
608 pub memory_per_node: Vec<usize>,
610}
611
612impl Default for NumaTopology {
613 fn default() -> Self {
614 let num_cores = current_num_threads(); Self {
616 num_nodes: 1,
617 cores_per_node: vec![num_cores],
618 memory_per_node: vec![8 * 1024 * 1024 * 1024], }
620 }
621}
622
623#[cfg(test)]
624mod tests {
625 use super::*;
626 use scirs2_core::ndarray::Array;
627
628 #[test]
629 fn test_parallel_tensor_engine() {
630 let config = ParallelTensorConfig::default();
631 let engine =
632 ParallelTensorEngine::new(config).expect("should create parallel tensor engine");
633
634 let tensor1 = Array::zeros(IxDyn(&[2, 2]));
636 let tensor2 = Array::zeros(IxDyn(&[2, 2]));
637 let tensors = vec![tensor1, tensor2];
638
639 let contraction = ContractionPair {
641 tensor1_id: 0,
642 tensor2_id: 1,
643 tensor1_indices: vec![1],
644 tensor2_indices: vec![0],
645 };
646
647 let result = engine.contract_network(&tensors, &[contraction]);
648 assert!(result.is_ok());
649 }
650
651 #[test]
652 fn test_work_queue() {
653 let work_unit = TensorWorkUnit {
654 id: 0,
655 input_tensors: vec![0, 1],
656 output_tensor: 2,
657 contraction_indices: vec![vec![0], vec![1]],
658 estimated_cost: 100.0,
659 memory_requirement: 1024,
660 dependencies: HashSet::new(),
661 priority: 1,
662 };
663
664 let config = ParallelTensorConfig::default();
665 let queue = TensorWorkQueue::new(vec![work_unit], config);
666
667 let work = queue.get_work();
668 assert!(work.is_some());
669
670 queue.complete_work(0);
671 assert!(queue.is_complete());
672 }
673
674 #[test]
675 fn test_parallel_strategies() {
676 let tensor1 = Array::ones(IxDyn(&[2, 2]));
677 let tensor2 = Array::ones(IxDyn(&[2, 2]));
678 let tensors = vec![tensor1, tensor2];
679
680 let contraction = ContractionPair {
681 tensor1_id: 0,
682 tensor2_id: 1,
683 tensor1_indices: vec![1],
684 tensor2_indices: vec![0],
685 };
686
687 let result = strategies::work_stealing_contraction(&tensors, &[contraction], 2);
688 assert!(result.is_ok());
689 }
690}