use serde::{Deserialize, Serialize};
use std::sync::atomic::{AtomicBool, AtomicU32, AtomicU64, Ordering};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
pub enum SyncMode {
Cooperative,
SoftwareBarrier,
#[default]
MultiLaunch,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ReductionOp {
Sum,
Product,
Max,
Min,
Count,
All,
Any,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PhaseState {
Pending,
Running,
Complete,
Failed,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ReductionConfig {
pub sync_mode: SyncMode,
pub num_phases: u32,
pub block_size: u32,
pub grid_size: u32,
pub convergence_check: bool,
pub convergence_threshold: f64,
}
impl Default for ReductionConfig {
fn default() -> Self {
Self {
sync_mode: SyncMode::MultiLaunch,
num_phases: 2,
block_size: 256,
grid_size: 1024,
convergence_check: false,
convergence_threshold: 1e-6,
}
}
}
pub struct InterPhaseReduction<T> {
config: ReductionConfig,
input_size: usize,
phase_buffers: Vec<Vec<T>>,
current_phase: AtomicU32,
phase_states: Vec<AtomicU32>,
is_complete: AtomicBool,
convergence_value: AtomicU64,
}
impl<T: Default + Clone + Copy> InterPhaseReduction<T> {
pub fn new(input_size: usize, sync_mode: SyncMode) -> Self {
Self::with_config(
input_size,
ReductionConfig {
sync_mode,
..Default::default()
},
)
}
pub fn with_config(input_size: usize, config: ReductionConfig) -> Self {
let num_phases = config.num_phases as usize;
let mut phase_buffers = Vec::with_capacity(num_phases);
let mut size = input_size;
for _ in 0..num_phases {
phase_buffers.push(vec![T::default(); size]);
size = size.div_ceil(config.block_size as usize);
size = size.max(1);
}
let phase_states: Vec<_> = (0..num_phases)
.map(|_| AtomicU32::new(PhaseState::Pending as u32))
.collect();
Self {
config,
input_size,
phase_buffers,
current_phase: AtomicU32::new(0),
phase_states,
is_complete: AtomicBool::new(false),
convergence_value: AtomicU64::new(0),
}
}
pub fn config(&self) -> &ReductionConfig {
&self.config
}
pub fn input_size(&self) -> usize {
self.input_size
}
pub fn current_phase(&self) -> u32 {
self.current_phase.load(Ordering::Relaxed)
}
pub fn phase_start(&self, phase: u32) -> Result<(), ReductionError> {
if phase >= self.config.num_phases {
return Err(ReductionError::InvalidPhase {
phase,
max_phases: self.config.num_phases,
});
}
let expected = PhaseState::Pending as u32;
let new = PhaseState::Running as u32;
match self.phase_states[phase as usize].compare_exchange(
expected,
new,
Ordering::SeqCst,
Ordering::SeqCst,
) {
Ok(_) => {
self.current_phase.store(phase, Ordering::Relaxed);
Ok(())
}
Err(current) => Err(ReductionError::InvalidPhaseState {
phase,
current: phase_state_from_u32(current),
}),
}
}
pub fn phase_complete(&self, phase: u32) -> Result<(), ReductionError> {
if phase >= self.config.num_phases {
return Err(ReductionError::InvalidPhase {
phase,
max_phases: self.config.num_phases,
});
}
let expected = PhaseState::Running as u32;
let new = PhaseState::Complete as u32;
match self.phase_states[phase as usize].compare_exchange(
expected,
new,
Ordering::SeqCst,
Ordering::SeqCst,
) {
Ok(_) => {
if phase == self.config.num_phases - 1 {
self.is_complete.store(true, Ordering::Release);
}
Ok(())
}
Err(current) => Err(ReductionError::InvalidPhaseState {
phase,
current: phase_state_from_u32(current),
}),
}
}
pub fn phase_failed(&self, phase: u32) {
if (phase as usize) < self.phase_states.len() {
self.phase_states[phase as usize].store(PhaseState::Failed as u32, Ordering::Release);
}
}
pub fn phase_state(&self, phase: u32) -> PhaseState {
if phase >= self.config.num_phases {
return PhaseState::Pending;
}
phase_state_from_u32(self.phase_states[phase as usize].load(Ordering::Acquire))
}
pub fn is_complete(&self) -> bool {
self.is_complete.load(Ordering::Acquire)
}
pub fn get_buffer(&self, phase: u32) -> Option<&[T]> {
self.phase_buffers.get(phase as usize).map(|v| v.as_slice())
}
pub fn get_buffer_mut(&mut self, phase: u32) -> Option<&mut [T]> {
self.phase_buffers
.get_mut(phase as usize)
.map(|v| v.as_mut_slice())
}
pub fn buffer_size(&self, phase: u32) -> usize {
self.phase_buffers
.get(phase as usize)
.map(|v| v.len())
.unwrap_or(0)
}
pub fn set_convergence(&self, value: f64) {
self.convergence_value
.store(value.to_bits(), Ordering::Release);
}
pub fn convergence(&self) -> f64 {
f64::from_bits(self.convergence_value.load(Ordering::Acquire))
}
pub fn is_converged(&self) -> bool {
if !self.config.convergence_check {
return false;
}
self.convergence() < self.config.convergence_threshold
}
pub fn reset(&mut self) {
self.current_phase.store(0, Ordering::Relaxed);
self.is_complete.store(false, Ordering::Release);
self.convergence_value.store(0, Ordering::Release);
for state in &self.phase_states {
state.store(PhaseState::Pending as u32, Ordering::Release);
}
for buffer in &mut self.phase_buffers {
for item in buffer.iter_mut() {
*item = T::default();
}
}
}
}
fn phase_state_from_u32(value: u32) -> PhaseState {
match value {
0 => PhaseState::Pending,
1 => PhaseState::Running,
2 => PhaseState::Complete,
_ => PhaseState::Failed,
}
}
#[derive(Debug, thiserror::Error)]
pub enum ReductionError {
#[error("Invalid phase {phase}, max phases: {max_phases}")]
InvalidPhase {
phase: u32,
max_phases: u32,
},
#[error("Invalid phase state for phase {phase}: {current:?}")]
InvalidPhaseState {
phase: u32,
current: PhaseState,
},
#[error("Reduction not complete, current phase: {current_phase}")]
NotComplete {
current_phase: u32,
},
#[error("Buffer size mismatch: expected {expected}, got {actual}")]
BufferSizeMismatch {
expected: usize,
actual: usize,
},
}
pub struct GlobalReduction {
pub total_participants: u32,
pub completed: AtomicU32,
pub all_complete: AtomicBool,
pub partial_results: Vec<AtomicU64>,
}
impl GlobalReduction {
pub fn new(participants: u32) -> Self {
let partial_results = (0..participants).map(|_| AtomicU64::new(0)).collect();
Self {
total_participants: participants,
completed: AtomicU32::new(0),
all_complete: AtomicBool::new(false),
partial_results,
}
}
pub fn submit(&self, participant_id: u32, value: f64) -> bool {
if participant_id >= self.total_participants {
return false;
}
self.partial_results[participant_id as usize].store(value.to_bits(), Ordering::Release);
let count = self.completed.fetch_add(1, Ordering::AcqRel) + 1;
if count == self.total_participants {
self.all_complete.store(true, Ordering::Release);
return true;
}
false
}
pub fn is_complete(&self) -> bool {
self.all_complete.load(Ordering::Acquire)
}
pub fn completion_count(&self) -> u32 {
self.completed.load(Ordering::Acquire)
}
pub fn finalize_sum(&self) -> Option<f64> {
if !self.is_complete() {
return None;
}
let sum: f64 = self
.partial_results
.iter()
.map(|v| f64::from_bits(v.load(Ordering::Acquire)))
.sum();
Some(sum)
}
pub fn finalize_max(&self) -> Option<f64> {
if !self.is_complete() {
return None;
}
self.partial_results
.iter()
.map(|v| f64::from_bits(v.load(Ordering::Acquire)))
.max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
}
pub fn finalize_min(&self) -> Option<f64> {
if !self.is_complete() {
return None;
}
self.partial_results
.iter()
.map(|v| f64::from_bits(v.load(Ordering::Acquire)))
.min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
}
pub fn reset(&self) {
self.completed.store(0, Ordering::Release);
self.all_complete.store(false, Ordering::Release);
for partial in &self.partial_results {
partial.store(0, Ordering::Release);
}
}
}
pub struct CooperativeBarrier {
expected: u32,
arrived: AtomicU32,
generation: AtomicU32,
}
impl CooperativeBarrier {
pub fn new(expected: u32) -> Self {
Self {
expected,
arrived: AtomicU32::new(0),
generation: AtomicU32::new(0),
}
}
pub fn wait(&self) -> u32 {
let generation_num = self.generation.load(Ordering::Acquire);
let arrived = self.arrived.fetch_add(1, Ordering::AcqRel) + 1;
if arrived == self.expected {
self.arrived.store(0, Ordering::Release);
self.generation.fetch_add(1, Ordering::Release);
} else {
while self.generation.load(Ordering::Acquire) == generation_num {
std::hint::spin_loop();
}
}
generation_num
}
pub fn reset(&self) {
self.arrived.store(0, Ordering::Release);
self.generation.store(0, Ordering::Release);
}
}
pub struct ReductionBuilder {
config: ReductionConfig,
}
impl ReductionBuilder {
pub fn new() -> Self {
Self {
config: ReductionConfig::default(),
}
}
pub fn sync_mode(mut self, mode: SyncMode) -> Self {
self.config.sync_mode = mode;
self
}
pub fn phases(mut self, num: u32) -> Self {
self.config.num_phases = num;
self
}
pub fn block_size(mut self, size: u32) -> Self {
self.config.block_size = size;
self
}
pub fn grid_size(mut self, size: u32) -> Self {
self.config.grid_size = size;
self
}
pub fn with_convergence(mut self, threshold: f64) -> Self {
self.config.convergence_check = true;
self.config.convergence_threshold = threshold;
self
}
pub fn build(self) -> ReductionConfig {
self.config
}
pub fn build_reduction<T: Default + Clone + Copy>(
self,
input_size: usize,
) -> InterPhaseReduction<T> {
InterPhaseReduction::with_config(input_size, self.config)
}
}
impl Default for ReductionBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_inter_phase_reduction() {
let reduction = InterPhaseReduction::<f64>::new(1024, SyncMode::MultiLaunch);
assert_eq!(reduction.current_phase(), 0);
assert!(!reduction.is_complete());
reduction.phase_start(0).unwrap();
assert_eq!(reduction.phase_state(0), PhaseState::Running);
reduction.phase_complete(0).unwrap();
assert_eq!(reduction.phase_state(0), PhaseState::Complete);
reduction.phase_start(1).unwrap();
reduction.phase_complete(1).unwrap();
assert!(reduction.is_complete());
}
#[test]
fn test_phase_buffers() {
let mut reduction = InterPhaseReduction::<f64>::with_config(
1000,
ReductionConfig {
block_size: 256,
num_phases: 3,
..Default::default()
},
);
assert_eq!(reduction.buffer_size(0), 1000);
assert!(reduction.buffer_size(1) < reduction.buffer_size(0));
if let Some(buf) = reduction.get_buffer_mut(0) {
buf[0] = 42.0;
}
assert_eq!(reduction.get_buffer(0).unwrap()[0], 42.0);
}
#[test]
fn test_global_reduction() {
let reduction = GlobalReduction::new(4);
assert!(!reduction.is_complete());
reduction.submit(0, 1.0);
reduction.submit(1, 2.0);
reduction.submit(2, 3.0);
assert!(!reduction.is_complete());
assert_eq!(reduction.completion_count(), 3);
reduction.submit(3, 4.0);
assert!(reduction.is_complete());
assert_eq!(reduction.finalize_sum(), Some(10.0));
}
#[test]
fn test_cooperative_barrier() {
use std::sync::Arc;
use std::thread;
let barrier = Arc::new(CooperativeBarrier::new(3));
let handles: Vec<_> = (0..3)
.map(|_| {
let b = barrier.clone();
thread::spawn(move || b.wait())
})
.collect();
for h in handles {
let generation_num = h.join().unwrap();
assert_eq!(generation_num, 0);
}
}
#[test]
fn test_reduction_builder() {
let config = ReductionBuilder::new()
.sync_mode(SyncMode::Cooperative)
.phases(3)
.block_size(512)
.with_convergence(1e-8)
.build();
assert_eq!(config.sync_mode, SyncMode::Cooperative);
assert_eq!(config.num_phases, 3);
assert_eq!(config.block_size, 512);
assert!(config.convergence_check);
}
#[test]
fn test_convergence_tracking() {
let reduction = InterPhaseReduction::<f64>::with_config(
100,
ReductionConfig {
convergence_check: true,
convergence_threshold: 1e-6,
..Default::default()
},
);
reduction.set_convergence(1e-3);
assert!(!reduction.is_converged());
reduction.set_convergence(1e-8);
assert!(reduction.is_converged());
}
}