use scirs2_core::random::prelude::*;
use scirs2_core::random::ChaCha8Rng;
use scirs2_core::random::{Rng, SeedableRng};
use scirs2_core::RngExt;
use std::collections::HashMap;
use std::time::{Duration, Instant};
use thiserror::Error;
use crate::simulator::{AnnealingParams, AnnealingSolution, TemperatureSchedule};
#[derive(Error, Debug)]
pub enum ContinuousVariableError {
#[error("Invalid variable: {0}")]
InvalidVariable(String),
#[error("Invalid constraint: {0}")]
InvalidConstraint(String),
#[error("Discretization error: {0}")]
DiscretizationError(String),
#[error("Optimization failed: {0}")]
OptimizationFailed(String),
#[error("Numerical error: {0}")]
NumericalError(String),
}
pub type ContinuousVariableResult<T> = Result<T, ContinuousVariableError>;
#[derive(Debug, Clone)]
pub struct ContinuousVariable {
pub name: String,
pub lower_bound: f64,
pub upper_bound: f64,
pub precision_bits: usize,
pub description: Option<String>,
}
impl ContinuousVariable {
pub fn new(
name: String,
lower_bound: f64,
upper_bound: f64,
precision_bits: usize,
) -> ContinuousVariableResult<Self> {
if lower_bound >= upper_bound {
return Err(ContinuousVariableError::InvalidVariable(format!(
"Invalid bounds: {lower_bound} >= {upper_bound}"
)));
}
if precision_bits == 0 || precision_bits > 32 {
return Err(ContinuousVariableError::InvalidVariable(
"Precision bits must be between 1 and 32".to_string(),
));
}
Ok(Self {
name,
lower_bound,
upper_bound,
precision_bits,
description: None,
})
}
#[must_use]
pub fn with_description(mut self, description: String) -> Self {
self.description = Some(description);
self
}
#[must_use]
pub const fn num_levels(&self) -> usize {
2_usize.pow(self.precision_bits as u32)
}
#[must_use]
pub fn binary_to_continuous(&self, binary_value: u32) -> f64 {
let max_value = (1u32 << self.precision_bits) - 1;
let normalized = f64::from(binary_value) / f64::from(max_value);
self.lower_bound + normalized * (self.upper_bound - self.lower_bound)
}
#[must_use]
pub fn continuous_to_binary(&self, continuous_value: f64) -> u32 {
let clamped = continuous_value.clamp(self.lower_bound, self.upper_bound);
let normalized = (clamped - self.lower_bound) / (self.upper_bound - self.lower_bound);
let max_value = (1u32 << self.precision_bits) - 1;
(normalized * f64::from(max_value)).round() as u32
}
#[must_use]
pub fn resolution(&self) -> f64 {
(self.upper_bound - self.lower_bound) / (self.num_levels() - 1) as f64
}
}
pub type ObjectiveFunction = Box<dyn Fn(&HashMap<String, f64>) -> f64 + Send + Sync>;
pub type ConstraintFunction = Box<dyn Fn(&HashMap<String, f64>) -> f64 + Send + Sync>;
pub struct ContinuousConstraint {
pub name: String,
pub function: ConstraintFunction,
pub penalty_weight: f64,
pub tolerance: f64,
}
impl std::fmt::Debug for ContinuousConstraint {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ContinuousConstraint")
.field("name", &self.name)
.field("function", &"<function>")
.field("penalty_weight", &self.penalty_weight)
.field("tolerance", &self.tolerance)
.finish()
}
}
impl ContinuousConstraint {
#[must_use]
pub fn new(name: String, function: ConstraintFunction, penalty_weight: f64) -> Self {
Self {
name,
function,
penalty_weight,
tolerance: 1e-6,
}
}
#[must_use]
pub const fn with_tolerance(mut self, tolerance: f64) -> Self {
self.tolerance = tolerance;
self
}
}
pub struct ContinuousOptimizationProblem {
variables: HashMap<String, ContinuousVariable>,
objective: ObjectiveFunction,
constraints: Vec<ContinuousConstraint>,
default_penalty_weight: f64,
}
impl std::fmt::Debug for ContinuousOptimizationProblem {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ContinuousOptimizationProblem")
.field("variables", &self.variables)
.field("objective", &"<function>")
.field("constraints", &self.constraints)
.field("default_penalty_weight", &self.default_penalty_weight)
.finish()
}
}
impl ContinuousOptimizationProblem {
#[must_use]
pub fn new(objective: ObjectiveFunction) -> Self {
Self {
variables: HashMap::new(),
objective,
constraints: Vec::new(),
default_penalty_weight: 100.0,
}
}
pub fn add_variable(&mut self, variable: ContinuousVariable) -> ContinuousVariableResult<()> {
if self.variables.contains_key(&variable.name) {
return Err(ContinuousVariableError::InvalidVariable(format!(
"Variable '{}' already exists",
variable.name
)));
}
self.variables.insert(variable.name.clone(), variable);
Ok(())
}
pub fn add_constraint(&mut self, constraint: ContinuousConstraint) {
self.constraints.push(constraint);
}
pub const fn set_default_penalty_weight(&mut self, weight: f64) {
self.default_penalty_weight = weight;
}
#[must_use]
pub fn total_binary_variables(&self) -> usize {
self.variables.values().map(|v| v.precision_bits).sum()
}
#[must_use]
pub fn create_binary_mapping(&self) -> HashMap<String, Vec<usize>> {
let mut mapping = HashMap::new();
let mut current_index = 0;
for (var_name, var) in &self.variables {
let indices: Vec<usize> = (current_index..current_index + var.precision_bits).collect();
mapping.insert(var_name.clone(), indices);
current_index += var.precision_bits;
}
mapping
}
pub fn binary_to_continuous_solution(
&self,
binary_solution: &[i8],
) -> ContinuousVariableResult<HashMap<String, f64>> {
let binary_mapping = self.create_binary_mapping();
let mut continuous_solution = HashMap::new();
for (var_name, var) in &self.variables {
let indices = &binary_mapping[var_name];
if indices.iter().any(|&i| i >= binary_solution.len()) {
return Err(ContinuousVariableError::DiscretizationError(format!(
"Binary solution too short for variable '{var_name}'"
)));
}
let mut binary_value = 0u32;
for (bit_idx, &global_idx) in indices.iter().enumerate() {
if binary_solution[global_idx] > 0 {
binary_value |= 1 << (var.precision_bits - 1 - bit_idx);
}
}
let continuous_value = var.binary_to_continuous(binary_value);
continuous_solution.insert(var_name.clone(), continuous_value);
}
Ok(continuous_solution)
}
#[must_use]
pub fn evaluate_penalized_objective(&self, continuous_solution: &HashMap<String, f64>) -> f64 {
let mut objective_value = (self.objective)(continuous_solution);
for constraint in &self.constraints {
let constraint_value = (constraint.function)(continuous_solution);
if constraint_value > constraint.tolerance {
objective_value += constraint.penalty_weight * constraint_value.powi(2);
}
}
objective_value
}
}
#[derive(Debug, Clone)]
pub struct ContinuousAnnealingConfig {
pub annealing_params: AnnealingParams,
pub adaptive_discretization: bool,
pub max_refinement_iterations: usize,
pub refinement_tolerance: f64,
pub local_search: bool,
pub local_search_iterations: usize,
pub local_search_step_size: f64,
}
impl Default for ContinuousAnnealingConfig {
fn default() -> Self {
Self {
annealing_params: AnnealingParams::default(),
adaptive_discretization: true,
max_refinement_iterations: 3,
refinement_tolerance: 1e-4,
local_search: true,
local_search_iterations: 100,
local_search_step_size: 0.01,
}
}
}
#[derive(Debug, Clone)]
pub struct ContinuousSolution {
pub variable_values: HashMap<String, f64>,
pub objective_value: f64,
pub constraint_violations: Vec<(String, f64)>,
pub binary_solution: Vec<i8>,
pub stats: ContinuousOptimizationStats,
}
#[derive(Debug, Clone)]
pub struct ContinuousOptimizationStats {
pub total_runtime: Duration,
pub discretization_time: Duration,
pub annealing_time: Duration,
pub local_search_time: Duration,
pub refinement_iterations: usize,
pub final_resolution: HashMap<String, f64>,
pub converged: bool,
}
pub struct ContinuousVariableAnnealer {
config: ContinuousAnnealingConfig,
rng: ChaCha8Rng,
}
impl ContinuousVariableAnnealer {
#[must_use]
pub fn new(config: ContinuousAnnealingConfig) -> Self {
let rng = match config.annealing_params.seed {
Some(seed) => ChaCha8Rng::seed_from_u64(seed),
None => ChaCha8Rng::seed_from_u64(thread_rng().random()),
};
Self { config, rng }
}
pub fn solve(
&mut self,
problem: &ContinuousOptimizationProblem,
) -> ContinuousVariableResult<ContinuousSolution> {
let total_start = Instant::now();
let discretize_start = Instant::now();
let mut current_problem = self.create_discretized_problem(problem)?;
let discretization_time = discretize_start.elapsed();
let mut best_solution = None;
let mut best_objective = f64::INFINITY;
let mut refinement_iterations = 0;
for iteration in 0..self.config.max_refinement_iterations {
let anneal_start = Instant::now();
let binary_solution = self.solve_discretized_problem(¤t_problem)?;
let annealing_time = anneal_start.elapsed();
let continuous_values = problem.binary_to_continuous_solution(&binary_solution)?;
let objective_value = problem.evaluate_penalized_objective(&continuous_values);
let improvement = if best_objective.is_finite() {
best_objective - objective_value
} else {
f64::INFINITY
};
if objective_value < best_objective {
best_objective = objective_value;
best_solution = Some((binary_solution, continuous_values.clone(), annealing_time));
}
refinement_iterations += 1;
if improvement < self.config.refinement_tolerance && iteration > 0 {
break;
}
if self.config.adaptive_discretization
&& iteration < self.config.max_refinement_iterations - 1
{
current_problem = self.refine_discretization(problem, &continuous_values)?;
}
}
let (final_binary, mut final_continuous, annealing_time) =
best_solution.ok_or_else(|| {
ContinuousVariableError::OptimizationFailed("No solution found".to_string())
})?;
let local_search_start = Instant::now();
let local_search_time = if self.config.local_search {
self.local_search(problem, &mut final_continuous)?;
local_search_start.elapsed()
} else {
Duration::from_secs(0)
};
let constraint_violations =
self.calculate_constraint_violations(problem, &final_continuous);
let final_objective = (problem.objective)(&final_continuous);
let final_resolution = problem
.variables
.iter()
.map(|(name, var)| (name.clone(), var.resolution()))
.collect();
let total_runtime = total_start.elapsed();
let stats = ContinuousOptimizationStats {
total_runtime,
discretization_time,
annealing_time,
local_search_time,
refinement_iterations,
final_resolution,
converged: refinement_iterations < self.config.max_refinement_iterations,
};
Ok(ContinuousSolution {
variable_values: final_continuous,
objective_value: final_objective,
constraint_violations,
binary_solution: final_binary,
stats,
})
}
const fn create_discretized_problem(
&self,
_problem: &ContinuousOptimizationProblem,
) -> ContinuousVariableResult<DiscretizedProblem> {
Ok(DiscretizedProblem {
num_variables: 0,
q_matrix: Vec::new(),
})
}
fn solve_discretized_problem(
&mut self,
_problem: &DiscretizedProblem,
) -> ContinuousVariableResult<Vec<i8>> {
let num_vars = 16; let solution: Vec<i8> = (0..num_vars)
.map(|_| if self.rng.random_bool(0.5) { 1 } else { -1 })
.collect();
Ok(solution)
}
const fn refine_discretization(
&self,
_problem: &ContinuousOptimizationProblem,
_current_solution: &HashMap<String, f64>,
) -> ContinuousVariableResult<DiscretizedProblem> {
Ok(DiscretizedProblem {
num_variables: 0,
q_matrix: Vec::new(),
})
}
fn local_search(
&self,
problem: &ContinuousOptimizationProblem,
solution: &mut HashMap<String, f64>,
) -> ContinuousVariableResult<()> {
let mut current_objective = problem.evaluate_penalized_objective(solution);
for _ in 0..self.config.local_search_iterations {
let mut improved = false;
for (var_name, var) in &problem.variables {
let current_value = solution[var_name];
let step_size =
(var.upper_bound - var.lower_bound) * self.config.local_search_step_size;
for direction in [-1.0_f64, 1.0] {
let new_value = direction
.mul_add(step_size, current_value)
.clamp(var.lower_bound, var.upper_bound);
solution.insert(var_name.clone(), new_value);
let new_objective = problem.evaluate_penalized_objective(solution);
if new_objective < current_objective {
current_objective = new_objective;
improved = true;
break; }
solution.insert(var_name.clone(), current_value);
}
}
if !improved {
break;
}
}
Ok(())
}
fn calculate_constraint_violations(
&self,
problem: &ContinuousOptimizationProblem,
solution: &HashMap<String, f64>,
) -> Vec<(String, f64)> {
problem
.constraints
.iter()
.map(|constraint| {
let violation = (constraint.function)(solution);
(constraint.name.clone(), violation.max(0.0))
})
.collect()
}
}
#[derive(Debug)]
struct DiscretizedProblem {
num_variables: usize,
q_matrix: Vec<Vec<f64>>,
}
pub fn create_quadratic_problem(
linear_coeffs: &[f64],
quadratic_matrix: &[Vec<f64>],
bounds: &[(f64, f64)],
precision_bits: usize,
) -> ContinuousVariableResult<ContinuousOptimizationProblem> {
let linear_coeffs = linear_coeffs.to_vec();
let quadratic_matrix = quadratic_matrix.to_vec();
let objective: ObjectiveFunction = Box::new(move |vars: &HashMap<String, f64>| {
let n = linear_coeffs.len();
let x: Vec<f64> = (0..n).map(|i| vars[&format!("x{i}")]).collect();
let linear_term: f64 = linear_coeffs
.iter()
.zip(x.iter())
.map(|(c, xi)| c * xi)
.sum();
let mut quadratic_term = 0.0;
for i in 0..n {
for j in 0..n {
quadratic_term += 0.5 * quadratic_matrix[i][j] * x[i] * x[j];
}
}
linear_term + quadratic_term
});
let mut problem = ContinuousOptimizationProblem::new(objective);
for (i, &(lower, upper)) in bounds.iter().enumerate() {
let var = ContinuousVariable::new(format!("x{i}"), lower, upper, precision_bits)?;
problem.add_variable(var)?;
}
Ok(problem)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_continuous_variable_creation() {
let var = ContinuousVariable::new("x".to_string(), 0.0, 10.0, 8)
.expect("should create continuous variable with valid bounds");
assert_eq!(var.name, "x");
assert_eq!(var.lower_bound, 0.0);
assert_eq!(var.upper_bound, 10.0);
assert_eq!(var.precision_bits, 8);
assert_eq!(var.num_levels(), 256);
}
#[test]
fn test_binary_continuous_conversion() {
let var = ContinuousVariable::new("x".to_string(), 0.0, 10.0, 4)
.expect("should create continuous variable for conversion test");
assert_eq!(var.binary_to_continuous(0), 0.0);
assert!((var.binary_to_continuous(15) - 10.0).abs() < 1e-10);
assert_eq!(var.continuous_to_binary(0.0), 0);
assert_eq!(var.continuous_to_binary(10.0), 15);
let continuous_val = 3.7;
let binary_val = var.continuous_to_binary(continuous_val);
let recovered_val = var.binary_to_continuous(binary_val);
assert!((recovered_val - continuous_val).abs() <= var.resolution());
}
#[test]
fn test_quadratic_problem_creation() {
let linear_coeffs = vec![1.0, -2.0];
let quadratic_matrix = vec![vec![2.0, 0.0], vec![0.0, 2.0]];
let bounds = vec![(0.0, 5.0), (-3.0, 3.0)];
let problem = create_quadratic_problem(&linear_coeffs, &quadratic_matrix, &bounds, 6)
.expect("should create quadratic problem with valid parameters");
assert_eq!(problem.variables.len(), 2);
assert!(problem.variables.contains_key("x0"));
assert!(problem.variables.contains_key("x1"));
}
#[test]
fn test_constraint_evaluation() {
let constraint_fn: ConstraintFunction = Box::new(|vars| {
vars["x"] + vars["y"] - 5.0 });
let constraint =
ContinuousConstraint::new("sum_constraint".to_string(), constraint_fn, 10.0);
let mut vars = HashMap::new();
vars.insert("x".to_string(), 2.0);
vars.insert("y".to_string(), 2.0);
let violation = (constraint.function)(&vars);
assert_eq!(violation, -1.0);
vars.insert("y".to_string(), 4.0);
let violation = (constraint.function)(&vars);
assert_eq!(violation, 1.0); }
}