1use crate::prelude::SimulatorError;
7use scirs2_core::ndarray::{ArrayD, Dimension, IxDyn};
8use scirs2_core::Complex64;
9use scirs2_core::parallel_ops::*;
10use std::collections::{HashMap, HashSet, VecDeque};
11use std::sync::{Arc, Mutex, RwLock};
12use std::thread;
13use std::time::{Duration, Instant};
14
15use crate::error::Result;
16
17#[derive(Debug, Clone)]
19pub struct ParallelTensorConfig {
20 pub num_threads: usize,
22 pub chunk_size: usize,
24 pub enable_work_stealing: bool,
26 pub parallel_threshold_bytes: usize,
28 pub load_balancing: LoadBalancingStrategy,
30 pub numa_aware: bool,
32 pub thread_affinity: ThreadAffinityConfig,
34}
35
36impl Default for ParallelTensorConfig {
37 fn default() -> Self {
38 Self {
39 num_threads: rayon::current_num_threads(),
40 chunk_size: 1024,
41 enable_work_stealing: true,
42 parallel_threshold_bytes: 1024 * 1024, load_balancing: LoadBalancingStrategy::DynamicWorkStealing,
44 numa_aware: true,
45 thread_affinity: ThreadAffinityConfig::default(),
46 }
47 }
48}
49
50#[derive(Debug, Clone, Copy, PartialEq, Eq)]
52pub enum LoadBalancingStrategy {
53 RoundRobin,
55 DynamicWorkStealing,
57 NumaAware,
59 CostBased,
61 Adaptive,
63}
64
65#[derive(Debug, Clone)]
67pub struct ThreadAffinityConfig {
68 pub enable_affinity: bool,
70 pub core_mapping: Vec<usize>,
72 pub numa_preferences: HashMap<usize, usize>,
74}
75
76impl Default for ThreadAffinityConfig {
77 fn default() -> Self {
78 Self {
79 enable_affinity: false,
80 core_mapping: Vec::new(),
81 numa_preferences: HashMap::new(),
82 }
83 }
84}
85
86#[derive(Debug, Clone)]
88pub struct TensorWorkUnit {
89 pub id: usize,
91 pub input_tensors: Vec<usize>,
93 pub output_tensor: usize,
95 pub contraction_indices: Vec<Vec<usize>>,
97 pub estimated_cost: f64,
99 pub memory_requirement: usize,
101 pub dependencies: HashSet<usize>,
103 pub priority: i32,
105}
106
107#[derive(Debug)]
109pub struct TensorWorkQueue {
110 pending: Mutex<VecDeque<TensorWorkUnit>>,
112 completed: RwLock<HashSet<usize>>,
114 in_progress: RwLock<HashMap<usize, Instant>>,
116 total_units: usize,
118 config: ParallelTensorConfig,
120}
121
122impl TensorWorkQueue {
123 pub fn new(work_units: Vec<TensorWorkUnit>, config: ParallelTensorConfig) -> Self {
125 let total_units = work_units.len();
126 let mut pending = VecDeque::from(work_units);
127
128 pending.make_contiguous().sort_by(|a, b| {
130 b.priority
131 .cmp(&a.priority)
132 .then_with(|| a.dependencies.len().cmp(&b.dependencies.len()))
133 });
134
135 Self {
136 pending: Mutex::new(pending),
137 completed: RwLock::new(HashSet::new()),
138 in_progress: RwLock::new(HashMap::new()),
139 total_units,
140 config,
141 }
142 }
143
144 pub fn get_work(&self) -> Option<TensorWorkUnit> {
146 let mut pending = self.pending.lock().unwrap();
147 let completed = self.completed.read().unwrap();
148
149 for i in 0..pending.len() {
151 let work_unit = &pending[i];
152 let dependencies_satisfied = work_unit
153 .dependencies
154 .iter()
155 .all(|dep| completed.contains(dep));
156
157 if dependencies_satisfied {
158 let work_unit = pending.remove(i).unwrap();
159
160 drop(completed);
162 let mut in_progress = self.in_progress.write().unwrap();
163 in_progress.insert(work_unit.id, Instant::now());
164
165 return Some(work_unit);
166 }
167 }
168
169 None
170 }
171
172 pub fn complete_work(&self, work_id: usize) {
174 let mut completed = self.completed.write().unwrap();
175 completed.insert(work_id);
176
177 let mut in_progress = self.in_progress.write().unwrap();
178 in_progress.remove(&work_id);
179 }
180
181 pub fn is_complete(&self) -> bool {
183 let completed = self.completed.read().unwrap();
184 completed.len() == self.total_units
185 }
186
187 pub fn get_progress(&self) -> (usize, usize, usize) {
189 let completed = self.completed.read().unwrap().len();
190 let in_progress = self.in_progress.read().unwrap().len();
191 let pending = self.pending.lock().unwrap().len();
192 (completed, in_progress, pending)
193 }
194}
195
196pub struct ParallelTensorEngine {
198 config: ParallelTensorConfig,
200 thread_pool: rayon::ThreadPool,
202 stats: Arc<Mutex<ParallelTensorStats>>,
204}
205
206#[derive(Debug, Clone, Default)]
208pub struct ParallelTensorStats {
209 pub total_contractions: u64,
211 pub total_computation_time: Duration,
213 pub parallel_efficiency: f64,
215 pub peak_memory_usage: usize,
217 pub thread_utilization: Vec<f64>,
219 pub load_balance_factor: f64,
221 pub cache_hit_rate: f64,
223}
224
225impl ParallelTensorEngine {
226 pub fn new(config: ParallelTensorConfig) -> Result<Self> {
228 let thread_pool = rayon::ThreadPoolBuilder::new()
229 .num_threads(config.num_threads)
230 .build()
231 .map_err(|e| {
232 SimulatorError::InitializationFailed(format!("Thread pool creation failed: {}", e))
233 })?;
234
235 Ok(Self {
236 config,
237 thread_pool,
238 stats: Arc::new(Mutex::new(ParallelTensorStats::default())),
239 })
240 }
241
242 pub fn contract_network(
244 &self,
245 tensors: &[ArrayD<Complex64>],
246 contraction_sequence: &[ContractionPair],
247 ) -> Result<ArrayD<Complex64>> {
248 let start_time = Instant::now();
249
250 let work_units = self.create_work_units(tensors, contraction_sequence)?;
252
253 let work_queue = Arc::new(TensorWorkQueue::new(work_units, self.config.clone()));
255
256 let intermediate_results =
258 Arc::new(RwLock::new(HashMap::<usize, ArrayD<Complex64>>::new()));
259
260 {
262 let mut results = intermediate_results.write().unwrap();
263 for (i, tensor) in tensors.iter().enumerate() {
264 results.insert(i, tensor.clone());
265 }
266 }
267
268 let final_result =
270 self.execute_parallel_contractions(work_queue.clone(), intermediate_results.clone())?;
271
272 let elapsed = start_time.elapsed();
274 let mut stats = self.stats.lock().unwrap();
275 stats.total_contractions += contraction_sequence.len() as u64;
276 stats.total_computation_time += elapsed;
277
278 let sequential_estimate = self.estimate_sequential_time(contraction_sequence);
280 stats.parallel_efficiency = sequential_estimate.as_secs_f64() / elapsed.as_secs_f64();
281
282 Ok(final_result)
283 }
284
285 fn create_work_units(
287 &self,
288 tensors: &[ArrayD<Complex64>],
289 contraction_sequence: &[ContractionPair],
290 ) -> Result<Vec<TensorWorkUnit>> {
291 let mut work_units: Vec<TensorWorkUnit> = Vec::new();
292 let mut next_tensor_id = tensors.len();
293
294 for (i, contraction) in contraction_sequence.iter().enumerate() {
295 let estimated_cost = self.estimate_contraction_cost(contraction, tensors)?;
296 let memory_requirement = self.estimate_memory_requirement(contraction, tensors)?;
297
298 let mut dependencies = HashSet::new();
300 for &input_id in &[contraction.tensor1_id, contraction.tensor2_id] {
301 if input_id >= tensors.len() {
302 for prev_unit in &work_units {
304 if prev_unit.output_tensor == input_id {
305 dependencies.insert(prev_unit.id);
306 break;
307 }
308 }
309 }
310 }
311
312 let work_unit = TensorWorkUnit {
313 id: i,
314 input_tensors: vec![contraction.tensor1_id, contraction.tensor2_id],
315 output_tensor: next_tensor_id,
316 contraction_indices: vec![
317 contraction.tensor1_indices.clone(),
318 contraction.tensor2_indices.clone(),
319 ],
320 estimated_cost,
321 memory_requirement,
322 dependencies,
323 priority: self.calculate_priority(estimated_cost, memory_requirement),
324 };
325
326 work_units.push(work_unit);
327 next_tensor_id += 1;
328 }
329
330 Ok(work_units)
331 }
332
333 fn execute_parallel_contractions(
335 &self,
336 work_queue: Arc<TensorWorkQueue>,
337 intermediate_results: Arc<RwLock<HashMap<usize, ArrayD<Complex64>>>>,
338 ) -> Result<ArrayD<Complex64>> {
339 let num_threads = self.config.num_threads;
340 let mut handles = Vec::new();
341
342 for thread_id in 0..num_threads {
344 let work_queue = work_queue.clone();
345 let intermediate_results = intermediate_results.clone();
346 let config = self.config.clone();
347
348 let handle = thread::spawn(move || {
349 Self::worker_thread(thread_id, work_queue, intermediate_results, config)
350 });
351 handles.push(handle);
352 }
353
354 for handle in handles {
356 handle.join().map_err(|e| {
357 SimulatorError::ComputationError(format!("Thread join failed: {:?}", e))
358 })??;
359 }
360
361 let results = intermediate_results.read().unwrap();
363 let max_id = results.keys().max().copied().unwrap_or(0);
364 Ok(results[&max_id].clone())
365 }
366
367 fn worker_thread(
369 _thread_id: usize,
370 work_queue: Arc<TensorWorkQueue>,
371 intermediate_results: Arc<RwLock<HashMap<usize, ArrayD<Complex64>>>>,
372 _config: ParallelTensorConfig,
373 ) -> Result<()> {
374 while !work_queue.is_complete() {
375 if let Some(work_unit) = work_queue.get_work() {
376 let tensor1 = {
378 let results = intermediate_results.read().unwrap();
379 results[&work_unit.input_tensors[0]].clone()
380 };
381
382 let tensor2 = {
383 let results = intermediate_results.read().unwrap();
384 results[&work_unit.input_tensors[1]].clone()
385 };
386
387 let result = Self::perform_tensor_contraction(
389 &tensor1,
390 &tensor2,
391 &work_unit.contraction_indices[0],
392 &work_unit.contraction_indices[1],
393 )?;
394
395 {
397 let mut results = intermediate_results.write().unwrap();
398 results.insert(work_unit.output_tensor, result);
399 }
400
401 work_queue.complete_work(work_unit.id);
403 } else {
404 thread::sleep(Duration::from_millis(1));
406 }
407 }
408
409 Ok(())
410 }
411
412 fn perform_tensor_contraction(
414 tensor1: &ArrayD<Complex64>,
415 tensor2: &ArrayD<Complex64>,
416 indices1: &[usize],
417 indices2: &[usize],
418 ) -> Result<ArrayD<Complex64>> {
419 let shape1 = tensor1.shape();
423 let shape2 = tensor2.shape();
424
425 let mut output_shape = Vec::new();
427 for (i, &size) in shape1.iter().enumerate() {
428 if !indices1.contains(&i) {
429 output_shape.push(size);
430 }
431 }
432 for (i, &size) in shape2.iter().enumerate() {
433 if !indices2.contains(&i) {
434 output_shape.push(size);
435 }
436 }
437
438 let output_dim = IxDyn(&output_shape);
440 let mut output = ArrayD::zeros(output_dim);
441
442 Ok(output)
445 }
446
447 fn estimate_contraction_cost(
449 &self,
450 contraction: &ContractionPair,
451 _tensors: &[ArrayD<Complex64>],
452 ) -> Result<f64> {
453 let cost = contraction.tensor1_indices.len() as f64
455 * contraction.tensor2_indices.len() as f64
456 * 1000.0; Ok(cost)
458 }
459
460 fn estimate_memory_requirement(
462 &self,
463 _contraction: &ContractionPair,
464 _tensors: &[ArrayD<Complex64>],
465 ) -> Result<usize> {
466 Ok(1024 * 1024) }
469
470 fn calculate_priority(&self, cost: f64, memory: usize) -> i32 {
472 let cost_factor = (cost / 1000.0) as i32;
474 let memory_factor = (1_000_000 / (memory + 1)) as i32;
475 cost_factor + memory_factor
476 }
477
478 fn estimate_sequential_time(&self, contraction_sequence: &[ContractionPair]) -> Duration {
480 let estimated_ops = contraction_sequence.len() as u64 * 1000; Duration::from_millis(estimated_ops)
482 }
483
484 pub fn get_stats(&self) -> ParallelTensorStats {
486 self.stats.lock().unwrap().clone()
487 }
488}
489
490#[derive(Debug, Clone)]
492pub struct ContractionPair {
493 pub tensor1_id: usize,
495 pub tensor2_id: usize,
497 pub tensor1_indices: Vec<usize>,
499 pub tensor2_indices: Vec<usize>,
501}
502
503pub mod strategies {
505 use super::*;
506
507 pub fn work_stealing_contraction(
509 tensors: &[ArrayD<Complex64>],
510 contraction_sequence: &[ContractionPair],
511 num_threads: usize,
512 ) -> Result<ArrayD<Complex64>> {
513 let config = ParallelTensorConfig {
514 num_threads,
515 load_balancing: LoadBalancingStrategy::DynamicWorkStealing,
516 ..Default::default()
517 };
518
519 let engine = ParallelTensorEngine::new(config)?;
520 engine.contract_network(tensors, contraction_sequence)
521 }
522
523 pub fn numa_aware_contraction(
525 tensors: &[ArrayD<Complex64>],
526 contraction_sequence: &[ContractionPair],
527 numa_topology: &NumaTopology,
528 ) -> Result<ArrayD<Complex64>> {
529 let config = ParallelTensorConfig {
530 load_balancing: LoadBalancingStrategy::NumaAware,
531 numa_aware: true,
532 ..Default::default()
533 };
534
535 let engine = ParallelTensorEngine::new(config)?;
536 engine.contract_network(tensors, contraction_sequence)
537 }
538
539 pub fn adaptive_contraction(
541 tensors: &[ArrayD<Complex64>],
542 contraction_sequence: &[ContractionPair],
543 ) -> Result<ArrayD<Complex64>> {
544 let config = ParallelTensorConfig {
545 load_balancing: LoadBalancingStrategy::Adaptive,
546 enable_work_stealing: true,
547 ..Default::default()
548 };
549
550 let engine = ParallelTensorEngine::new(config)?;
551 engine.contract_network(tensors, contraction_sequence)
552 }
553}
554
555#[derive(Debug, Clone)]
557pub struct NumaTopology {
558 pub num_nodes: usize,
560 pub cores_per_node: Vec<usize>,
562 pub memory_per_node: Vec<usize>,
564}
565
566impl Default for NumaTopology {
567 fn default() -> Self {
568 let num_cores = rayon::current_num_threads();
569 Self {
570 num_nodes: 1,
571 cores_per_node: vec![num_cores],
572 memory_per_node: vec![8 * 1024 * 1024 * 1024], }
574 }
575}
576
577#[cfg(test)]
578mod tests {
579 use super::*;
580 use scirs2_core::ndarray::Array;
581
582 #[test]
583 fn test_parallel_tensor_engine() {
584 let config = ParallelTensorConfig::default();
585 let engine = ParallelTensorEngine::new(config).unwrap();
586
587 let tensor1 = Array::zeros(IxDyn(&[2, 2]));
589 let tensor2 = Array::zeros(IxDyn(&[2, 2]));
590 let tensors = vec![tensor1, tensor2];
591
592 let contraction = ContractionPair {
594 tensor1_id: 0,
595 tensor2_id: 1,
596 tensor1_indices: vec![1],
597 tensor2_indices: vec![0],
598 };
599
600 let result = engine.contract_network(&tensors, &[contraction]);
601 assert!(result.is_ok());
602 }
603
604 #[test]
605 fn test_work_queue() {
606 let work_unit = TensorWorkUnit {
607 id: 0,
608 input_tensors: vec![0, 1],
609 output_tensor: 2,
610 contraction_indices: vec![vec![0], vec![1]],
611 estimated_cost: 100.0,
612 memory_requirement: 1024,
613 dependencies: HashSet::new(),
614 priority: 1,
615 };
616
617 let config = ParallelTensorConfig::default();
618 let queue = TensorWorkQueue::new(vec![work_unit], config);
619
620 let work = queue.get_work();
621 assert!(work.is_some());
622
623 queue.complete_work(0);
624 assert!(queue.is_complete());
625 }
626
627 #[test]
628 fn test_parallel_strategies() {
629 let tensor1 = Array::ones(IxDyn(&[2, 2]));
630 let tensor2 = Array::ones(IxDyn(&[2, 2]));
631 let tensors = vec![tensor1, tensor2];
632
633 let contraction = ContractionPair {
634 tensor1_id: 0,
635 tensor2_id: 1,
636 tensor1_indices: vec![1],
637 tensor2_indices: vec![0],
638 };
639
640 let result = strategies::work_stealing_contraction(&tensors, &[contraction], 2);
641 assert!(result.is_ok());
642 }
643}