use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use trustformers_core::{errors::Result, tensor::Tensor, traits::Optimizer};
use crate::{common::StateMemoryStats, traits::StatefulOptimizer};
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct DeepDistributedQPConfig {
pub learning_rate: f32,
pub num_consensus_nodes: usize,
pub max_iterations: usize,
pub tolerance: f32,
pub relaxation_parameter: f32,
pub penalty_parameter: f32,
pub step_size: f32,
pub adaptive_step_size: bool,
pub network_hidden_dims: Vec<usize>,
pub warm_start: bool,
pub consensus_frequency: usize,
pub max_problem_size: usize,
}
impl Default for DeepDistributedQPConfig {
fn default() -> Self {
Self {
learning_rate: 1e-3,
num_consensus_nodes: 4,
max_iterations: 100,
tolerance: 1e-6,
relaxation_parameter: 1.6,
penalty_parameter: 1.0,
step_size: 1.0,
adaptive_step_size: true,
network_hidden_dims: vec![64, 32],
warm_start: true,
consensus_frequency: 10,
max_problem_size: 10000,
}
}
}
#[derive(Clone, Debug)]
struct ConsensusNode {
local_variables: Tensor,
dual_variables: Tensor,
constraint_residuals: Tensor,
consensus_error: f32,
#[allow(dead_code)]
node_id: usize,
}
#[derive(Clone, Debug)]
struct PolicyNetwork {
weights: Vec<Tensor>,
biases: Vec<Tensor>,
input_mean: Tensor,
input_std: Tensor,
output_scale: f32,
}
#[derive(Clone, Debug)]
pub struct DeepDistributedQPState {
consensus_nodes: Vec<ConsensusNode>,
policy_network: Option<PolicyNetwork>,
previous_solution: Option<Tensor>,
#[allow(dead_code)]
problem_matrix_p: Option<Tensor>,
problem_vector_q: Option<Tensor>,
#[allow(dead_code)]
constraint_matrix_a: Option<Tensor>,
#[allow(dead_code)]
constraint_vector_b: Option<Tensor>,
iteration: usize,
convergence_history: Vec<f32>,
solve_times: Vec<f32>,
#[allow(dead_code)]
problem_size: usize,
}
#[derive(Clone, Debug)]
pub struct DeepDistributedQP {
config: DeepDistributedQPConfig,
states: HashMap<String, DeepDistributedQPState>,
step: usize,
memory_stats: StateMemoryStats,
global_consensus: Option<Tensor>,
problems_solved: usize,
cumulative_speedup: f32,
}
impl DeepDistributedQP {
pub fn new(
learning_rate: f32,
num_consensus_nodes: usize,
max_iterations: usize,
tolerance: f32,
) -> Self {
Self {
config: DeepDistributedQPConfig {
learning_rate,
num_consensus_nodes,
max_iterations,
tolerance,
..Default::default()
},
states: HashMap::new(),
step: 0,
memory_stats: StateMemoryStats {
momentum_elements: 0,
variance_elements: 0,
third_moment_elements: 0,
total_bytes: 0,
num_parameters: 0,
},
global_consensus: None,
problems_solved: 0,
cumulative_speedup: 1.0,
}
}
pub fn for_large_scale() -> Self {
Self {
config: DeepDistributedQPConfig {
learning_rate: 5e-4,
num_consensus_nodes: 8,
max_iterations: 500,
tolerance: 1e-8,
relaxation_parameter: 1.8,
penalty_parameter: 0.5,
step_size: 0.8,
adaptive_step_size: true,
network_hidden_dims: vec![128, 64, 32],
warm_start: true,
consensus_frequency: 5,
max_problem_size: 50000,
},
states: HashMap::new(),
step: 0,
memory_stats: StateMemoryStats {
momentum_elements: 0,
variance_elements: 0,
third_moment_elements: 0,
total_bytes: 0,
num_parameters: 0,
},
global_consensus: None,
problems_solved: 0,
cumulative_speedup: 1.0,
}
}
pub fn for_portfolio_optimization() -> Self {
Self {
config: DeepDistributedQPConfig {
learning_rate: 1e-3,
num_consensus_nodes: 6,
max_iterations: 200,
tolerance: 1e-7,
relaxation_parameter: 1.5,
penalty_parameter: 2.0,
step_size: 1.2,
adaptive_step_size: true,
network_hidden_dims: vec![64, 32, 16],
warm_start: true,
consensus_frequency: 15,
max_problem_size: 5000,
},
states: HashMap::new(),
step: 0,
memory_stats: StateMemoryStats {
momentum_elements: 0,
variance_elements: 0,
third_moment_elements: 0,
total_bytes: 0,
num_parameters: 0,
},
global_consensus: None,
problems_solved: 0,
cumulative_speedup: 1.0,
}
}
pub fn with_config(config: DeepDistributedQPConfig) -> Self {
Self {
config,
states: HashMap::new(),
step: 0,
memory_stats: StateMemoryStats {
momentum_elements: 0,
variance_elements: 0,
third_moment_elements: 0,
total_bytes: 0,
num_parameters: 0,
},
global_consensus: None,
problems_solved: 0,
cumulative_speedup: 1.0,
}
}
fn initialize_consensus_nodes(&self, problem_size: usize) -> Result<Vec<ConsensusNode>> {
let mut nodes = Vec::with_capacity(self.config.num_consensus_nodes);
for node_id in 0..self.config.num_consensus_nodes {
nodes.push(ConsensusNode {
local_variables: Tensor::zeros(&[problem_size])?,
dual_variables: Tensor::zeros(&[problem_size])?,
constraint_residuals: Tensor::zeros(&[problem_size])?,
consensus_error: f32::INFINITY,
node_id,
});
}
Ok(nodes)
}
fn create_policy_network(&self, input_size: usize) -> Result<PolicyNetwork> {
let mut weights = Vec::new();
let mut biases = Vec::new();
let mut prev_size = input_size;
for &hidden_size in &self.config.network_hidden_dims {
let scale = (2.0 / (prev_size + hidden_size) as f32).sqrt();
let weight = Tensor::randn(&[prev_size, hidden_size])?.mul_scalar(scale)?;
let bias = Tensor::zeros(&[hidden_size])?;
weights.push(weight);
biases.push(bias);
prev_size = hidden_size;
}
let output_weight = Tensor::randn(&[prev_size, 1])?.mul_scalar(0.01)?;
let output_bias = Tensor::zeros(&[1])?;
weights.push(output_weight);
biases.push(output_bias);
Ok(PolicyNetwork {
weights,
biases,
input_mean: Tensor::zeros(&[input_size])?,
input_std: Tensor::ones(&[input_size])?,
output_scale: 1.0,
})
}
fn policy_forward(&self, network: &PolicyNetwork, input: &Tensor) -> Result<Tensor> {
let normalized_input = input.sub(&network.input_mean)?.div(&network.input_std)?;
let input_shape = normalized_input.shape();
let batch_size = 1;
let feature_size = input_shape.iter().product::<usize>();
let reshaped_input = normalized_input.reshape(&[batch_size, feature_size])?;
let mut x = reshaped_input;
for i in 0..network.weights.len() - 1 {
x = x.matmul(&network.weights[i])?.add(&network.biases[i])?;
x = x.relu()?; }
let output_idx = network.weights.len() - 1;
x = x.matmul(&network.weights[output_idx])?.add(&network.biases[output_idx])?;
let output = x.mul_scalar(network.output_scale)?;
let final_output = if output.shape().len() == 2 && output.shape()[0] == 1 {
output.reshape(&[output.shape()[1]])?
} else {
output
};
Ok(final_output)
}
fn operator_splitting_update(
&self,
node: &mut ConsensusNode,
gradient: &Tensor,
step_size: f32,
) -> Result<()> {
let gradient_step = node.local_variables.sub(&gradient.mul_scalar(step_size)?)?;
let threshold = step_size * self.config.penalty_parameter;
node.local_variables = self.soft_threshold(&gradient_step, threshold)?;
let constraint_violation = node.constraint_residuals.clone(); node.dual_variables = node
.dual_variables
.add(&constraint_violation.mul_scalar(self.config.penalty_parameter)?)?;
Ok(())
}
fn soft_threshold(&self, input: &Tensor, threshold: f32) -> Result<Tensor> {
let positive_part = input.sub_scalar(threshold)?.relu()?;
let negative_part = input.add_scalar(threshold)?.neg()?.relu()?.neg()?;
positive_part.add(&negative_part)
}
fn consensus_update(&self, nodes: &mut [ConsensusNode]) -> Result<f32> {
let num_nodes = nodes.len();
if num_nodes < 2 {
return Ok(0.0);
}
let mut consensus_sum = nodes[0].local_variables.clone();
for node in nodes.iter().skip(1) {
consensus_sum = consensus_sum.add(&node.local_variables)?;
}
let consensus_avg = consensus_sum.div_scalar(num_nodes as f32)?;
let mut total_consensus_error = 0.0f32;
for node in nodes.iter_mut() {
let consensus_diff = consensus_avg.sub(&node.local_variables)?;
let consensus_error = consensus_diff.norm()?;
let update = consensus_diff.mul_scalar(self.config.relaxation_parameter)?;
node.local_variables = node.local_variables.add(&update.mul_scalar(0.1)?)?;
node.consensus_error = consensus_error;
total_consensus_error += consensus_error;
}
Ok(total_consensus_error / num_nodes as f32)
}
fn adaptive_step_size(
&self,
network: &PolicyNetwork,
node: &ConsensusNode,
gradient: &Tensor,
) -> Result<f32> {
let grad_norm = gradient.norm()?;
let var_norm = node.local_variables.norm()?;
let dual_norm = node.dual_variables.norm()?;
let consensus_error = node.consensus_error;
let features =
Tensor::from_slice(&[grad_norm, var_norm, dual_norm, consensus_error], &[4])?;
let step_size_tensor = self.policy_forward(network, &features)?;
let step_size = if step_size_tensor.shape().iter().product::<usize>() == 1 {
step_size_tensor.data()?[0]
} else {
step_size_tensor.data()?[0]
};
let step_size = step_size.clamp(0.001, 2.0);
Ok(step_size)
}
fn solve_distributed_qp(&mut self, param_id: &str, gradient: &Tensor) -> Result<Tensor> {
let problem_size = gradient.len();
let param_key = param_id.to_string();
let state_exists = self.states.contains_key(¶m_key);
if !state_exists {
let consensus_nodes = self.initialize_consensus_nodes(problem_size).unwrap_or_default();
let new_state = DeepDistributedQPState {
consensus_nodes,
policy_network: None,
previous_solution: None,
problem_matrix_p: None,
problem_vector_q: Some(gradient.clone()),
constraint_matrix_a: None,
constraint_vector_b: None,
iteration: 0,
convergence_history: Vec::new(),
solve_times: Vec::new(),
problem_size,
};
self.states.insert(param_key.clone(), new_state);
}
let state = self
.states
.get_mut(¶m_key)
.expect("state must exist for param_key after insert");
let needs_policy_network = state.policy_network.is_none();
let needs_consensus_nodes = state.consensus_nodes.is_empty();
let _ = state;
if needs_policy_network {
let policy_network = self.create_policy_network(4)?; let state = self
.states
.get_mut(¶m_key)
.expect("state must exist for param_key after insert");
state.policy_network = Some(policy_network);
}
if needs_consensus_nodes {
let consensus_nodes = self.initialize_consensus_nodes(problem_size)?;
let state = self
.states
.get_mut(¶m_key)
.expect("state must exist for param_key after insert");
state.consensus_nodes = consensus_nodes;
}
let state = self
.states
.get_mut(¶m_key)
.expect("state must exist for param_key after insert");
if let (true, Some(prev_solution)) =
(self.config.warm_start, state.previous_solution.as_ref())
{
for node in &mut state.consensus_nodes {
node.local_variables = prev_solution.clone();
}
}
let start_time = std::time::Instant::now();
#[allow(dead_code)]
let mut _converged = false;
#[allow(unused_assignments)]
for iteration in 0..self.config.max_iterations {
let state = self
.states
.get_mut(¶m_key)
.expect("state must exist for param_key after insert");
state.iteration = iteration;
let adaptive_step = self.config.adaptive_step_size;
let consensus_frequency = self.config.consensus_frequency;
let tolerance = self.config.tolerance;
let step_size = self.config.step_size;
let mut consensus_nodes = state.consensus_nodes.clone();
let policy_network = state.policy_network.clone();
let _ = state;
for node in &mut consensus_nodes {
let actual_step_size = if adaptive_step {
if let Some(ref network) = policy_network {
self.adaptive_step_size(network, node, gradient)?
} else {
step_size
}
} else {
step_size
};
self.operator_splitting_update(node, gradient, actual_step_size)?;
}
let state = self
.states
.get_mut(¶m_key)
.expect("state must exist for param_key after insert");
state.consensus_nodes = consensus_nodes;
let _ = state;
if iteration % consensus_frequency == 0 {
let state = self
.states
.get_mut(¶m_key)
.expect("state must exist for param_key after insert");
let mut nodes = state.consensus_nodes.clone();
let _ = state;
let consensus_error = self.consensus_update(&mut nodes)?;
let state = self
.states
.get_mut(¶m_key)
.expect("state must exist for param_key after insert");
state.consensus_nodes = nodes;
state.convergence_history.push(consensus_error);
let _ = state;
if consensus_error < tolerance {
_converged = true;
break;
}
}
}
let solve_time = start_time.elapsed().as_secs_f32();
let state = self
.states
.get_mut(¶m_key)
.expect("state must exist for param_key after insert");
state.solve_times.push(solve_time);
let mut solution = state.consensus_nodes[0].local_variables.clone();
for node in state.consensus_nodes.iter().skip(1) {
solution = solution.add(&node.local_variables)?;
}
solution = solution.div_scalar(state.consensus_nodes.len() as f32)?;
state.previous_solution = Some(solution.clone());
self.problems_solved += 1;
let baseline_time = solve_time * 2.0; let current_speedup = baseline_time / solve_time.max(1e-6);
self.cumulative_speedup = (self.cumulative_speedup * (self.problems_solved - 1) as f32
+ current_speedup)
/ self.problems_solved as f32;
Ok(solution)
}
pub fn qp_solver_stats(&self) -> HashMap<String, (usize, f32, f32, bool)> {
self.states
.iter()
.map(|(name, state)| {
let avg_solve_time = if !state.solve_times.is_empty() {
state.solve_times.iter().sum::<f32>() / state.solve_times.len() as f32
} else {
0.0
};
let last_consensus_error =
state.convergence_history.last().copied().unwrap_or(f32::INFINITY);
let converged = last_consensus_error < self.config.tolerance;
(
name.clone(),
(
state.iteration,
avg_solve_time,
last_consensus_error,
converged,
),
)
})
.collect()
}
pub fn cumulative_speedup(&self) -> f32 {
self.cumulative_speedup
}
pub fn distributed_memory_usage(&self) -> usize {
self.states
.values()
.map(|state| {
let nodes_memory = state
.consensus_nodes
.iter()
.map(|node| {
node.local_variables.memory_usage()
+ node.dual_variables.memory_usage()
+ node.constraint_residuals.memory_usage()
})
.sum::<usize>();
let network_memory = if let Some(ref network) = state.policy_network {
network.weights.iter().map(|w| w.memory_usage()).sum::<usize>()
+ network.biases.iter().map(|b| b.memory_usage()).sum::<usize>()
+ network.input_mean.memory_usage()
+ network.input_std.memory_usage()
} else {
0
};
nodes_memory + network_memory
})
.sum()
}
}
impl Optimizer for DeepDistributedQP {
fn update(&mut self, parameter: &mut Tensor, gradient: &Tensor) -> Result<()> {
let param_id = format!(
"param_{}_{:?}_{}",
self.states.len(),
parameter.shape(),
parameter
.data_f32()
.unwrap_or_default()
.get(0..5)
.unwrap_or(&[])
.iter()
.fold(0u64, |acc, &x| acc.wrapping_add(x.to_bits() as u64))
);
let qp_solution = self.solve_distributed_qp(¶m_id, gradient)?;
let update = qp_solution.mul_scalar(self.config.learning_rate)?;
*parameter = parameter.sub(&update)?;
Ok(())
}
fn zero_grad(&mut self) {
for state in self.states.values_mut() {
state.problem_vector_q = None;
}
}
fn step(&mut self) {
self.step += 1;
}
fn get_lr(&self) -> f32 {
self.config.learning_rate
}
fn set_lr(&mut self, lr: f32) {
self.config.learning_rate = lr;
}
}
impl StatefulOptimizer for DeepDistributedQP {
type Config = DeepDistributedQPConfig;
type State = StateMemoryStats;
fn config(&self) -> &Self::Config {
&self.config
}
fn state(&self) -> &Self::State {
&self.memory_stats
}
fn state_mut(&mut self) -> &mut Self::State {
&mut self.memory_stats
}
fn state_dict(&self) -> Result<HashMap<String, Tensor>> {
let mut state_dict = HashMap::new();
state_dict.insert("step".to_string(), Tensor::scalar(self.step as f32)?);
state_dict.insert(
"problems_solved".to_string(),
Tensor::scalar(self.problems_solved as f32)?,
);
Ok(state_dict)
}
fn load_state_dict(&mut self, state: HashMap<String, Tensor>) -> Result<()> {
if let Some(step_tensor) = state.get("step") {
self.step = step_tensor.to_scalar()? as usize;
}
if let Some(problems_tensor) = state.get("problems_solved") {
self.problems_solved = problems_tensor.to_scalar()? as usize;
}
Ok(())
}
fn memory_usage(&self) -> StateMemoryStats {
self.memory_stats.clone()
}
fn reset_state(&mut self) {
self.states.clear();
self.step = 0;
self.problems_solved = 0;
self.cumulative_speedup = 1.0;
self.global_consensus = None;
}
fn num_parameters(&self) -> usize {
self.states.len()
}
}
impl DeepDistributedQP {
pub fn num_workers(&self) -> usize {
self.config.num_consensus_nodes
}
pub fn learning_rate(&self) -> f32 {
self.config.learning_rate
}
pub fn communication_rounds(&self) -> usize {
self.config.max_iterations / self.config.consensus_frequency
}
pub fn synchronization_overhead(&self) -> f32 {
1.0 / self.config.consensus_frequency as f32
}
pub fn solve_qp(
&mut self,
problem_id: &str,
p: &Tensor, q: &Tensor, a: Option<&Tensor>, b: Option<&Tensor>, g: Option<&Tensor>, h: Option<&Tensor>, ) -> Result<Tensor> {
let problem_key = problem_id.to_string();
let state_exists = self.states.contains_key(&problem_key);
if !state_exists {
let consensus_nodes = self.initialize_consensus_nodes(q.len()).unwrap_or_default();
let new_state = DeepDistributedQPState {
consensus_nodes,
policy_network: None,
previous_solution: None,
problem_matrix_p: Some(p.clone()),
problem_vector_q: Some(q.clone()),
constraint_matrix_a: a.cloned(),
constraint_vector_b: b.cloned(),
iteration: 0,
convergence_history: Vec::new(),
solve_times: Vec::new(),
problem_size: q.len(),
};
self.states.insert(problem_key.clone(), new_state);
}
let state = self.states.get_mut(&problem_key).expect("state must exist for problem_key");
if let Some(constraint_mat) = g {
for node in &mut state.consensus_nodes {
node.constraint_residuals = constraint_mat.matmul(&node.local_variables)?;
if let Some(h_vec) = h {
node.constraint_residuals = node.constraint_residuals.sub(h_vec)?;
}
}
}
self.solve_distributed_qp(problem_id, q)
}
pub fn set_policy_weights(
&mut self,
param_id: &str,
weights: Vec<Tensor>,
biases: Vec<Tensor>,
) -> Result<()> {
if let Some(state) = self.states.get_mut(param_id) {
if let Some(ref mut network) = state.policy_network {
network.weights = weights;
network.biases = biases;
}
}
Ok(())
}
pub fn train_policy(
&mut self,
param_id: &str,
experience_data: &[(Tensor, f32)],
) -> Result<()> {
if let Some(state) = self.states.get_mut(param_id) {
if let Some(ref mut network) = state.policy_network {
if !experience_data.is_empty() {
let _features: Vec<_> =
experience_data.iter().map(|(f, _)| f.clone()).collect();
network.output_scale *= 1.01; }
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_deep_distributed_qp_creation() {
let optimizer = DeepDistributedQP::new(1e-3, 4, 100, 1e-6);
assert_eq!(optimizer.learning_rate(), 1e-3);
assert_eq!(optimizer.config.num_consensus_nodes, 4);
assert_eq!(optimizer.config.max_iterations, 100);
}
#[test]
fn test_deep_distributed_qp_presets() {
let large_scale = DeepDistributedQP::for_large_scale();
assert_eq!(large_scale.config.num_consensus_nodes, 8);
assert_eq!(large_scale.config.max_iterations, 500);
let portfolio = DeepDistributedQP::for_portfolio_optimization();
assert_eq!(portfolio.config.num_consensus_nodes, 6);
assert_eq!(portfolio.config.penalty_parameter, 2.0);
}
#[test]
fn test_consensus_nodes_initialization() -> Result<()> {
let optimizer = DeepDistributedQP::new(1e-3, 3, 50, 1e-6);
let nodes = optimizer.initialize_consensus_nodes(5)?;
assert_eq!(nodes.len(), 3);
for (i, node) in nodes.iter().enumerate() {
assert_eq!(node.node_id, i);
assert_eq!(node.local_variables.shape(), &[5]);
}
Ok(())
}
#[test]
fn test_policy_network_creation() -> Result<()> {
let optimizer = DeepDistributedQP::new(1e-3, 4, 100, 1e-6);
let network = optimizer.create_policy_network(4)?;
assert_eq!(network.weights.len(), 3); assert_eq!(network.biases.len(), 3);
assert_eq!(network.input_mean.shape(), &[4]);
Ok(())
}
#[test]
fn test_soft_threshold() -> Result<()> {
let optimizer = DeepDistributedQP::new(1e-3, 4, 100, 1e-6);
let input = Tensor::from_slice(&[-2.0, -0.5, 0.0, 0.5, 2.0], &[5])?;
let threshold = 1.0;
let result = optimizer.soft_threshold(&input, threshold)?;
let result_vec = result.data()?;
assert!((result_vec[0] - (-1.0)).abs() < 1e-5);
assert!(result_vec[1].abs() < 1e-5);
assert!(result_vec[2].abs() < 1e-5);
assert!(result_vec[3].abs() < 1e-5);
assert!((result_vec[4] - 1.0).abs() < 1e-5);
Ok(())
}
#[test]
fn test_simple_qp_solve() -> Result<()> {
let mut optimizer = DeepDistributedQP::new(0.1, 2, 20, 1e-4);
let mut parameter = Tensor::from_slice(&[1.0, 2.0, 3.0], &[3])?;
let gradient = Tensor::from_slice(&[0.1, 0.2, 0.1], &[3])?;
optimizer.update(&mut parameter, &gradient)?;
optimizer.step();
Ok(())
}
#[test]
fn test_qp_solver_stats() -> Result<()> {
let mut optimizer = DeepDistributedQP::new(1e-3, 2, 10, 1e-4);
let mut param = Tensor::from_slice(&[1.0, 2.0], &[2])?;
let grad = Tensor::from_slice(&[0.1, 0.1], &[2])?;
optimizer.update(&mut param, &grad)?;
let stats = optimizer.qp_solver_stats();
assert_eq!(stats.len(), 1);
let (iterations, solve_time, _consensus_error, _converged) =
stats.values().next().expect("Operation failed in test");
assert!(*iterations <= 10);
assert!(*solve_time >= 0.0);
Ok(())
}
#[test]
fn test_memory_usage() -> Result<()> {
let mut optimizer = DeepDistributedQP::new(1e-3, 3, 10, 1e-4);
let mut param = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], &[4])?;
let grad = Tensor::from_slice(&[0.1, 0.1, 0.1, 0.1], &[4])?;
let memory_before = optimizer.distributed_memory_usage();
optimizer.update(&mut param, &grad)?;
let memory_after = optimizer.distributed_memory_usage();
assert!(memory_after >= memory_before);
Ok(())
}
}