1use crate::prelude::SimulatorError;
7use scirs2_core::ndarray::{ArrayD, Dimension, IxDyn};
8use scirs2_core::parallel_ops::*;
9use scirs2_core::Complex64;
10use std::collections::{HashMap, HashSet, VecDeque};
11use std::sync::{Arc, Mutex, RwLock};
12use std::thread;
13use std::time::{Duration, Instant};
14
15use crate::error::Result;
16
17#[derive(Debug, Clone)]
19pub struct ParallelTensorConfig {
20 pub num_threads: usize,
22 pub chunk_size: usize,
24 pub enable_work_stealing: bool,
26 pub parallel_threshold_bytes: usize,
28 pub load_balancing: LoadBalancingStrategy,
30 pub numa_aware: bool,
32 pub thread_affinity: ThreadAffinityConfig,
34}
35
36impl Default for ParallelTensorConfig {
37 fn default() -> Self {
38 Self {
39 num_threads: current_num_threads(), chunk_size: 1024,
41 enable_work_stealing: true,
42 parallel_threshold_bytes: 1024 * 1024, load_balancing: LoadBalancingStrategy::DynamicWorkStealing,
44 numa_aware: true,
45 thread_affinity: ThreadAffinityConfig::default(),
46 }
47 }
48}
49
50#[derive(Debug, Clone, Copy, PartialEq, Eq)]
52pub enum LoadBalancingStrategy {
53 RoundRobin,
55 DynamicWorkStealing,
57 NumaAware,
59 CostBased,
61 Adaptive,
63}
64
65#[derive(Debug, Clone, Default)]
67pub struct ThreadAffinityConfig {
68 pub enable_affinity: bool,
70 pub core_mapping: Vec<usize>,
72 pub numa_preferences: HashMap<usize, usize>,
74}
75
76#[derive(Debug, Clone)]
78pub struct TensorWorkUnit {
79 pub id: usize,
81 pub input_tensors: Vec<usize>,
83 pub output_tensor: usize,
85 pub contraction_indices: Vec<Vec<usize>>,
87 pub estimated_cost: f64,
89 pub memory_requirement: usize,
91 pub dependencies: HashSet<usize>,
93 pub priority: i32,
95}
96
97#[derive(Debug)]
99pub struct TensorWorkQueue {
100 pending: Mutex<VecDeque<TensorWorkUnit>>,
102 completed: RwLock<HashSet<usize>>,
104 in_progress: RwLock<HashMap<usize, Instant>>,
106 total_units: usize,
108 config: ParallelTensorConfig,
110}
111
112impl TensorWorkQueue {
113 pub fn new(work_units: Vec<TensorWorkUnit>, config: ParallelTensorConfig) -> Self {
115 let total_units = work_units.len();
116 let mut pending = VecDeque::from(work_units);
117
118 pending.make_contiguous().sort_by(|a, b| {
120 b.priority
121 .cmp(&a.priority)
122 .then_with(|| a.dependencies.len().cmp(&b.dependencies.len()))
123 });
124
125 Self {
126 pending: Mutex::new(pending),
127 completed: RwLock::new(HashSet::new()),
128 in_progress: RwLock::new(HashMap::new()),
129 total_units,
130 config,
131 }
132 }
133
134 pub fn get_work(&self) -> Option<TensorWorkUnit> {
136 let mut pending = self.pending.lock().unwrap();
137 let completed = self.completed.read().unwrap();
138
139 for i in 0..pending.len() {
141 let work_unit = &pending[i];
142 let dependencies_satisfied = work_unit
143 .dependencies
144 .iter()
145 .all(|dep| completed.contains(dep));
146
147 if dependencies_satisfied {
148 let work_unit = pending.remove(i).unwrap();
149
150 drop(completed);
152 let mut in_progress = self.in_progress.write().unwrap();
153 in_progress.insert(work_unit.id, Instant::now());
154
155 return Some(work_unit);
156 }
157 }
158
159 None
160 }
161
162 pub fn complete_work(&self, work_id: usize) {
164 let mut completed = self.completed.write().unwrap();
165 completed.insert(work_id);
166
167 let mut in_progress = self.in_progress.write().unwrap();
168 in_progress.remove(&work_id);
169 }
170
171 pub fn is_complete(&self) -> bool {
173 let completed = self.completed.read().unwrap();
174 completed.len() == self.total_units
175 }
176
177 pub fn get_progress(&self) -> (usize, usize, usize) {
179 let completed = self.completed.read().unwrap().len();
180 let in_progress = self.in_progress.read().unwrap().len();
181 let pending = self.pending.lock().unwrap().len();
182 (completed, in_progress, pending)
183 }
184}
185
186pub struct ParallelTensorEngine {
188 config: ParallelTensorConfig,
190 thread_pool: ThreadPool, stats: Arc<Mutex<ParallelTensorStats>>,
194}
195
196#[derive(Debug, Clone, Default)]
198pub struct ParallelTensorStats {
199 pub total_contractions: u64,
201 pub total_computation_time: Duration,
203 pub parallel_efficiency: f64,
205 pub peak_memory_usage: usize,
207 pub thread_utilization: Vec<f64>,
209 pub load_balance_factor: f64,
211 pub cache_hit_rate: f64,
213}
214
215impl ParallelTensorEngine {
216 pub fn new(config: ParallelTensorConfig) -> Result<Self> {
218 let thread_pool = ThreadPoolBuilder::new() .num_threads(config.num_threads)
220 .build()
221 .map_err(|e| {
222 SimulatorError::InitializationFailed(format!("Thread pool creation failed: {e}"))
223 })?;
224
225 Ok(Self {
226 config,
227 thread_pool,
228 stats: Arc::new(Mutex::new(ParallelTensorStats::default())),
229 })
230 }
231
232 pub fn contract_network(
234 &self,
235 tensors: &[ArrayD<Complex64>],
236 contraction_sequence: &[ContractionPair],
237 ) -> Result<ArrayD<Complex64>> {
238 let start_time = Instant::now();
239
240 let work_units = self.create_work_units(tensors, contraction_sequence)?;
242
243 let work_queue = Arc::new(TensorWorkQueue::new(work_units, self.config.clone()));
245
246 let intermediate_results =
248 Arc::new(RwLock::new(HashMap::<usize, ArrayD<Complex64>>::new()));
249
250 {
252 let mut results = intermediate_results.write().unwrap();
253 for (i, tensor) in tensors.iter().enumerate() {
254 results.insert(i, tensor.clone());
255 }
256 }
257
258 let final_result = self.execute_parallel_contractions(work_queue, intermediate_results)?;
260
261 let elapsed = start_time.elapsed();
263 let mut stats = self.stats.lock().unwrap();
264 stats.total_contractions += contraction_sequence.len() as u64;
265 stats.total_computation_time += elapsed;
266
267 let sequential_estimate = self.estimate_sequential_time(contraction_sequence);
269 stats.parallel_efficiency = sequential_estimate.as_secs_f64() / elapsed.as_secs_f64();
270
271 Ok(final_result)
272 }
273
274 fn create_work_units(
276 &self,
277 tensors: &[ArrayD<Complex64>],
278 contraction_sequence: &[ContractionPair],
279 ) -> Result<Vec<TensorWorkUnit>> {
280 let mut work_units: Vec<TensorWorkUnit> = Vec::new();
281 let mut next_tensor_id = tensors.len();
282
283 for (i, contraction) in contraction_sequence.iter().enumerate() {
284 let estimated_cost = self.estimate_contraction_cost(contraction, tensors)?;
285 let memory_requirement = self.estimate_memory_requirement(contraction, tensors)?;
286
287 let mut dependencies = HashSet::new();
289 for &input_id in &[contraction.tensor1_id, contraction.tensor2_id] {
290 if input_id >= tensors.len() {
291 for prev_unit in &work_units {
293 if prev_unit.output_tensor == input_id {
294 dependencies.insert(prev_unit.id);
295 break;
296 }
297 }
298 }
299 }
300
301 let work_unit = TensorWorkUnit {
302 id: i,
303 input_tensors: vec![contraction.tensor1_id, contraction.tensor2_id],
304 output_tensor: next_tensor_id,
305 contraction_indices: vec![
306 contraction.tensor1_indices.clone(),
307 contraction.tensor2_indices.clone(),
308 ],
309 estimated_cost,
310 memory_requirement,
311 dependencies,
312 priority: self.calculate_priority(estimated_cost, memory_requirement),
313 };
314
315 work_units.push(work_unit);
316 next_tensor_id += 1;
317 }
318
319 Ok(work_units)
320 }
321
322 fn execute_parallel_contractions(
324 &self,
325 work_queue: Arc<TensorWorkQueue>,
326 intermediate_results: Arc<RwLock<HashMap<usize, ArrayD<Complex64>>>>,
327 ) -> Result<ArrayD<Complex64>> {
328 let num_threads = self.config.num_threads;
329 let mut handles = Vec::new();
330
331 for thread_id in 0..num_threads {
333 let work_queue = work_queue.clone();
334 let intermediate_results = intermediate_results.clone();
335 let config = self.config.clone();
336
337 let handle = thread::spawn(move || {
338 Self::worker_thread(thread_id, work_queue, intermediate_results, config)
339 });
340 handles.push(handle);
341 }
342
343 for handle in handles {
345 handle.join().map_err(|e| {
346 SimulatorError::ComputationError(format!("Thread join failed: {e:?}"))
347 })??;
348 }
349
350 let results = intermediate_results.read().unwrap();
352 let max_id = results.keys().max().copied().unwrap_or(0);
353 Ok(results[&max_id].clone())
354 }
355
356 fn worker_thread(
358 _thread_id: usize,
359 work_queue: Arc<TensorWorkQueue>,
360 intermediate_results: Arc<RwLock<HashMap<usize, ArrayD<Complex64>>>>,
361 _config: ParallelTensorConfig,
362 ) -> Result<()> {
363 while !work_queue.is_complete() {
364 if let Some(work_unit) = work_queue.get_work() {
365 let tensor1 = {
367 let results = intermediate_results.read().unwrap();
368 results[&work_unit.input_tensors[0]].clone()
369 };
370
371 let tensor2 = {
372 let results = intermediate_results.read().unwrap();
373 results[&work_unit.input_tensors[1]].clone()
374 };
375
376 let result = Self::perform_tensor_contraction(
378 &tensor1,
379 &tensor2,
380 &work_unit.contraction_indices[0],
381 &work_unit.contraction_indices[1],
382 )?;
383
384 {
386 let mut results = intermediate_results.write().unwrap();
387 results.insert(work_unit.output_tensor, result);
388 }
389
390 work_queue.complete_work(work_unit.id);
392 } else {
393 thread::sleep(Duration::from_millis(1));
395 }
396 }
397
398 Ok(())
399 }
400
401 fn perform_tensor_contraction(
403 tensor1: &ArrayD<Complex64>,
404 tensor2: &ArrayD<Complex64>,
405 indices1: &[usize],
406 indices2: &[usize],
407 ) -> Result<ArrayD<Complex64>> {
408 let shape1 = tensor1.shape();
412 let shape2 = tensor2.shape();
413
414 let mut output_shape = Vec::new();
416 for (i, &size) in shape1.iter().enumerate() {
417 if !indices1.contains(&i) {
418 output_shape.push(size);
419 }
420 }
421 for (i, &size) in shape2.iter().enumerate() {
422 if !indices2.contains(&i) {
423 output_shape.push(size);
424 }
425 }
426
427 let output_dim = IxDyn(&output_shape);
429 let mut output = ArrayD::zeros(output_dim);
430
431 Ok(output)
434 }
435
436 fn estimate_contraction_cost(
438 &self,
439 contraction: &ContractionPair,
440 _tensors: &[ArrayD<Complex64>],
441 ) -> Result<f64> {
442 let cost = contraction.tensor1_indices.len() as f64
444 * contraction.tensor2_indices.len() as f64
445 * 1000.0; Ok(cost)
447 }
448
449 const fn estimate_memory_requirement(
451 &self,
452 _contraction: &ContractionPair,
453 _tensors: &[ArrayD<Complex64>],
454 ) -> Result<usize> {
455 Ok(1024 * 1024) }
458
459 fn calculate_priority(&self, cost: f64, memory: usize) -> i32 {
461 let cost_factor = (cost / 1000.0) as i32;
463 let memory_factor = (1_000_000 / (memory + 1)) as i32;
464 cost_factor + memory_factor
465 }
466
467 const fn estimate_sequential_time(&self, contraction_sequence: &[ContractionPair]) -> Duration {
469 let estimated_ops = contraction_sequence.len() as u64 * 1000; Duration::from_millis(estimated_ops)
471 }
472
473 pub fn get_stats(&self) -> ParallelTensorStats {
475 self.stats.lock().unwrap().clone()
476 }
477}
478
479#[derive(Debug, Clone)]
481pub struct ContractionPair {
482 pub tensor1_id: usize,
484 pub tensor2_id: usize,
486 pub tensor1_indices: Vec<usize>,
488 pub tensor2_indices: Vec<usize>,
490}
491
492pub mod strategies {
494 use super::*;
495
496 pub fn work_stealing_contraction(
498 tensors: &[ArrayD<Complex64>],
499 contraction_sequence: &[ContractionPair],
500 num_threads: usize,
501 ) -> Result<ArrayD<Complex64>> {
502 let config = ParallelTensorConfig {
503 num_threads,
504 load_balancing: LoadBalancingStrategy::DynamicWorkStealing,
505 ..Default::default()
506 };
507
508 let engine = ParallelTensorEngine::new(config)?;
509 engine.contract_network(tensors, contraction_sequence)
510 }
511
512 pub fn numa_aware_contraction(
514 tensors: &[ArrayD<Complex64>],
515 contraction_sequence: &[ContractionPair],
516 numa_topology: &NumaTopology,
517 ) -> Result<ArrayD<Complex64>> {
518 let config = ParallelTensorConfig {
519 load_balancing: LoadBalancingStrategy::NumaAware,
520 numa_aware: true,
521 ..Default::default()
522 };
523
524 let engine = ParallelTensorEngine::new(config)?;
525 engine.contract_network(tensors, contraction_sequence)
526 }
527
528 pub fn adaptive_contraction(
530 tensors: &[ArrayD<Complex64>],
531 contraction_sequence: &[ContractionPair],
532 ) -> Result<ArrayD<Complex64>> {
533 let config = ParallelTensorConfig {
534 load_balancing: LoadBalancingStrategy::Adaptive,
535 enable_work_stealing: true,
536 ..Default::default()
537 };
538
539 let engine = ParallelTensorEngine::new(config)?;
540 engine.contract_network(tensors, contraction_sequence)
541 }
542}
543
544#[derive(Debug, Clone)]
546pub struct NumaTopology {
547 pub num_nodes: usize,
549 pub cores_per_node: Vec<usize>,
551 pub memory_per_node: Vec<usize>,
553}
554
555impl Default for NumaTopology {
556 fn default() -> Self {
557 let num_cores = current_num_threads(); Self {
559 num_nodes: 1,
560 cores_per_node: vec![num_cores],
561 memory_per_node: vec![8 * 1024 * 1024 * 1024], }
563 }
564}
565
566#[cfg(test)]
567mod tests {
568 use super::*;
569 use scirs2_core::ndarray::Array;
570
571 #[test]
572 fn test_parallel_tensor_engine() {
573 let config = ParallelTensorConfig::default();
574 let engine = ParallelTensorEngine::new(config).unwrap();
575
576 let tensor1 = Array::zeros(IxDyn(&[2, 2]));
578 let tensor2 = Array::zeros(IxDyn(&[2, 2]));
579 let tensors = vec![tensor1, tensor2];
580
581 let contraction = ContractionPair {
583 tensor1_id: 0,
584 tensor2_id: 1,
585 tensor1_indices: vec![1],
586 tensor2_indices: vec![0],
587 };
588
589 let result = engine.contract_network(&tensors, &[contraction]);
590 assert!(result.is_ok());
591 }
592
593 #[test]
594 fn test_work_queue() {
595 let work_unit = TensorWorkUnit {
596 id: 0,
597 input_tensors: vec![0, 1],
598 output_tensor: 2,
599 contraction_indices: vec![vec![0], vec![1]],
600 estimated_cost: 100.0,
601 memory_requirement: 1024,
602 dependencies: HashSet::new(),
603 priority: 1,
604 };
605
606 let config = ParallelTensorConfig::default();
607 let queue = TensorWorkQueue::new(vec![work_unit], config);
608
609 let work = queue.get_work();
610 assert!(work.is_some());
611
612 queue.complete_work(0);
613 assert!(queue.is_complete());
614 }
615
616 #[test]
617 fn test_parallel_strategies() {
618 let tensor1 = Array::ones(IxDyn(&[2, 2]));
619 let tensor2 = Array::ones(IxDyn(&[2, 2]));
620 let tensors = vec![tensor1, tensor2];
621
622 let contraction = ContractionPair {
623 tensor1_id: 0,
624 tensor2_id: 1,
625 tensor1_indices: vec![1],
626 tensor2_indices: vec![0],
627 };
628
629 let result = strategies::work_stealing_contraction(&tensors, &[contraction], 2);
630 assert!(result.is_ok());
631 }
632}