1use scirs2_core::random::prelude::*;
8use scirs2_core::random::ChaCha8Rng;
9use scirs2_core::random::{Rng, SeedableRng};
10use std::collections::HashMap;
11use std::time::{Duration, Instant};
12use thiserror::Error;
13
14use crate::simulator::{AnnealingParams, AnnealingSolution, TemperatureSchedule};
15
16#[derive(Error, Debug)]
18pub enum ContinuousVariableError {
19 #[error("Invalid variable: {0}")]
21 InvalidVariable(String),
22
23 #[error("Invalid constraint: {0}")]
25 InvalidConstraint(String),
26
27 #[error("Discretization error: {0}")]
29 DiscretizationError(String),
30
31 #[error("Optimization failed: {0}")]
33 OptimizationFailed(String),
34
35 #[error("Numerical error: {0}")]
37 NumericalError(String),
38}
39
40pub type ContinuousVariableResult<T> = Result<T, ContinuousVariableError>;
42
43#[derive(Debug, Clone)]
45pub struct ContinuousVariable {
46 pub name: String,
48
49 pub lower_bound: f64,
51
52 pub upper_bound: f64,
54
55 pub precision_bits: usize,
57
58 pub description: Option<String>,
60}
61
62impl ContinuousVariable {
63 pub fn new(
65 name: String,
66 lower_bound: f64,
67 upper_bound: f64,
68 precision_bits: usize,
69 ) -> ContinuousVariableResult<Self> {
70 if lower_bound >= upper_bound {
71 return Err(ContinuousVariableError::InvalidVariable(format!(
72 "Invalid bounds: {lower_bound} >= {upper_bound}"
73 )));
74 }
75
76 if precision_bits == 0 || precision_bits > 32 {
77 return Err(ContinuousVariableError::InvalidVariable(
78 "Precision bits must be between 1 and 32".to_string(),
79 ));
80 }
81
82 Ok(Self {
83 name,
84 lower_bound,
85 upper_bound,
86 precision_bits,
87 description: None,
88 })
89 }
90
91 #[must_use]
93 pub fn with_description(mut self, description: String) -> Self {
94 self.description = Some(description);
95 self
96 }
97
98 #[must_use]
100 pub const fn num_levels(&self) -> usize {
101 2_usize.pow(self.precision_bits as u32)
102 }
103
104 #[must_use]
106 pub fn binary_to_continuous(&self, binary_value: u32) -> f64 {
107 let max_value = (1u32 << self.precision_bits) - 1;
108 let normalized = f64::from(binary_value) / f64::from(max_value);
109 self.lower_bound + normalized * (self.upper_bound - self.lower_bound)
110 }
111
112 #[must_use]
114 pub fn continuous_to_binary(&self, continuous_value: f64) -> u32 {
115 let clamped = continuous_value.clamp(self.lower_bound, self.upper_bound);
116 let normalized = (clamped - self.lower_bound) / (self.upper_bound - self.lower_bound);
117 let max_value = (1u32 << self.precision_bits) - 1;
118 (normalized * f64::from(max_value)).round() as u32
119 }
120
121 #[must_use]
123 pub fn resolution(&self) -> f64 {
124 (self.upper_bound - self.lower_bound) / (self.num_levels() - 1) as f64
125 }
126}
127
128pub type ObjectiveFunction = Box<dyn Fn(&HashMap<String, f64>) -> f64 + Send + Sync>;
130
131pub type ConstraintFunction = Box<dyn Fn(&HashMap<String, f64>) -> f64 + Send + Sync>;
133
134pub struct ContinuousConstraint {
136 pub name: String,
138
139 pub function: ConstraintFunction,
141
142 pub penalty_weight: f64,
144
145 pub tolerance: f64,
147}
148
149impl std::fmt::Debug for ContinuousConstraint {
150 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
151 f.debug_struct("ContinuousConstraint")
152 .field("name", &self.name)
153 .field("function", &"<function>")
154 .field("penalty_weight", &self.penalty_weight)
155 .field("tolerance", &self.tolerance)
156 .finish()
157 }
158}
159
160impl ContinuousConstraint {
161 #[must_use]
163 pub fn new(name: String, function: ConstraintFunction, penalty_weight: f64) -> Self {
164 Self {
165 name,
166 function,
167 penalty_weight,
168 tolerance: 1e-6,
169 }
170 }
171
172 #[must_use]
174 pub const fn with_tolerance(mut self, tolerance: f64) -> Self {
175 self.tolerance = tolerance;
176 self
177 }
178}
179
180pub struct ContinuousOptimizationProblem {
182 variables: HashMap<String, ContinuousVariable>,
184
185 objective: ObjectiveFunction,
187
188 constraints: Vec<ContinuousConstraint>,
190
191 default_penalty_weight: f64,
193}
194
195impl std::fmt::Debug for ContinuousOptimizationProblem {
196 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
197 f.debug_struct("ContinuousOptimizationProblem")
198 .field("variables", &self.variables)
199 .field("objective", &"<function>")
200 .field("constraints", &self.constraints)
201 .field("default_penalty_weight", &self.default_penalty_weight)
202 .finish()
203 }
204}
205
206impl ContinuousOptimizationProblem {
207 #[must_use]
209 pub fn new(objective: ObjectiveFunction) -> Self {
210 Self {
211 variables: HashMap::new(),
212 objective,
213 constraints: Vec::new(),
214 default_penalty_weight: 100.0,
215 }
216 }
217
218 pub fn add_variable(&mut self, variable: ContinuousVariable) -> ContinuousVariableResult<()> {
220 if self.variables.contains_key(&variable.name) {
221 return Err(ContinuousVariableError::InvalidVariable(format!(
222 "Variable '{}' already exists",
223 variable.name
224 )));
225 }
226
227 self.variables.insert(variable.name.clone(), variable);
228 Ok(())
229 }
230
231 pub fn add_constraint(&mut self, constraint: ContinuousConstraint) {
233 self.constraints.push(constraint);
234 }
235
236 pub const fn set_default_penalty_weight(&mut self, weight: f64) {
238 self.default_penalty_weight = weight;
239 }
240
241 #[must_use]
243 pub fn total_binary_variables(&self) -> usize {
244 self.variables.values().map(|v| v.precision_bits).sum()
245 }
246
247 #[must_use]
249 pub fn create_binary_mapping(&self) -> HashMap<String, Vec<usize>> {
250 let mut mapping = HashMap::new();
251 let mut current_index = 0;
252
253 for (var_name, var) in &self.variables {
254 let indices: Vec<usize> = (current_index..current_index + var.precision_bits).collect();
255 mapping.insert(var_name.clone(), indices);
256 current_index += var.precision_bits;
257 }
258
259 mapping
260 }
261
262 pub fn binary_to_continuous_solution(
264 &self,
265 binary_solution: &[i8],
266 ) -> ContinuousVariableResult<HashMap<String, f64>> {
267 let binary_mapping = self.create_binary_mapping();
268 let mut continuous_solution = HashMap::new();
269
270 for (var_name, var) in &self.variables {
271 let indices = &binary_mapping[var_name];
272
273 if indices.iter().any(|&i| i >= binary_solution.len()) {
274 return Err(ContinuousVariableError::DiscretizationError(format!(
275 "Binary solution too short for variable '{var_name}'"
276 )));
277 }
278
279 let mut binary_value = 0u32;
281 for (bit_idx, &global_idx) in indices.iter().enumerate() {
282 if binary_solution[global_idx] > 0 {
283 binary_value |= 1 << (var.precision_bits - 1 - bit_idx);
284 }
285 }
286
287 let continuous_value = var.binary_to_continuous(binary_value);
289 continuous_solution.insert(var_name.clone(), continuous_value);
290 }
291
292 Ok(continuous_solution)
293 }
294
295 #[must_use]
297 pub fn evaluate_penalized_objective(&self, continuous_solution: &HashMap<String, f64>) -> f64 {
298 let mut objective_value = (self.objective)(continuous_solution);
299
300 for constraint in &self.constraints {
302 let constraint_value = (constraint.function)(continuous_solution);
303 if constraint_value > constraint.tolerance {
304 objective_value += constraint.penalty_weight * constraint_value.powi(2);
305 }
306 }
307
308 objective_value
309 }
310}
311
312#[derive(Debug, Clone)]
314pub struct ContinuousAnnealingConfig {
315 pub annealing_params: AnnealingParams,
317
318 pub adaptive_discretization: bool,
320
321 pub max_refinement_iterations: usize,
323
324 pub refinement_tolerance: f64,
326
327 pub local_search: bool,
329
330 pub local_search_iterations: usize,
332
333 pub local_search_step_size: f64,
335}
336
337impl Default for ContinuousAnnealingConfig {
338 fn default() -> Self {
339 Self {
340 annealing_params: AnnealingParams::default(),
341 adaptive_discretization: true,
342 max_refinement_iterations: 3,
343 refinement_tolerance: 1e-4,
344 local_search: true,
345 local_search_iterations: 100,
346 local_search_step_size: 0.01,
347 }
348 }
349}
350
351#[derive(Debug, Clone)]
353pub struct ContinuousSolution {
354 pub variable_values: HashMap<String, f64>,
356
357 pub objective_value: f64,
359
360 pub constraint_violations: Vec<(String, f64)>,
362
363 pub binary_solution: Vec<i8>,
365
366 pub stats: ContinuousOptimizationStats,
368}
369
370#[derive(Debug, Clone)]
372pub struct ContinuousOptimizationStats {
373 pub total_runtime: Duration,
375
376 pub discretization_time: Duration,
378
379 pub annealing_time: Duration,
381
382 pub local_search_time: Duration,
384
385 pub refinement_iterations: usize,
387
388 pub final_resolution: HashMap<String, f64>,
390
391 pub converged: bool,
393}
394
395pub struct ContinuousVariableAnnealer {
397 config: ContinuousAnnealingConfig,
399
400 rng: ChaCha8Rng,
402}
403
404impl ContinuousVariableAnnealer {
405 #[must_use]
407 pub fn new(config: ContinuousAnnealingConfig) -> Self {
408 let rng = match config.annealing_params.seed {
409 Some(seed) => ChaCha8Rng::seed_from_u64(seed),
410 None => ChaCha8Rng::seed_from_u64(thread_rng().gen()),
411 };
412
413 Self { config, rng }
414 }
415
416 pub fn solve(
418 &mut self,
419 problem: &ContinuousOptimizationProblem,
420 ) -> ContinuousVariableResult<ContinuousSolution> {
421 let total_start = Instant::now();
422
423 let discretize_start = Instant::now();
425 let mut current_problem = self.create_discretized_problem(problem)?;
426 let discretization_time = discretize_start.elapsed();
427
428 let mut best_solution = None;
429 let mut best_objective = f64::INFINITY;
430 let mut refinement_iterations = 0;
431
432 for iteration in 0..self.config.max_refinement_iterations {
434 let anneal_start = Instant::now();
436 let binary_solution = self.solve_discretized_problem(¤t_problem)?;
437 let annealing_time = anneal_start.elapsed();
438
439 let continuous_values = problem.binary_to_continuous_solution(&binary_solution)?;
441 let objective_value = problem.evaluate_penalized_objective(&continuous_values);
442
443 let improvement = if best_objective.is_finite() {
445 best_objective - objective_value
446 } else {
447 f64::INFINITY
448 };
449
450 if objective_value < best_objective {
451 best_objective = objective_value;
452 best_solution = Some((binary_solution, continuous_values.clone(), annealing_time));
453 }
454
455 refinement_iterations += 1;
456
457 if improvement < self.config.refinement_tolerance && iteration > 0 {
459 break;
460 }
461
462 if self.config.adaptive_discretization
464 && iteration < self.config.max_refinement_iterations - 1
465 {
466 current_problem = self.refine_discretization(problem, &continuous_values)?;
467 }
468 }
469
470 let (final_binary, mut final_continuous, annealing_time) =
471 best_solution.ok_or_else(|| {
472 ContinuousVariableError::OptimizationFailed("No solution found".to_string())
473 })?;
474
475 let local_search_start = Instant::now();
477 let local_search_time = if self.config.local_search {
478 self.local_search(problem, &mut final_continuous)?;
479 local_search_start.elapsed()
480 } else {
481 Duration::from_secs(0)
482 };
483
484 let constraint_violations =
486 self.calculate_constraint_violations(problem, &final_continuous);
487
488 let final_objective = (problem.objective)(&final_continuous);
490
491 let final_resolution = problem
493 .variables
494 .iter()
495 .map(|(name, var)| (name.clone(), var.resolution()))
496 .collect();
497
498 let total_runtime = total_start.elapsed();
499
500 let stats = ContinuousOptimizationStats {
501 total_runtime,
502 discretization_time,
503 annealing_time,
504 local_search_time,
505 refinement_iterations,
506 final_resolution,
507 converged: refinement_iterations < self.config.max_refinement_iterations,
508 };
509
510 Ok(ContinuousSolution {
511 variable_values: final_continuous,
512 objective_value: final_objective,
513 constraint_violations,
514 binary_solution: final_binary,
515 stats,
516 })
517 }
518
519 const fn create_discretized_problem(
521 &self,
522 _problem: &ContinuousOptimizationProblem,
523 ) -> ContinuousVariableResult<DiscretizedProblem> {
524 Ok(DiscretizedProblem {
527 num_variables: 0,
528 q_matrix: Vec::new(),
529 })
530 }
531
532 fn solve_discretized_problem(
534 &mut self,
535 _problem: &DiscretizedProblem,
536 ) -> ContinuousVariableResult<Vec<i8>> {
537 let num_vars = 16; let solution: Vec<i8> = (0..num_vars)
541 .map(|_| if self.rng.gen_bool(0.5) { 1 } else { -1 })
542 .collect();
543
544 Ok(solution)
545 }
546
547 const fn refine_discretization(
549 &self,
550 _problem: &ContinuousOptimizationProblem,
551 _current_solution: &HashMap<String, f64>,
552 ) -> ContinuousVariableResult<DiscretizedProblem> {
553 Ok(DiscretizedProblem {
555 num_variables: 0,
556 q_matrix: Vec::new(),
557 })
558 }
559
560 fn local_search(
562 &self,
563 problem: &ContinuousOptimizationProblem,
564 solution: &mut HashMap<String, f64>,
565 ) -> ContinuousVariableResult<()> {
566 let mut current_objective = problem.evaluate_penalized_objective(solution);
567
568 for _ in 0..self.config.local_search_iterations {
569 let mut improved = false;
570
571 for (var_name, var) in &problem.variables {
573 let current_value = solution[var_name];
574 let step_size =
575 (var.upper_bound - var.lower_bound) * self.config.local_search_step_size;
576
577 for direction in [-1.0_f64, 1.0] {
579 let new_value = direction
580 .mul_add(step_size, current_value)
581 .clamp(var.lower_bound, var.upper_bound);
582
583 solution.insert(var_name.clone(), new_value);
585 let new_objective = problem.evaluate_penalized_objective(solution);
586
587 if new_objective < current_objective {
588 current_objective = new_objective;
589 improved = true;
590 break; }
592 solution.insert(var_name.clone(), current_value);
594 }
595 }
596
597 if !improved {
599 break;
600 }
601 }
602
603 Ok(())
604 }
605
606 fn calculate_constraint_violations(
608 &self,
609 problem: &ContinuousOptimizationProblem,
610 solution: &HashMap<String, f64>,
611 ) -> Vec<(String, f64)> {
612 problem
613 .constraints
614 .iter()
615 .map(|constraint| {
616 let violation = (constraint.function)(solution);
617 (constraint.name.clone(), violation.max(0.0))
618 })
619 .collect()
620 }
621}
622
623#[derive(Debug)]
625struct DiscretizedProblem {
626 num_variables: usize,
627 q_matrix: Vec<Vec<f64>>,
628}
629
630pub fn create_quadratic_problem(
634 linear_coeffs: &[f64],
635 quadratic_matrix: &[Vec<f64>],
636 bounds: &[(f64, f64)],
637 precision_bits: usize,
638) -> ContinuousVariableResult<ContinuousOptimizationProblem> {
639 let linear_coeffs = linear_coeffs.to_vec();
641 let quadratic_matrix = quadratic_matrix.to_vec();
642
643 let objective: ObjectiveFunction = Box::new(move |vars: &HashMap<String, f64>| {
644 let n = linear_coeffs.len();
645 let x: Vec<f64> = (0..n).map(|i| vars[&format!("x{i}")]).collect();
646
647 let linear_term: f64 = linear_coeffs
649 .iter()
650 .zip(x.iter())
651 .map(|(c, xi)| c * xi)
652 .sum();
653
654 let mut quadratic_term = 0.0;
656 for i in 0..n {
657 for j in 0..n {
658 quadratic_term += 0.5 * quadratic_matrix[i][j] * x[i] * x[j];
659 }
660 }
661
662 linear_term + quadratic_term
663 });
664
665 let mut problem = ContinuousOptimizationProblem::new(objective);
666
667 for (i, &(lower, upper)) in bounds.iter().enumerate() {
669 let var = ContinuousVariable::new(format!("x{i}"), lower, upper, precision_bits)?;
670 problem.add_variable(var)?;
671 }
672
673 Ok(problem)
674}
675
676#[cfg(test)]
677mod tests {
678 use super::*;
679
680 #[test]
681 fn test_continuous_variable_creation() {
682 let var = ContinuousVariable::new("x".to_string(), 0.0, 10.0, 8)
683 .expect("should create continuous variable with valid bounds");
684 assert_eq!(var.name, "x");
685 assert_eq!(var.lower_bound, 0.0);
686 assert_eq!(var.upper_bound, 10.0);
687 assert_eq!(var.precision_bits, 8);
688 assert_eq!(var.num_levels(), 256);
689 }
690
691 #[test]
692 fn test_binary_continuous_conversion() {
693 let var = ContinuousVariable::new("x".to_string(), 0.0, 10.0, 4)
694 .expect("should create continuous variable for conversion test");
695
696 assert_eq!(var.binary_to_continuous(0), 0.0);
698 assert!((var.binary_to_continuous(15) - 10.0).abs() < 1e-10);
699
700 assert_eq!(var.continuous_to_binary(0.0), 0);
702 assert_eq!(var.continuous_to_binary(10.0), 15);
703
704 let continuous_val = 3.7;
706 let binary_val = var.continuous_to_binary(continuous_val);
707 let recovered_val = var.binary_to_continuous(binary_val);
708 assert!((recovered_val - continuous_val).abs() <= var.resolution());
709 }
710
711 #[test]
712 fn test_quadratic_problem_creation() {
713 let linear_coeffs = vec![1.0, -2.0];
714 let quadratic_matrix = vec![vec![2.0, 0.0], vec![0.0, 2.0]];
715 let bounds = vec![(0.0, 5.0), (-3.0, 3.0)];
716
717 let problem = create_quadratic_problem(&linear_coeffs, &quadratic_matrix, &bounds, 6)
718 .expect("should create quadratic problem with valid parameters");
719 assert_eq!(problem.variables.len(), 2);
720 assert!(problem.variables.contains_key("x0"));
721 assert!(problem.variables.contains_key("x1"));
722 }
723
724 #[test]
725 fn test_constraint_evaluation() {
726 let constraint_fn: ConstraintFunction = Box::new(|vars| {
727 vars["x"] + vars["y"] - 5.0 });
729
730 let constraint =
731 ContinuousConstraint::new("sum_constraint".to_string(), constraint_fn, 10.0);
732
733 let mut vars = HashMap::new();
734 vars.insert("x".to_string(), 2.0);
735 vars.insert("y".to_string(), 2.0);
736
737 let violation = (constraint.function)(&vars);
738 assert_eq!(violation, -1.0); vars.insert("y".to_string(), 4.0);
741 let violation = (constraint.function)(&vars);
742 assert_eq!(violation, 1.0); }
744}