use crate::common::{BiasCorrection, ParameterUpdate, StateMemoryStats};
use scirs2_core::parallel_ops::*; use std::collections::HashMap;
use std::sync::{Arc, Mutex, RwLock};
use trustformers_core::errors::{Result, TrustformersError};
use trustformers_core::tensor::Tensor;
use trustformers_core::traits::Optimizer;
#[derive(Debug, Clone)]
pub struct ParallelConfig {
pub num_threads: usize,
pub min_params_per_thread: usize,
pub enable_work_stealing: bool,
pub numa_aware: bool,
pub chunk_size: usize,
pub lock_free: bool,
}
impl Default for ParallelConfig {
fn default() -> Self {
Self {
num_threads: 0, min_params_per_thread: 1000,
enable_work_stealing: true,
numa_aware: false,
chunk_size: 1024,
lock_free: true,
}
}
}
impl ParallelConfig {
pub fn cpu_optimized() -> Self {
Self {
num_threads: num_cpus::get(),
chunk_size: 512,
enable_work_stealing: true,
..Default::default()
}
}
pub fn large_model() -> Self {
Self {
num_threads: num_cpus::get(),
min_params_per_thread: 10000,
chunk_size: 4096,
numa_aware: true,
..Default::default()
}
}
pub fn memory_bound() -> Self {
Self {
num_threads: (num_cpus::get() / 2).max(1),
chunk_size: 2048,
numa_aware: true,
..Default::default()
}
}
pub fn effective_num_threads(&self) -> usize {
if self.num_threads == 0 {
num_cpus::get()
} else {
self.num_threads
}
}
}
#[derive(Debug)]
pub struct ParallelOptimizerState {
parameter_states: RwLock<HashMap<String, Arc<Mutex<ParameterState>>>>,
global_step: Arc<std::sync::atomic::AtomicUsize>,
config: ParallelConfig,
}
#[derive(Debug)]
pub struct ParameterState {
pub momentum: Vec<f32>,
pub variance: Vec<f32>,
pub step: usize,
pub last_update: std::time::Instant,
}
impl ParameterState {
fn new(size: usize) -> Self {
Self {
momentum: vec![0.0; size],
variance: vec![0.0; size],
step: 0,
last_update: std::time::Instant::now(),
}
}
}
impl ParallelOptimizerState {
pub fn new(config: ParallelConfig) -> Self {
Self {
parameter_states: RwLock::new(HashMap::new()),
global_step: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
config,
}
}
pub fn get_or_create_state(&self, param_id: String, size: usize) -> Arc<Mutex<ParameterState>> {
{
let states = self
.parameter_states
.read()
.expect("parameter_states lock should not be poisoned");
if let Some(state) = states.get(¶m_id) {
return state.clone();
}
}
let mut states = self
.parameter_states
.write()
.expect("parameter_states lock should not be poisoned");
if let Some(state) = states.get(¶m_id) {
return state.clone();
}
let new_state = Arc::new(Mutex::new(ParameterState::new(size)));
states.insert(param_id, new_state.clone());
new_state
}
pub fn step(&self) {
self.global_step.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
pub fn get_step(&self) -> usize {
self.global_step.load(std::sync::atomic::Ordering::Relaxed)
}
pub fn memory_usage(&self) -> StateMemoryStats {
let states = self
.parameter_states
.read()
.expect("parameter_states lock should not be poisoned");
let mut total_momentum = 0;
let mut total_variance = 0;
let num_params = states.len();
for state_arc in states.values() {
if let Ok(state) = state_arc.try_lock() {
total_momentum += state.momentum.len();
total_variance += state.variance.len();
}
}
StateMemoryStats {
momentum_elements: total_momentum,
variance_elements: total_variance,
third_moment_elements: 0,
total_bytes: (total_momentum + total_variance) * std::mem::size_of::<f32>(),
num_parameters: num_params,
}
}
pub fn clear(&self) {
let mut states = self
.parameter_states
.write()
.expect("parameter_states lock should not be poisoned");
states.clear();
self.global_step.store(0, std::sync::atomic::Ordering::Relaxed);
}
}
#[derive(Debug)]
pub struct ParallelAdam {
lr: f32,
betas: (f32, f32),
eps: f32,
weight_decay: f32,
state: ParallelOptimizerState,
}
impl ParallelAdam {
pub fn new(lr: f32, betas: (f32, f32), eps: f32, weight_decay: f32) -> Self {
Self::with_config(lr, betas, eps, weight_decay, ParallelConfig::default())
}
pub fn with_config(
lr: f32,
betas: (f32, f32),
eps: f32,
weight_decay: f32,
config: ParallelConfig,
) -> Self {
Self {
lr,
betas,
eps,
weight_decay,
state: ParallelOptimizerState::new(config),
}
}
pub fn update_parallel(&self, updates: Vec<(String, &mut [f32], &[f32])>) -> Result<()> {
let _chunk_size = self.state.config.chunk_size;
let min_params = self.state.config.min_params_per_thread;
if updates.len() < min_params || !self.should_parallelize(&updates) {
return self.update_sequential(updates);
}
let results: Result<Vec<()>> = updates
.into_par_iter()
.with_min_len(1)
.map(|(param_id, param, grad)| self.update_single_parameter(param_id, param, grad))
.collect();
results.map(|_| ())
}
fn update_sequential(&self, updates: Vec<(String, &mut [f32], &[f32])>) -> Result<()> {
for (param_id, param, grad) in updates {
self.update_single_parameter(param_id, param, grad)?;
}
Ok(())
}
fn update_single_parameter(
&self,
param_id: String,
param: &mut [f32],
grad: &[f32],
) -> Result<()> {
if param.len() != grad.len() {
return Err(TrustformersError::tensor_op_error(
"Parameter and gradient size mismatch",
"update_single_parameter",
));
}
let size = param.len();
let state_arc = self.state.get_or_create_state(param_id, size);
let chunk_size = self.state.config.chunk_size;
let mut param_state = state_arc.lock().expect("Parallel optimizer state lock poisoned");
param_state.step += 1;
param_state.last_update = std::time::Instant::now();
let step = param_state.step;
let (bias_correction1, bias_correction2) =
BiasCorrection::compute_adam_corrections(self.betas.0, self.betas.1, step);
let should_parallelize = size >= chunk_size * 2 && self.state.config.num_threads > 1;
if should_parallelize {
let ParameterState {
ref mut momentum,
ref mut variance,
..
} = *param_state;
self.update_parameter_parallel(
param,
grad,
momentum,
variance,
bias_correction1,
bias_correction2,
chunk_size,
);
} else {
let ParameterState {
ref mut momentum,
ref mut variance,
..
} = *param_state;
self.update_parameter_sequential(
param,
grad,
momentum,
variance,
bias_correction1,
bias_correction2,
);
}
Ok(())
}
fn update_parameter_parallel(
&self,
param: &mut [f32],
grad: &[f32],
momentum: &mut [f32],
variance: &mut [f32],
bias_correction1: f32,
bias_correction2: f32,
chunk_size: usize,
) {
param
.par_chunks_mut(chunk_size)
.zip(grad.par_chunks(chunk_size))
.zip(momentum.par_chunks_mut(chunk_size))
.zip(variance.par_chunks_mut(chunk_size))
.for_each(|(((p_chunk, g_chunk), m_chunk), v_chunk)| {
self.process_chunk(
p_chunk,
g_chunk,
m_chunk,
v_chunk,
bias_correction1,
bias_correction2,
);
});
}
fn update_parameter_sequential(
&self,
param: &mut [f32],
grad: &[f32],
momentum: &mut [f32],
variance: &mut [f32],
bias_correction1: f32,
bias_correction2: f32,
) {
self.process_chunk(
param,
grad,
momentum,
variance,
bias_correction1,
bias_correction2,
);
}
#[inline]
fn process_chunk(
&self,
param_chunk: &mut [f32],
grad_chunk: &[f32],
momentum_chunk: &mut [f32],
variance_chunk: &mut [f32],
bias_correction1: f32,
bias_correction2: f32,
) {
let len = param_chunk
.len()
.min(grad_chunk.len())
.min(momentum_chunk.len())
.min(variance_chunk.len());
for i in 0..len {
let grad_val = grad_chunk[i] + self.weight_decay * param_chunk[i];
ParameterUpdate::update_ema(&mut momentum_chunk[i], grad_val, self.betas.0);
ParameterUpdate::update_ema(&mut variance_chunk[i], grad_val * grad_val, self.betas.1);
let m_hat = momentum_chunk[i] / bias_correction1;
let v_hat = variance_chunk[i] / bias_correction2;
ParameterUpdate::adam_update(&mut param_chunk[i], self.lr, m_hat, v_hat, self.eps);
}
}
fn should_parallelize(&self, updates: &[(String, &mut [f32], &[f32])]) -> bool {
let total_elements: usize = updates.iter().map(|(_, param, _)| param.len()).sum();
let num_threads = self.state.config.effective_num_threads();
total_elements >= self.state.config.min_params_per_thread * num_threads
}
pub fn parallel_stats(&self) -> ParallelStats {
let memory_stats = self.state.memory_usage();
let num_threads = self.state.config.effective_num_threads();
ParallelStats {
num_threads,
memory_stats,
config: self.state.config.clone(),
current_step: self.state.get_step(),
}
}
pub fn configure_thread_pool(&self) -> Result<()> {
let num_threads = self.state.config.effective_num_threads();
ThreadPoolBuilder::new().num_threads(num_threads).build_global().map_err(|e| {
TrustformersError::tensor_op_error(
&format!("Failed to configure thread pool: {}", e),
"configure_thread_pool",
)
})?;
Ok(())
}
}
impl Optimizer for ParallelAdam {
fn update(&mut self, parameter: &mut Tensor, grad: &Tensor) -> Result<()> {
match (parameter, grad) {
(Tensor::F32(param), Tensor::F32(grad_arr)) => {
let param_id = format!("{:p}", param.as_ptr());
self.update_single_parameter(
param_id,
param.as_slice_mut().expect("array must have contiguous layout"),
grad_arr.as_slice().expect("array must have contiguous layout"),
)
},
_ => Err(TrustformersError::tensor_op_error(
"Unsupported tensor types for ParallelAdam",
"update",
)),
}
}
fn zero_grad(&mut self) {
}
fn step(&mut self) {
self.state.step();
}
fn get_lr(&self) -> f32 {
self.lr
}
fn set_lr(&mut self, lr: f32) {
self.lr = lr;
}
}
#[derive(Debug, Clone)]
pub struct ParallelStats {
pub num_threads: usize,
pub memory_stats: StateMemoryStats,
pub config: ParallelConfig,
pub current_step: usize,
}
impl ParallelStats {
pub fn theoretical_speedup(&self, _sequential_time_ms: f64) -> f64 {
let parallel_fraction = 0.95; let num_threads = self.num_threads as f64;
1.0 / ((1.0 - parallel_fraction) + (parallel_fraction / num_threads))
}
pub fn optimization_suggestions(&self) -> Vec<String> {
let mut suggestions = Vec::new();
if self.num_threads == 1 {
suggestions.push(
"Consider increasing number of threads for better parallelization".to_string(),
);
}
if self.num_threads > num_cpus::get() {
suggestions.push("Number of threads exceeds CPU cores; consider reducing".to_string());
}
if self.config.chunk_size < 256 {
suggestions
.push("Small chunk size may cause overhead; consider increasing".to_string());
}
if self.config.chunk_size > 8192 {
suggestions.push("Large chunk size may reduce parallelization efficiency".to_string());
}
if !self.config.enable_work_stealing {
suggestions.push("Enable work stealing for better load balancing".to_string());
}
if suggestions.is_empty() {
suggestions.push("Parallel configuration appears optimal".to_string());
}
suggestions
}
}
pub trait BatchUpdate {
fn update_batch(&mut self, batch: Vec<(&mut Tensor, &Tensor)>) -> Result<()>;
}
impl BatchUpdate for ParallelAdam {
fn update_batch(&mut self, batch: Vec<(&mut Tensor, &Tensor)>) -> Result<()> {
let mut updates = Vec::new();
for (param, grad) in batch {
match (param, grad) {
(Tensor::F32(p), Tensor::F32(g)) => {
let param_id = format!("{:p}", p.as_ptr());
updates.push((
param_id,
p.as_slice_mut().expect("array must have contiguous layout"),
g.as_slice().expect("array must have contiguous layout"),
));
},
_ => {
return Err(TrustformersError::tensor_op_error(
"Unsupported tensor types",
"update_batch",
))
},
}
}
self.update_parallel(updates)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parallel_config() {
let config = ParallelConfig::default();
assert_eq!(config.num_threads, 0); assert!(config.enable_work_stealing);
let cpu_config = ParallelConfig::cpu_optimized();
assert_eq!(cpu_config.num_threads, num_cpus::get());
let effective_threads = config.effective_num_threads();
assert!(effective_threads > 0);
assert_eq!(effective_threads, num_cpus::get());
}
#[test]
fn test_parallel_optimizer_state() {
let config = ParallelConfig::default();
let state = ParallelOptimizerState::new(config);
assert_eq!(state.get_step(), 0);
state.step();
assert_eq!(state.get_step(), 1);
let param_state = state.get_or_create_state("test_param".to_string(), 100);
let locked_state = param_state.lock().expect("Parallel optimizer state lock poisoned");
assert_eq!(locked_state.momentum.len(), 100);
assert_eq!(locked_state.variance.len(), 100);
}
#[test]
fn test_parallel_adam() {
let optimizer = ParallelAdam::new(1e-3, (0.9, 0.999), 1e-8, 0.01);
assert_eq!(optimizer.get_lr(), 1e-3);
assert_eq!(optimizer.betas, (0.9, 0.999));
let stats = optimizer.parallel_stats();
assert!(stats.num_threads > 0);
assert_eq!(stats.current_step, 0);
}
#[test]
fn test_should_parallelize() {
let config = ParallelConfig {
min_params_per_thread: 1000,
num_threads: 4,
..Default::default()
};
let optimizer = ParallelAdam::with_config(1e-3, (0.9, 0.999), 1e-8, 0.01, config);
let mut small_params = [0.0; 100];
let small_grads = [0.0; 100];
let small_updates = vec![(
"param1".to_string(),
&mut small_params[..],
&small_grads[..],
)];
assert!(!optimizer.should_parallelize(&small_updates));
let mut large_params = [0.0; 5000];
let large_grads = [0.0; 5000];
let large_updates = vec![(
"param1".to_string(),
&mut large_params[..],
&large_grads[..],
)];
assert!(optimizer.should_parallelize(&large_updates));
}
#[test]
fn test_parallel_stats() {
let optimizer = ParallelAdam::new(1e-3, (0.9, 0.999), 1e-8, 0.01);
let stats = optimizer.parallel_stats();
let speedup = stats.theoretical_speedup(1000.0);
assert!(speedup > 1.0);
assert!(speedup <= stats.num_threads as f64);
let suggestions = stats.optimization_suggestions();
assert!(!suggestions.is_empty());
}
#[test]
fn test_memory_usage() {
let config = ParallelConfig::default();
let state = ParallelOptimizerState::new(config);
state.get_or_create_state("param1".to_string(), 1000);
state.get_or_create_state("param2".to_string(), 2000);
let memory_stats = state.memory_usage();
assert_eq!(memory_stats.num_parameters, 2);
assert_eq!(memory_stats.momentum_elements, 3000);
assert_eq!(memory_stats.variance_elements, 3000);
}
}