1use scirs2_core::random::prelude::*;
8use scirs2_core::random::ChaCha8Rng;
9use scirs2_core::random::{Rng, SeedableRng};
10use scirs2_core::RngExt;
11use std::collections::HashMap;
12use std::time::{Duration, Instant};
13use thiserror::Error;
14
15use crate::simulator::{AnnealingParams, AnnealingSolution, TemperatureSchedule};
16
17#[derive(Error, Debug)]
19pub enum ContinuousVariableError {
20 #[error("Invalid variable: {0}")]
22 InvalidVariable(String),
23
24 #[error("Invalid constraint: {0}")]
26 InvalidConstraint(String),
27
28 #[error("Discretization error: {0}")]
30 DiscretizationError(String),
31
32 #[error("Optimization failed: {0}")]
34 OptimizationFailed(String),
35
36 #[error("Numerical error: {0}")]
38 NumericalError(String),
39}
40
41pub type ContinuousVariableResult<T> = Result<T, ContinuousVariableError>;
43
44#[derive(Debug, Clone)]
46pub struct ContinuousVariable {
47 pub name: String,
49
50 pub lower_bound: f64,
52
53 pub upper_bound: f64,
55
56 pub precision_bits: usize,
58
59 pub description: Option<String>,
61}
62
63impl ContinuousVariable {
64 pub fn new(
66 name: String,
67 lower_bound: f64,
68 upper_bound: f64,
69 precision_bits: usize,
70 ) -> ContinuousVariableResult<Self> {
71 if lower_bound >= upper_bound {
72 return Err(ContinuousVariableError::InvalidVariable(format!(
73 "Invalid bounds: {lower_bound} >= {upper_bound}"
74 )));
75 }
76
77 if precision_bits == 0 || precision_bits > 32 {
78 return Err(ContinuousVariableError::InvalidVariable(
79 "Precision bits must be between 1 and 32".to_string(),
80 ));
81 }
82
83 Ok(Self {
84 name,
85 lower_bound,
86 upper_bound,
87 precision_bits,
88 description: None,
89 })
90 }
91
92 #[must_use]
94 pub fn with_description(mut self, description: String) -> Self {
95 self.description = Some(description);
96 self
97 }
98
99 #[must_use]
101 pub const fn num_levels(&self) -> usize {
102 2_usize.pow(self.precision_bits as u32)
103 }
104
105 #[must_use]
107 pub fn binary_to_continuous(&self, binary_value: u32) -> f64 {
108 let max_value = (1u32 << self.precision_bits) - 1;
109 let normalized = f64::from(binary_value) / f64::from(max_value);
110 self.lower_bound + normalized * (self.upper_bound - self.lower_bound)
111 }
112
113 #[must_use]
115 pub fn continuous_to_binary(&self, continuous_value: f64) -> u32 {
116 let clamped = continuous_value.clamp(self.lower_bound, self.upper_bound);
117 let normalized = (clamped - self.lower_bound) / (self.upper_bound - self.lower_bound);
118 let max_value = (1u32 << self.precision_bits) - 1;
119 (normalized * f64::from(max_value)).round() as u32
120 }
121
122 #[must_use]
124 pub fn resolution(&self) -> f64 {
125 (self.upper_bound - self.lower_bound) / (self.num_levels() - 1) as f64
126 }
127}
128
129pub type ObjectiveFunction = Box<dyn Fn(&HashMap<String, f64>) -> f64 + Send + Sync>;
131
132pub type ConstraintFunction = Box<dyn Fn(&HashMap<String, f64>) -> f64 + Send + Sync>;
134
135pub struct ContinuousConstraint {
137 pub name: String,
139
140 pub function: ConstraintFunction,
142
143 pub penalty_weight: f64,
145
146 pub tolerance: f64,
148}
149
150impl std::fmt::Debug for ContinuousConstraint {
151 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
152 f.debug_struct("ContinuousConstraint")
153 .field("name", &self.name)
154 .field("function", &"<function>")
155 .field("penalty_weight", &self.penalty_weight)
156 .field("tolerance", &self.tolerance)
157 .finish()
158 }
159}
160
161impl ContinuousConstraint {
162 #[must_use]
164 pub fn new(name: String, function: ConstraintFunction, penalty_weight: f64) -> Self {
165 Self {
166 name,
167 function,
168 penalty_weight,
169 tolerance: 1e-6,
170 }
171 }
172
173 #[must_use]
175 pub const fn with_tolerance(mut self, tolerance: f64) -> Self {
176 self.tolerance = tolerance;
177 self
178 }
179}
180
181pub struct ContinuousOptimizationProblem {
183 variables: HashMap<String, ContinuousVariable>,
185
186 objective: ObjectiveFunction,
188
189 constraints: Vec<ContinuousConstraint>,
191
192 default_penalty_weight: f64,
194}
195
196impl std::fmt::Debug for ContinuousOptimizationProblem {
197 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
198 f.debug_struct("ContinuousOptimizationProblem")
199 .field("variables", &self.variables)
200 .field("objective", &"<function>")
201 .field("constraints", &self.constraints)
202 .field("default_penalty_weight", &self.default_penalty_weight)
203 .finish()
204 }
205}
206
207impl ContinuousOptimizationProblem {
208 #[must_use]
210 pub fn new(objective: ObjectiveFunction) -> Self {
211 Self {
212 variables: HashMap::new(),
213 objective,
214 constraints: Vec::new(),
215 default_penalty_weight: 100.0,
216 }
217 }
218
219 pub fn add_variable(&mut self, variable: ContinuousVariable) -> ContinuousVariableResult<()> {
221 if self.variables.contains_key(&variable.name) {
222 return Err(ContinuousVariableError::InvalidVariable(format!(
223 "Variable '{}' already exists",
224 variable.name
225 )));
226 }
227
228 self.variables.insert(variable.name.clone(), variable);
229 Ok(())
230 }
231
232 pub fn add_constraint(&mut self, constraint: ContinuousConstraint) {
234 self.constraints.push(constraint);
235 }
236
237 pub const fn set_default_penalty_weight(&mut self, weight: f64) {
239 self.default_penalty_weight = weight;
240 }
241
242 #[must_use]
244 pub fn total_binary_variables(&self) -> usize {
245 self.variables.values().map(|v| v.precision_bits).sum()
246 }
247
248 #[must_use]
250 pub fn create_binary_mapping(&self) -> HashMap<String, Vec<usize>> {
251 let mut mapping = HashMap::new();
252 let mut current_index = 0;
253
254 for (var_name, var) in &self.variables {
255 let indices: Vec<usize> = (current_index..current_index + var.precision_bits).collect();
256 mapping.insert(var_name.clone(), indices);
257 current_index += var.precision_bits;
258 }
259
260 mapping
261 }
262
263 pub fn binary_to_continuous_solution(
265 &self,
266 binary_solution: &[i8],
267 ) -> ContinuousVariableResult<HashMap<String, f64>> {
268 let binary_mapping = self.create_binary_mapping();
269 let mut continuous_solution = HashMap::new();
270
271 for (var_name, var) in &self.variables {
272 let indices = &binary_mapping[var_name];
273
274 if indices.iter().any(|&i| i >= binary_solution.len()) {
275 return Err(ContinuousVariableError::DiscretizationError(format!(
276 "Binary solution too short for variable '{var_name}'"
277 )));
278 }
279
280 let mut binary_value = 0u32;
282 for (bit_idx, &global_idx) in indices.iter().enumerate() {
283 if binary_solution[global_idx] > 0 {
284 binary_value |= 1 << (var.precision_bits - 1 - bit_idx);
285 }
286 }
287
288 let continuous_value = var.binary_to_continuous(binary_value);
290 continuous_solution.insert(var_name.clone(), continuous_value);
291 }
292
293 Ok(continuous_solution)
294 }
295
296 #[must_use]
298 pub fn evaluate_penalized_objective(&self, continuous_solution: &HashMap<String, f64>) -> f64 {
299 let mut objective_value = (self.objective)(continuous_solution);
300
301 for constraint in &self.constraints {
303 let constraint_value = (constraint.function)(continuous_solution);
304 if constraint_value > constraint.tolerance {
305 objective_value += constraint.penalty_weight * constraint_value.powi(2);
306 }
307 }
308
309 objective_value
310 }
311}
312
313#[derive(Debug, Clone)]
315pub struct ContinuousAnnealingConfig {
316 pub annealing_params: AnnealingParams,
318
319 pub adaptive_discretization: bool,
321
322 pub max_refinement_iterations: usize,
324
325 pub refinement_tolerance: f64,
327
328 pub local_search: bool,
330
331 pub local_search_iterations: usize,
333
334 pub local_search_step_size: f64,
336}
337
338impl Default for ContinuousAnnealingConfig {
339 fn default() -> Self {
340 Self {
341 annealing_params: AnnealingParams::default(),
342 adaptive_discretization: true,
343 max_refinement_iterations: 3,
344 refinement_tolerance: 1e-4,
345 local_search: true,
346 local_search_iterations: 100,
347 local_search_step_size: 0.01,
348 }
349 }
350}
351
352#[derive(Debug, Clone)]
354pub struct ContinuousSolution {
355 pub variable_values: HashMap<String, f64>,
357
358 pub objective_value: f64,
360
361 pub constraint_violations: Vec<(String, f64)>,
363
364 pub binary_solution: Vec<i8>,
366
367 pub stats: ContinuousOptimizationStats,
369}
370
371#[derive(Debug, Clone)]
373pub struct ContinuousOptimizationStats {
374 pub total_runtime: Duration,
376
377 pub discretization_time: Duration,
379
380 pub annealing_time: Duration,
382
383 pub local_search_time: Duration,
385
386 pub refinement_iterations: usize,
388
389 pub final_resolution: HashMap<String, f64>,
391
392 pub converged: bool,
394}
395
396pub struct ContinuousVariableAnnealer {
398 config: ContinuousAnnealingConfig,
400
401 rng: ChaCha8Rng,
403}
404
405impl ContinuousVariableAnnealer {
406 #[must_use]
408 pub fn new(config: ContinuousAnnealingConfig) -> Self {
409 let rng = match config.annealing_params.seed {
410 Some(seed) => ChaCha8Rng::seed_from_u64(seed),
411 None => ChaCha8Rng::seed_from_u64(thread_rng().random()),
412 };
413
414 Self { config, rng }
415 }
416
417 pub fn solve(
419 &mut self,
420 problem: &ContinuousOptimizationProblem,
421 ) -> ContinuousVariableResult<ContinuousSolution> {
422 let total_start = Instant::now();
423
424 let discretize_start = Instant::now();
426 let mut current_problem = self.create_discretized_problem(problem)?;
427 let discretization_time = discretize_start.elapsed();
428
429 let mut best_solution = None;
430 let mut best_objective = f64::INFINITY;
431 let mut refinement_iterations = 0;
432
433 for iteration in 0..self.config.max_refinement_iterations {
435 let anneal_start = Instant::now();
437 let binary_solution = self.solve_discretized_problem(¤t_problem)?;
438 let annealing_time = anneal_start.elapsed();
439
440 let continuous_values = problem.binary_to_continuous_solution(&binary_solution)?;
442 let objective_value = problem.evaluate_penalized_objective(&continuous_values);
443
444 let improvement = if best_objective.is_finite() {
446 best_objective - objective_value
447 } else {
448 f64::INFINITY
449 };
450
451 if objective_value < best_objective {
452 best_objective = objective_value;
453 best_solution = Some((binary_solution, continuous_values.clone(), annealing_time));
454 }
455
456 refinement_iterations += 1;
457
458 if improvement < self.config.refinement_tolerance && iteration > 0 {
460 break;
461 }
462
463 if self.config.adaptive_discretization
465 && iteration < self.config.max_refinement_iterations - 1
466 {
467 current_problem = self.refine_discretization(problem, &continuous_values)?;
468 }
469 }
470
471 let (final_binary, mut final_continuous, annealing_time) =
472 best_solution.ok_or_else(|| {
473 ContinuousVariableError::OptimizationFailed("No solution found".to_string())
474 })?;
475
476 let local_search_start = Instant::now();
478 let local_search_time = if self.config.local_search {
479 self.local_search(problem, &mut final_continuous)?;
480 local_search_start.elapsed()
481 } else {
482 Duration::from_secs(0)
483 };
484
485 let constraint_violations =
487 self.calculate_constraint_violations(problem, &final_continuous);
488
489 let final_objective = (problem.objective)(&final_continuous);
491
492 let final_resolution = problem
494 .variables
495 .iter()
496 .map(|(name, var)| (name.clone(), var.resolution()))
497 .collect();
498
499 let total_runtime = total_start.elapsed();
500
501 let stats = ContinuousOptimizationStats {
502 total_runtime,
503 discretization_time,
504 annealing_time,
505 local_search_time,
506 refinement_iterations,
507 final_resolution,
508 converged: refinement_iterations < self.config.max_refinement_iterations,
509 };
510
511 Ok(ContinuousSolution {
512 variable_values: final_continuous,
513 objective_value: final_objective,
514 constraint_violations,
515 binary_solution: final_binary,
516 stats,
517 })
518 }
519
520 const fn create_discretized_problem(
522 &self,
523 _problem: &ContinuousOptimizationProblem,
524 ) -> ContinuousVariableResult<DiscretizedProblem> {
525 Ok(DiscretizedProblem {
528 num_variables: 0,
529 q_matrix: Vec::new(),
530 })
531 }
532
533 fn solve_discretized_problem(
535 &mut self,
536 _problem: &DiscretizedProblem,
537 ) -> ContinuousVariableResult<Vec<i8>> {
538 let num_vars = 16; let solution: Vec<i8> = (0..num_vars)
542 .map(|_| if self.rng.random_bool(0.5) { 1 } else { -1 })
543 .collect();
544
545 Ok(solution)
546 }
547
548 const fn refine_discretization(
550 &self,
551 _problem: &ContinuousOptimizationProblem,
552 _current_solution: &HashMap<String, f64>,
553 ) -> ContinuousVariableResult<DiscretizedProblem> {
554 Ok(DiscretizedProblem {
556 num_variables: 0,
557 q_matrix: Vec::new(),
558 })
559 }
560
561 fn local_search(
563 &self,
564 problem: &ContinuousOptimizationProblem,
565 solution: &mut HashMap<String, f64>,
566 ) -> ContinuousVariableResult<()> {
567 let mut current_objective = problem.evaluate_penalized_objective(solution);
568
569 for _ in 0..self.config.local_search_iterations {
570 let mut improved = false;
571
572 for (var_name, var) in &problem.variables {
574 let current_value = solution[var_name];
575 let step_size =
576 (var.upper_bound - var.lower_bound) * self.config.local_search_step_size;
577
578 for direction in [-1.0_f64, 1.0] {
580 let new_value = direction
581 .mul_add(step_size, current_value)
582 .clamp(var.lower_bound, var.upper_bound);
583
584 solution.insert(var_name.clone(), new_value);
586 let new_objective = problem.evaluate_penalized_objective(solution);
587
588 if new_objective < current_objective {
589 current_objective = new_objective;
590 improved = true;
591 break; }
593 solution.insert(var_name.clone(), current_value);
595 }
596 }
597
598 if !improved {
600 break;
601 }
602 }
603
604 Ok(())
605 }
606
607 fn calculate_constraint_violations(
609 &self,
610 problem: &ContinuousOptimizationProblem,
611 solution: &HashMap<String, f64>,
612 ) -> Vec<(String, f64)> {
613 problem
614 .constraints
615 .iter()
616 .map(|constraint| {
617 let violation = (constraint.function)(solution);
618 (constraint.name.clone(), violation.max(0.0))
619 })
620 .collect()
621 }
622}
623
624#[derive(Debug)]
626struct DiscretizedProblem {
627 num_variables: usize,
628 q_matrix: Vec<Vec<f64>>,
629}
630
631pub fn create_quadratic_problem(
635 linear_coeffs: &[f64],
636 quadratic_matrix: &[Vec<f64>],
637 bounds: &[(f64, f64)],
638 precision_bits: usize,
639) -> ContinuousVariableResult<ContinuousOptimizationProblem> {
640 let linear_coeffs = linear_coeffs.to_vec();
642 let quadratic_matrix = quadratic_matrix.to_vec();
643
644 let objective: ObjectiveFunction = Box::new(move |vars: &HashMap<String, f64>| {
645 let n = linear_coeffs.len();
646 let x: Vec<f64> = (0..n).map(|i| vars[&format!("x{i}")]).collect();
647
648 let linear_term: f64 = linear_coeffs
650 .iter()
651 .zip(x.iter())
652 .map(|(c, xi)| c * xi)
653 .sum();
654
655 let mut quadratic_term = 0.0;
657 for i in 0..n {
658 for j in 0..n {
659 quadratic_term += 0.5 * quadratic_matrix[i][j] * x[i] * x[j];
660 }
661 }
662
663 linear_term + quadratic_term
664 });
665
666 let mut problem = ContinuousOptimizationProblem::new(objective);
667
668 for (i, &(lower, upper)) in bounds.iter().enumerate() {
670 let var = ContinuousVariable::new(format!("x{i}"), lower, upper, precision_bits)?;
671 problem.add_variable(var)?;
672 }
673
674 Ok(problem)
675}
676
677#[cfg(test)]
678mod tests {
679 use super::*;
680
681 #[test]
682 fn test_continuous_variable_creation() {
683 let var = ContinuousVariable::new("x".to_string(), 0.0, 10.0, 8)
684 .expect("should create continuous variable with valid bounds");
685 assert_eq!(var.name, "x");
686 assert_eq!(var.lower_bound, 0.0);
687 assert_eq!(var.upper_bound, 10.0);
688 assert_eq!(var.precision_bits, 8);
689 assert_eq!(var.num_levels(), 256);
690 }
691
692 #[test]
693 fn test_binary_continuous_conversion() {
694 let var = ContinuousVariable::new("x".to_string(), 0.0, 10.0, 4)
695 .expect("should create continuous variable for conversion test");
696
697 assert_eq!(var.binary_to_continuous(0), 0.0);
699 assert!((var.binary_to_continuous(15) - 10.0).abs() < 1e-10);
700
701 assert_eq!(var.continuous_to_binary(0.0), 0);
703 assert_eq!(var.continuous_to_binary(10.0), 15);
704
705 let continuous_val = 3.7;
707 let binary_val = var.continuous_to_binary(continuous_val);
708 let recovered_val = var.binary_to_continuous(binary_val);
709 assert!((recovered_val - continuous_val).abs() <= var.resolution());
710 }
711
712 #[test]
713 fn test_quadratic_problem_creation() {
714 let linear_coeffs = vec![1.0, -2.0];
715 let quadratic_matrix = vec![vec![2.0, 0.0], vec![0.0, 2.0]];
716 let bounds = vec![(0.0, 5.0), (-3.0, 3.0)];
717
718 let problem = create_quadratic_problem(&linear_coeffs, &quadratic_matrix, &bounds, 6)
719 .expect("should create quadratic problem with valid parameters");
720 assert_eq!(problem.variables.len(), 2);
721 assert!(problem.variables.contains_key("x0"));
722 assert!(problem.variables.contains_key("x1"));
723 }
724
725 #[test]
726 fn test_constraint_evaluation() {
727 let constraint_fn: ConstraintFunction = Box::new(|vars| {
728 vars["x"] + vars["y"] - 5.0 });
730
731 let constraint =
732 ContinuousConstraint::new("sum_constraint".to_string(), constraint_fn, 10.0);
733
734 let mut vars = HashMap::new();
735 vars.insert("x".to_string(), 2.0);
736 vars.insert("y".to_string(), 2.0);
737
738 let violation = (constraint.function)(&vars);
739 assert_eq!(violation, -1.0); vars.insert("y".to_string(), 4.0);
742 let violation = (constraint.function)(&vars);
743 assert_eq!(violation, 1.0); }
745}