use crate::prelude::SimulatorError;
use scirs2_core::ndarray::{ArrayD, Dimension, IxDyn};
use scirs2_core::parallel_ops::{
current_num_threads, IndexedParallelIterator, ParallelIterator, ThreadPool, ThreadPoolBuilder,
};
use scirs2_core::Complex64;
use std::collections::{HashMap, HashSet, VecDeque};
use std::sync::{Arc, Mutex, RwLock};
use std::thread;
use std::time::{Duration, Instant};
use crate::error::Result;
#[derive(Debug, Clone)]
pub struct ParallelTensorConfig {
pub num_threads: usize,
pub chunk_size: usize,
pub enable_work_stealing: bool,
pub parallel_threshold_bytes: usize,
pub load_balancing: LoadBalancingStrategy,
pub numa_aware: bool,
pub thread_affinity: ThreadAffinityConfig,
}
impl Default for ParallelTensorConfig {
fn default() -> Self {
Self {
num_threads: current_num_threads(), chunk_size: 1024,
enable_work_stealing: true,
parallel_threshold_bytes: 1024 * 1024, load_balancing: LoadBalancingStrategy::DynamicWorkStealing,
numa_aware: true,
thread_affinity: ThreadAffinityConfig::default(),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LoadBalancingStrategy {
RoundRobin,
DynamicWorkStealing,
NumaAware,
CostBased,
Adaptive,
}
#[derive(Debug, Clone, Default)]
pub struct ThreadAffinityConfig {
pub enable_affinity: bool,
pub core_mapping: Vec<usize>,
pub numa_preferences: HashMap<usize, usize>,
}
#[derive(Debug, Clone)]
pub struct TensorWorkUnit {
pub id: usize,
pub input_tensors: Vec<usize>,
pub output_tensor: usize,
pub contraction_indices: Vec<Vec<usize>>,
pub estimated_cost: f64,
pub memory_requirement: usize,
pub dependencies: HashSet<usize>,
pub priority: i32,
}
#[derive(Debug)]
pub struct TensorWorkQueue {
pending: Mutex<VecDeque<TensorWorkUnit>>,
completed: RwLock<HashSet<usize>>,
in_progress: RwLock<HashMap<usize, Instant>>,
total_units: usize,
config: ParallelTensorConfig,
}
impl TensorWorkQueue {
#[must_use]
pub fn new(work_units: Vec<TensorWorkUnit>, config: ParallelTensorConfig) -> Self {
let total_units = work_units.len();
let mut pending = VecDeque::from(work_units);
pending.make_contiguous().sort_by(|a, b| {
b.priority
.cmp(&a.priority)
.then_with(|| a.dependencies.len().cmp(&b.dependencies.len()))
});
Self {
pending: Mutex::new(pending),
completed: RwLock::new(HashSet::new()),
in_progress: RwLock::new(HashMap::new()),
total_units,
config,
}
}
pub fn get_work(&self) -> Option<TensorWorkUnit> {
let mut pending = self
.pending
.lock()
.expect("pending lock should not be poisoned");
let completed = self
.completed
.read()
.expect("completed lock should not be poisoned");
for i in 0..pending.len() {
let work_unit = &pending[i];
let dependencies_satisfied = work_unit
.dependencies
.iter()
.all(|dep| completed.contains(dep));
if dependencies_satisfied {
let work_unit = pending
.remove(i)
.expect("index i is guaranteed to be within bounds");
drop(completed);
let mut in_progress = self
.in_progress
.write()
.expect("in_progress lock should not be poisoned");
in_progress.insert(work_unit.id, Instant::now());
return Some(work_unit);
}
}
None
}
pub fn complete_work(&self, work_id: usize) {
let mut completed = self
.completed
.write()
.expect("completed lock should not be poisoned");
completed.insert(work_id);
let mut in_progress = self
.in_progress
.write()
.expect("in_progress lock should not be poisoned");
in_progress.remove(&work_id);
}
pub fn is_complete(&self) -> bool {
let completed = self
.completed
.read()
.expect("completed lock should not be poisoned");
completed.len() == self.total_units
}
pub fn get_progress(&self) -> (usize, usize, usize) {
let completed = self
.completed
.read()
.expect("completed lock should not be poisoned")
.len();
let in_progress = self
.in_progress
.read()
.expect("in_progress lock should not be poisoned")
.len();
let pending = self
.pending
.lock()
.expect("pending lock should not be poisoned")
.len();
(completed, in_progress, pending)
}
}
pub struct ParallelTensorEngine {
config: ParallelTensorConfig,
thread_pool: ThreadPool, stats: Arc<Mutex<ParallelTensorStats>>,
}
#[derive(Debug, Clone, Default)]
pub struct ParallelTensorStats {
pub total_contractions: u64,
pub total_computation_time: Duration,
pub parallel_efficiency: f64,
pub peak_memory_usage: usize,
pub thread_utilization: Vec<f64>,
pub load_balance_factor: f64,
pub cache_hit_rate: f64,
}
impl ParallelTensorEngine {
pub fn new(config: ParallelTensorConfig) -> Result<Self> {
let thread_pool = ThreadPoolBuilder::new() .num_threads(config.num_threads)
.build()
.map_err(|e| {
SimulatorError::InitializationFailed(format!("Thread pool creation failed: {e}"))
})?;
Ok(Self {
config,
thread_pool,
stats: Arc::new(Mutex::new(ParallelTensorStats::default())),
})
}
pub fn contract_network(
&self,
tensors: &[ArrayD<Complex64>],
contraction_sequence: &[ContractionPair],
) -> Result<ArrayD<Complex64>> {
let start_time = Instant::now();
let work_units = self.create_work_units(tensors, contraction_sequence)?;
let work_queue = Arc::new(TensorWorkQueue::new(work_units, self.config.clone()));
let intermediate_results =
Arc::new(RwLock::new(HashMap::<usize, ArrayD<Complex64>>::new()));
{
let mut results = intermediate_results
.write()
.expect("intermediate_results lock should not be poisoned");
for (i, tensor) in tensors.iter().enumerate() {
results.insert(i, tensor.clone());
}
}
let final_result = self.execute_parallel_contractions(work_queue, intermediate_results)?;
let elapsed = start_time.elapsed();
let mut stats = self
.stats
.lock()
.expect("stats lock should not be poisoned");
stats.total_contractions += contraction_sequence.len() as u64;
stats.total_computation_time += elapsed;
let sequential_estimate = self.estimate_sequential_time(contraction_sequence);
stats.parallel_efficiency = sequential_estimate.as_secs_f64() / elapsed.as_secs_f64();
Ok(final_result)
}
fn create_work_units(
&self,
tensors: &[ArrayD<Complex64>],
contraction_sequence: &[ContractionPair],
) -> Result<Vec<TensorWorkUnit>> {
let mut work_units: Vec<TensorWorkUnit> = Vec::new();
let mut next_tensor_id = tensors.len();
for (i, contraction) in contraction_sequence.iter().enumerate() {
let estimated_cost = self.estimate_contraction_cost(contraction, tensors)?;
let memory_requirement = self.estimate_memory_requirement(contraction, tensors)?;
let mut dependencies = HashSet::new();
for &input_id in &[contraction.tensor1_id, contraction.tensor2_id] {
if input_id >= tensors.len() {
for prev_unit in &work_units {
if prev_unit.output_tensor == input_id {
dependencies.insert(prev_unit.id);
break;
}
}
}
}
let work_unit = TensorWorkUnit {
id: i,
input_tensors: vec![contraction.tensor1_id, contraction.tensor2_id],
output_tensor: next_tensor_id,
contraction_indices: vec![
contraction.tensor1_indices.clone(),
contraction.tensor2_indices.clone(),
],
estimated_cost,
memory_requirement,
dependencies,
priority: self.calculate_priority(estimated_cost, memory_requirement),
};
work_units.push(work_unit);
next_tensor_id += 1;
}
Ok(work_units)
}
fn execute_parallel_contractions(
&self,
work_queue: Arc<TensorWorkQueue>,
intermediate_results: Arc<RwLock<HashMap<usize, ArrayD<Complex64>>>>,
) -> Result<ArrayD<Complex64>> {
let num_threads = self.config.num_threads;
let mut handles = Vec::new();
for thread_id in 0..num_threads {
let work_queue = work_queue.clone();
let intermediate_results = intermediate_results.clone();
let config = self.config.clone();
let handle = thread::spawn(move || {
Self::worker_thread(thread_id, work_queue, intermediate_results, config)
});
handles.push(handle);
}
for handle in handles {
handle.join().map_err(|e| {
SimulatorError::ComputationError(format!("Thread join failed: {e:?}"))
})??;
}
let results = intermediate_results
.read()
.expect("intermediate_results lock should not be poisoned");
let max_id = results.keys().max().copied().unwrap_or(0);
Ok(results[&max_id].clone())
}
fn worker_thread(
_thread_id: usize,
work_queue: Arc<TensorWorkQueue>,
intermediate_results: Arc<RwLock<HashMap<usize, ArrayD<Complex64>>>>,
_config: ParallelTensorConfig,
) -> Result<()> {
while !work_queue.is_complete() {
if let Some(work_unit) = work_queue.get_work() {
let tensor1 = {
let results = intermediate_results
.read()
.expect("intermediate_results lock should not be poisoned");
results[&work_unit.input_tensors[0]].clone()
};
let tensor2 = {
let results = intermediate_results
.read()
.expect("intermediate_results lock should not be poisoned");
results[&work_unit.input_tensors[1]].clone()
};
let result = Self::perform_tensor_contraction(
&tensor1,
&tensor2,
&work_unit.contraction_indices[0],
&work_unit.contraction_indices[1],
)?;
{
let mut results = intermediate_results
.write()
.expect("intermediate_results lock should not be poisoned");
results.insert(work_unit.output_tensor, result);
}
work_queue.complete_work(work_unit.id);
} else {
thread::sleep(Duration::from_millis(1));
}
}
Ok(())
}
fn perform_tensor_contraction(
tensor1: &ArrayD<Complex64>,
tensor2: &ArrayD<Complex64>,
indices1: &[usize],
indices2: &[usize],
) -> Result<ArrayD<Complex64>> {
let shape1 = tensor1.shape();
let shape2 = tensor2.shape();
let mut output_shape = Vec::new();
for (i, &size) in shape1.iter().enumerate() {
if !indices1.contains(&i) {
output_shape.push(size);
}
}
for (i, &size) in shape2.iter().enumerate() {
if !indices2.contains(&i) {
output_shape.push(size);
}
}
let output_dim = IxDyn(&output_shape);
let mut output = ArrayD::zeros(output_dim);
Ok(output)
}
fn estimate_contraction_cost(
&self,
contraction: &ContractionPair,
_tensors: &[ArrayD<Complex64>],
) -> Result<f64> {
let cost = contraction.tensor1_indices.len() as f64
* contraction.tensor2_indices.len() as f64
* 1000.0; Ok(cost)
}
const fn estimate_memory_requirement(
&self,
_contraction: &ContractionPair,
_tensors: &[ArrayD<Complex64>],
) -> Result<usize> {
Ok(1024 * 1024) }
fn calculate_priority(&self, cost: f64, memory: usize) -> i32 {
let cost_factor = (cost / 1000.0) as i32;
let memory_factor = (1_000_000 / (memory + 1)) as i32;
cost_factor + memory_factor
}
const fn estimate_sequential_time(&self, contraction_sequence: &[ContractionPair]) -> Duration {
let estimated_ops = contraction_sequence.len() as u64 * 1000; Duration::from_millis(estimated_ops)
}
#[must_use]
pub fn get_stats(&self) -> ParallelTensorStats {
self.stats
.lock()
.expect("stats lock should not be poisoned")
.clone()
}
}
#[derive(Debug, Clone)]
pub struct ContractionPair {
pub tensor1_id: usize,
pub tensor2_id: usize,
pub tensor1_indices: Vec<usize>,
pub tensor2_indices: Vec<usize>,
}
pub mod strategies {
use super::{
ArrayD, Complex64, ContractionPair, LoadBalancingStrategy, NumaTopology,
ParallelTensorConfig, ParallelTensorEngine, Result,
};
pub fn work_stealing_contraction(
tensors: &[ArrayD<Complex64>],
contraction_sequence: &[ContractionPair],
num_threads: usize,
) -> Result<ArrayD<Complex64>> {
let config = ParallelTensorConfig {
num_threads,
load_balancing: LoadBalancingStrategy::DynamicWorkStealing,
..Default::default()
};
let engine = ParallelTensorEngine::new(config)?;
engine.contract_network(tensors, contraction_sequence)
}
pub fn numa_aware_contraction(
tensors: &[ArrayD<Complex64>],
contraction_sequence: &[ContractionPair],
numa_topology: &NumaTopology,
) -> Result<ArrayD<Complex64>> {
let config = ParallelTensorConfig {
load_balancing: LoadBalancingStrategy::NumaAware,
numa_aware: true,
..Default::default()
};
let engine = ParallelTensorEngine::new(config)?;
engine.contract_network(tensors, contraction_sequence)
}
pub fn adaptive_contraction(
tensors: &[ArrayD<Complex64>],
contraction_sequence: &[ContractionPair],
) -> Result<ArrayD<Complex64>> {
let config = ParallelTensorConfig {
load_balancing: LoadBalancingStrategy::Adaptive,
enable_work_stealing: true,
..Default::default()
};
let engine = ParallelTensorEngine::new(config)?;
engine.contract_network(tensors, contraction_sequence)
}
}
#[derive(Debug, Clone)]
pub struct NumaTopology {
pub num_nodes: usize,
pub cores_per_node: Vec<usize>,
pub memory_per_node: Vec<usize>,
}
impl Default for NumaTopology {
fn default() -> Self {
let num_cores = current_num_threads(); Self {
num_nodes: 1,
cores_per_node: vec![num_cores],
memory_per_node: vec![8 * 1024 * 1024 * 1024], }
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array;
#[test]
fn test_parallel_tensor_engine() {
let config = ParallelTensorConfig::default();
let engine =
ParallelTensorEngine::new(config).expect("should create parallel tensor engine");
let tensor1 = Array::zeros(IxDyn(&[2, 2]));
let tensor2 = Array::zeros(IxDyn(&[2, 2]));
let tensors = vec![tensor1, tensor2];
let contraction = ContractionPair {
tensor1_id: 0,
tensor2_id: 1,
tensor1_indices: vec![1],
tensor2_indices: vec![0],
};
let result = engine.contract_network(&tensors, &[contraction]);
assert!(result.is_ok());
}
#[test]
fn test_work_queue() {
let work_unit = TensorWorkUnit {
id: 0,
input_tensors: vec![0, 1],
output_tensor: 2,
contraction_indices: vec![vec![0], vec![1]],
estimated_cost: 100.0,
memory_requirement: 1024,
dependencies: HashSet::new(),
priority: 1,
};
let config = ParallelTensorConfig::default();
let queue = TensorWorkQueue::new(vec![work_unit], config);
let work = queue.get_work();
assert!(work.is_some());
queue.complete_work(0);
assert!(queue.is_complete());
}
#[test]
fn test_parallel_strategies() {
let tensor1 = Array::ones(IxDyn(&[2, 2]));
let tensor2 = Array::ones(IxDyn(&[2, 2]));
let tensors = vec![tensor1, tensor2];
let contraction = ContractionPair {
tensor1_id: 0,
tensor2_id: 1,
tensor1_indices: vec![1],
tensor2_indices: vec![0],
};
let result = strategies::work_stealing_contraction(&tensors, &[contraction], 2);
assert!(result.is_ok());
}
}