use crate::error::{ModelError, ModelResult};
use rayon::prelude::*;
use scirs2_core::ndarray::{Array2, ArrayView2, ArrayView3};
use tracing::{debug, trace};
#[derive(Debug, Clone, Copy)]
pub struct ParallelConfig {
pub min_heads_for_parallel: usize,
pub num_threads: usize,
pub enable_simd: bool,
}
impl Default for ParallelConfig {
fn default() -> Self {
Self {
min_heads_for_parallel: 4,
num_threads: 0, enable_simd: true,
}
}
}
pub struct MultiHeadExecutor {
config: ParallelConfig,
}
impl MultiHeadExecutor {
pub fn new(config: ParallelConfig) -> Self {
debug!(
"Created MultiHeadExecutor: min_heads={}, threads={}",
config.min_heads_for_parallel,
if config.num_threads == 0 {
"auto".to_string()
} else {
config.num_threads.to_string()
}
);
Self { config }
}
pub fn par_map<F, R>(&self, num_heads: usize, f: F) -> ModelResult<Vec<R>>
where
F: Fn(usize) -> ModelResult<R> + Sync + Send,
R: Send,
{
trace!("Parallel execution across {} heads", num_heads);
if num_heads < self.config.min_heads_for_parallel {
(0..num_heads).map(f).collect()
} else {
(0..num_heads).into_par_iter().map(f).collect()
}
}
pub fn split_heads_and_process<F>(
&self,
input: &ArrayView2<f32>,
num_heads: usize,
head_dim: usize,
f: F,
) -> ModelResult<Array2<f32>>
where
F: Fn(&ArrayView2<f32>) -> ModelResult<Array2<f32>> + Sync + Send,
{
let (seq_len, d_model) = input.dim();
if d_model != num_heads * head_dim {
return Err(ModelError::dimension_mismatch(
"d_model vs num_heads × head_dim",
num_heads * head_dim,
d_model,
));
}
trace!(
"Splitting {} × {} into {} heads of dim {}",
seq_len,
d_model,
num_heads,
head_dim
);
let head_outputs: Vec<Array2<f32>> = self.par_map(num_heads, |head_idx| {
let start_col = head_idx * head_dim;
let end_col = start_col + head_dim;
let head_input = input.slice(s![.., start_col..end_col]);
f(&head_input)
})?;
self.concat_heads(&head_outputs)
}
fn concat_heads(&self, heads: &[Array2<f32>]) -> ModelResult<Array2<f32>> {
if heads.is_empty() {
return Err(ModelError::invalid_config("No heads to concatenate"));
}
let (seq_len, head_dim) = heads[0].dim();
let num_heads = heads.len();
let d_model = num_heads * head_dim;
let mut output = Array2::zeros((seq_len, d_model));
for (head_idx, head_output) in heads.iter().enumerate() {
let start_col = head_idx * head_dim;
let end_col = start_col + head_dim;
output
.slice_mut(s![.., start_col..end_col])
.assign(head_output);
}
Ok(output)
}
}
pub fn parallel_multi_head_projection(
input: &ArrayView2<f32>,
weights: &ArrayView3<f32>,
config: &ParallelConfig,
) -> ModelResult<Vec<Array2<f32>>> {
let (seq_len, d_model) = input.dim();
let (num_heads, weight_d_model, _head_dim) = weights.dim();
if d_model != weight_d_model {
return Err(ModelError::dimension_mismatch(
"input d_model vs weight d_model",
weight_d_model,
d_model,
));
}
trace!(
"Parallel projection: {} × {} with {} heads",
seq_len,
d_model,
num_heads
);
let process_head = |head_idx: usize| -> ModelResult<Array2<f32>> {
let head_weight = weights.slice(s![head_idx, .., ..]);
Ok(input.dot(&head_weight))
};
if num_heads < config.min_heads_for_parallel {
(0..num_heads).map(process_head).collect()
} else {
(0..num_heads).into_par_iter().map(process_head).collect()
}
}
pub fn parallel_combine_heads(
head_outputs: &[Array2<f32>],
output_weight: &ArrayView2<f32>,
_config: &ParallelConfig,
) -> ModelResult<Array2<f32>> {
if head_outputs.is_empty() {
return Err(ModelError::invalid_config("No head outputs to combine"));
}
let (seq_len, head_dim) = head_outputs[0].dim();
let num_heads = head_outputs.len();
let d_model = num_heads * head_dim;
let mut concatenated = Array2::zeros((seq_len, d_model));
for (head_idx, head_output) in head_outputs.iter().enumerate() {
let start_col = head_idx * head_dim;
let end_col = start_col + head_dim;
concatenated
.slice_mut(s![.., start_col..end_col])
.assign(head_output);
}
Ok(concatenated.dot(output_weight))
}
pub fn parallel_attention_scores(
queries: &[Array2<f32>],
keys: &[Array2<f32>],
scale: f32,
config: &ParallelConfig,
) -> ModelResult<Vec<Array2<f32>>> {
if queries.len() != keys.len() {
return Err(ModelError::dimension_mismatch(
"number of query and key heads",
keys.len(),
queries.len(),
));
}
let num_heads = queries.len();
let compute_scores = |head_idx: usize| -> ModelResult<Array2<f32>> {
let q = &queries[head_idx];
let k = &keys[head_idx];
let scores = q.dot(&k.t()) * scale;
Ok(scores)
};
if num_heads < config.min_heads_for_parallel {
(0..num_heads).map(compute_scores).collect()
} else {
(0..num_heads).into_par_iter().map(compute_scores).collect()
}
}
pub fn parallel_softmax(
scores: &[Array2<f32>],
config: &ParallelConfig,
) -> ModelResult<Vec<Array2<f32>>> {
let num_heads = scores.len();
let apply_softmax = |head_idx: usize| -> ModelResult<Array2<f32>> {
let head_scores = &scores[head_idx];
let mut output = head_scores.clone();
for mut row in output.rows_mut() {
let max = row.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
row.mapv_inplace(|x| (x - max).exp());
let sum: f32 = row.sum();
if sum > 0.0 {
row.mapv_inplace(|x| x / sum);
}
}
Ok(output)
};
if num_heads < config.min_heads_for_parallel {
(0..num_heads).map(apply_softmax).collect()
} else {
(0..num_heads).into_par_iter().map(apply_softmax).collect()
}
}
use scirs2_core::ndarray::s;
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::{array, Array3};
#[test]
fn test_executor_creation() {
let config = ParallelConfig::default();
let executor = MultiHeadExecutor::new(config);
assert_eq!(executor.config.min_heads_for_parallel, 4);
}
#[test]
fn test_par_map_sequential() {
let config = ParallelConfig {
min_heads_for_parallel: 10,
..Default::default()
};
let executor = MultiHeadExecutor::new(config);
let results = executor.par_map(3, |i| Ok(i * 2)).expect("par_map failed");
assert_eq!(results, vec![0, 2, 4]);
}
#[test]
fn test_par_map_parallel() {
let config = ParallelConfig {
min_heads_for_parallel: 2,
..Default::default()
};
let executor = MultiHeadExecutor::new(config);
let results = executor.par_map(4, |i| Ok(i * 3)).expect("par_map failed");
assert_eq!(results, vec![0, 3, 6, 9]);
}
#[test]
fn test_split_heads_and_process() {
let config = ParallelConfig::default();
let executor = MultiHeadExecutor::new(config);
let input = array![[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]];
let output = executor
.split_heads_and_process(&input.view(), 2, 2, |head| Ok(head.to_owned()))
.expect("split_heads_and_process failed");
assert_eq!(output.dim(), (2, 4));
}
#[test]
fn test_parallel_multi_head_projection() {
let config = ParallelConfig::default();
let input = array![[1.0, 2.0], [3.0, 4.0]];
let mut weights = Array3::zeros((2, 2, 3));
weights
.slice_mut(s![0, .., ..])
.assign(&array![[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]);
weights
.slice_mut(s![1, .., ..])
.assign(&array![[0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]);
let outputs = parallel_multi_head_projection(&input.view(), &weights.view(), &config)
.expect("projection failed");
assert_eq!(outputs.len(), 2);
assert_eq!(outputs[0].dim(), (2, 3));
}
#[test]
fn test_parallel_combine_heads() {
let config = ParallelConfig::default();
let head1 = array![[1.0, 2.0], [3.0, 4.0]];
let head2 = array![[5.0, 6.0], [7.0, 8.0]];
let heads = vec![head1, head2];
let output_weight = Array2::eye(4);
let combined =
parallel_combine_heads(&heads, &output_weight.view(), &config).expect("combine failed");
assert_eq!(combined.dim(), (2, 4));
}
#[test]
fn test_parallel_attention_scores() {
let config = ParallelConfig::default();
let q1 = array![[1.0, 0.0], [0.0, 1.0]];
let k1 = array![[1.0, 0.0], [0.0, 1.0]];
let queries = vec![q1.clone()];
let keys = vec![k1.clone()];
let scores =
parallel_attention_scores(&queries, &keys, 1.0, &config).expect("scores failed");
assert_eq!(scores.len(), 1);
assert_eq!(scores[0].dim(), (2, 2));
}
#[test]
fn test_parallel_softmax() {
let config = ParallelConfig::default();
let scores = vec![array![[1.0, 2.0], [3.0, 4.0]]];
let probs = parallel_softmax(&scores, &config).expect("softmax failed");
assert_eq!(probs.len(), 1);
assert_eq!(probs[0].dim(), (2, 2));
for row in probs[0].rows() {
let sum: f32 = row.sum();
assert!((sum - 1.0).abs() < 1e-5);
}
}
#[test]
fn test_dimension_mismatch() {
let config = ParallelConfig::default();
let input = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
let result =
MultiHeadExecutor::new(config)
.split_heads_and_process(&input.view(), 2, 2, |head| Ok(head.to_owned()));
assert!(result.is_err());
}
}