1use scirs2_core::ndarray::Array2;
7use std::collections::HashMap;
8
9#[cfg(feature = "dwave")]
10use crate::symbol::Expression;
11
12#[derive(Debug, Clone)]
14pub enum EncodingScheme {
15 OneHot { num_values: usize },
17 Binary { num_values: usize },
19 GrayCode { num_values: usize },
21 DomainWall { num_values: usize },
23 Unary { num_values: usize },
25 OrderEncoding { min_value: i32, max_value: i32 },
27 Direct,
29}
30
31#[derive(Debug, Clone)]
33pub struct EncodedVariable {
34 pub name: String,
36 pub scheme: EncodingScheme,
38 pub binary_vars: Vec<String>,
40 #[cfg(feature = "dwave")]
42 pub constraints: Option<Expression>,
43}
44
45impl EncodedVariable {
46 pub fn new(name: &str, scheme: EncodingScheme) -> Self {
48 let binary_vars = Self::generate_binary_vars(name, &scheme);
49 Self {
50 name: name.to_string(),
51 scheme,
52 binary_vars,
53 #[cfg(feature = "dwave")]
54 constraints: None,
55 }
56 }
57
58 fn generate_binary_vars(name: &str, scheme: &EncodingScheme) -> Vec<String> {
60 match scheme {
61 EncodingScheme::OneHot { num_values } => {
62 (0..*num_values).map(|i| format!("{name}_{i}")).collect()
63 }
64 EncodingScheme::Binary { num_values } => {
65 let num_bits = (*num_values as f64).log2().ceil() as usize;
66 (0..num_bits).map(|i| format!("{name}_bit{i}")).collect()
67 }
68 EncodingScheme::GrayCode { num_values } => {
69 let num_bits = (*num_values as f64).log2().ceil() as usize;
70 (0..num_bits).map(|i| format!("{name}_gray{i}")).collect()
71 }
72 EncodingScheme::DomainWall { num_values } => (0..*num_values - 1)
73 .map(|i| format!("{name}_dw{i}"))
74 .collect(),
75 EncodingScheme::Unary { num_values } => (0..*num_values - 1)
76 .map(|i| format!("{name}_u{i}"))
77 .collect(),
78 EncodingScheme::OrderEncoding {
79 min_value,
80 max_value,
81 } => {
82 let range = max_value - min_value;
83 (0..range).map(|i| format!("{name}_ord{i}")).collect()
84 }
85 EncodingScheme::Direct => vec![name.to_string()],
86 }
87 }
88
89 pub fn decode(&self, binary_values: &HashMap<String, bool>) -> Option<i32> {
91 match &self.scheme {
92 EncodingScheme::OneHot { .. } => {
93 for (i, var) in self.binary_vars.iter().enumerate() {
94 if binary_values.get(var).copied().unwrap_or(false) {
95 return Some(i as i32);
96 }
97 }
98 None }
100 EncodingScheme::Binary { .. } => {
101 let mut value = 0;
102 for (i, var) in self.binary_vars.iter().enumerate() {
103 if binary_values.get(var).copied().unwrap_or(false) {
104 value |= 1 << i;
105 }
106 }
107 Some(value)
108 }
109 EncodingScheme::GrayCode { .. } => {
110 let mut gray = 0;
111 for (i, var) in self.binary_vars.iter().enumerate() {
112 if binary_values.get(var).copied().unwrap_or(false) {
113 gray |= 1 << i;
114 }
115 }
116 let mut binary = gray;
118 binary ^= binary >> 16;
119 binary ^= binary >> 8;
120 binary ^= binary >> 4;
121 binary ^= binary >> 2;
122 binary ^= binary >> 1;
123 Some(binary)
124 }
125 EncodingScheme::DomainWall { num_values } => {
126 let mut value = *num_values as i32 - 1;
127 for (i, var) in self.binary_vars.iter().enumerate() {
128 if !binary_values.get(var).copied().unwrap_or(false) {
129 value = i as i32;
130 break;
131 }
132 }
133 Some(value)
134 }
135 EncodingScheme::Unary { .. } => {
136 let mut value = 0;
137 for var in &self.binary_vars {
138 if binary_values.get(var).copied().unwrap_or(false) {
139 value += 1;
140 } else {
141 break;
142 }
143 }
144 Some(value)
145 }
146 EncodingScheme::OrderEncoding { min_value, .. } => {
147 let mut value = *min_value;
148 for var in &self.binary_vars {
149 if binary_values.get(var).copied().unwrap_or(false) {
150 value += 1;
151 }
152 }
153 Some(value - 1)
154 }
155 EncodingScheme::Direct => binary_values.get(&self.name).map(|&b| i32::from(b)),
156 }
157 }
158
159 pub fn encode(&self, value: i32) -> HashMap<String, bool> {
161 let mut binary_values = HashMap::new();
162
163 match &self.scheme {
164 EncodingScheme::OneHot { num_values: _ } => {
165 for (i, var) in self.binary_vars.iter().enumerate() {
166 binary_values.insert(var.clone(), i == value as usize);
167 }
168 }
169 EncodingScheme::Binary { .. } => {
170 for (i, var) in self.binary_vars.iter().enumerate() {
171 binary_values.insert(var.clone(), (value & (1 << i)) != 0);
172 }
173 }
174 EncodingScheme::GrayCode { .. } => {
175 let gray = value ^ (value >> 1);
177 for (i, var) in self.binary_vars.iter().enumerate() {
178 binary_values.insert(var.clone(), (gray & (1 << i)) != 0);
179 }
180 }
181 EncodingScheme::DomainWall { num_values: _ } => {
182 for (i, var) in self.binary_vars.iter().enumerate() {
183 binary_values.insert(var.clone(), i < value as usize);
184 }
185 }
186 EncodingScheme::Unary { .. } => {
187 for (i, var) in self.binary_vars.iter().enumerate() {
188 binary_values.insert(var.clone(), i < value as usize);
189 }
190 }
191 EncodingScheme::OrderEncoding { min_value, .. } => {
192 let adjusted = value - min_value + 1;
193 for (i, var) in self.binary_vars.iter().enumerate() {
194 binary_values.insert(var.clone(), i < adjusted as usize);
195 }
196 }
197 EncodingScheme::Direct => {
198 binary_values.insert(self.name.clone(), value != 0);
199 }
200 }
201
202 binary_values
203 }
204
205 pub fn get_penalty_matrix(&self, var_indices: &HashMap<String, usize>) -> Array2<f64> {
207 let n = var_indices.len();
208 let mut penalty = Array2::zeros((n, n));
209
210 match &self.scheme {
211 EncodingScheme::OneHot { .. } => {
212 let indices: Vec<usize> = self
217 .binary_vars
218 .iter()
219 .filter_map(|var| var_indices.get(var).copied())
220 .collect();
221
222 for &i in &indices {
224 for &j in &indices {
225 if i != j {
226 penalty[[i, j]] += 1.0;
227 }
228 }
229 }
230
231 for &i in &indices {
233 penalty[[i, i]] -= 2.0;
234 }
235 }
236 EncodingScheme::DomainWall { .. } => {
237 let indices: Vec<usize> = self
241 .binary_vars
242 .iter()
243 .filter_map(|var| var_indices.get(var).copied())
244 .collect();
245
246 for i in 0..indices.len() - 1 {
247 let idx1 = indices[i];
248 let idx2 = indices[i + 1];
249
250 penalty[[idx2, idx2]] += 1.0;
252 penalty[[idx1, idx2]] -= 1.0;
253 penalty[[idx2, idx1]] -= 1.0;
254 }
255 }
256 EncodingScheme::Unary { .. } => {
257 let indices: Vec<usize> = self
260 .binary_vars
261 .iter()
262 .filter_map(|var| var_indices.get(var).copied())
263 .collect();
264
265 for i in 0..indices.len() - 1 {
266 let idx1 = indices[i];
267 let idx2 = indices[i + 1];
268
269 penalty[[idx2, idx2]] += 1.0;
270 penalty[[idx1, idx2]] -= 1.0;
271 penalty[[idx2, idx1]] -= 1.0;
272 }
273 }
274 _ => {
275 }
277 }
278
279 penalty
280 }
281}
282
283pub struct EncodingOptimizer {
285 domains: HashMap<String, (i32, i32)>,
287 constraint_graph: HashMap<String, Vec<String>>,
289}
290
291impl Default for EncodingOptimizer {
292 fn default() -> Self {
293 Self::new()
294 }
295}
296
297impl EncodingOptimizer {
298 pub fn new() -> Self {
300 Self {
301 domains: HashMap::new(),
302 constraint_graph: HashMap::new(),
303 }
304 }
305
306 pub fn add_variable(&mut self, name: &str, min_value: i32, max_value: i32) {
308 self.domains
309 .insert(name.to_string(), (min_value, max_value));
310 }
311
312 pub fn add_constraint(&mut self, var1: &str, var2: &str) {
314 self.constraint_graph
315 .entry(var1.to_string())
316 .or_default()
317 .push(var2.to_string());
318 self.constraint_graph
319 .entry(var2.to_string())
320 .or_default()
321 .push(var1.to_string());
322 }
323
324 pub fn optimize_encodings(&self) -> HashMap<String, EncodingScheme> {
326 let mut encodings = HashMap::new();
327
328 for (var, &(min_val, max_val)) in &self.domains {
329 let domain_size = (max_val - min_val + 1) as usize;
330 let neighbors = self.constraint_graph.get(var).map_or(0, |v| v.len());
331
332 let encoding = if domain_size == 2 {
334 EncodingScheme::Direct
336 } else if domain_size <= 4 && neighbors > 3 {
337 EncodingScheme::OneHot {
339 num_values: domain_size,
340 }
341 } else if domain_size <= 8 {
342 if self.has_ordering_constraints(var) {
344 EncodingScheme::GrayCode {
345 num_values: domain_size,
346 }
347 } else {
348 EncodingScheme::Binary {
349 num_values: domain_size,
350 }
351 }
352 } else if self.has_ordering_constraints(var) {
353 if domain_size <= 32 {
355 EncodingScheme::OrderEncoding {
356 min_value: min_val,
357 max_value: max_val,
358 }
359 } else {
360 EncodingScheme::DomainWall {
361 num_values: domain_size,
362 }
363 }
364 } else {
365 EncodingScheme::Binary {
367 num_values: domain_size,
368 }
369 };
370
371 encodings.insert(var.clone(), encoding);
372 }
373
374 encodings
375 }
376
377 const fn has_ordering_constraints(&self, _var: &str) -> bool {
379 false
381 }
382}
383
384pub struct AuxiliaryVariableGenerator {
386 counter: usize,
388 prefix: String,
390}
391
392impl AuxiliaryVariableGenerator {
393 pub fn new(prefix: &str) -> Self {
395 Self {
396 counter: 0,
397 prefix: prefix.to_string(),
398 }
399 }
400
401 pub fn next(&mut self) -> String {
403 let name = format!("{}_{}", self.prefix, self.counter);
404 self.counter += 1;
405 name
406 }
407
408 pub fn product_encoding(
410 &mut self,
411 _var1: &str,
412 _var2: &str,
413 enc1: &EncodedVariable,
414 enc2: &EncodedVariable,
415 ) -> Vec<(String, Vec<String>)> {
416 let mut auxiliaries = Vec::new();
417
418 for bin1 in &enc1.binary_vars {
420 for bin2 in &enc2.binary_vars {
421 let aux = self.next();
422 auxiliaries.push((aux.clone(), vec![bin1.clone(), bin2.clone()]));
423 }
424 }
425
426 auxiliaries
427 }
428}
429
430pub struct EncodingConverter {
432 encodings: HashMap<String, EncodedVariable>,
434 aux_gen: AuxiliaryVariableGenerator,
436}
437
438impl Default for EncodingConverter {
439 fn default() -> Self {
440 Self::new()
441 }
442}
443
444impl EncodingConverter {
445 pub fn new() -> Self {
447 Self {
448 encodings: HashMap::new(),
449 aux_gen: AuxiliaryVariableGenerator::new("aux"),
450 }
451 }
452
453 pub fn add_variable(&mut self, encoded: EncodedVariable) {
455 self.encodings.insert(encoded.name.clone(), encoded);
456 }
457
458 pub fn get_binary_variables(&self) -> Vec<String> {
460 let mut vars = Vec::new();
461 for encoded in self.encodings.values() {
462 vars.extend(encoded.binary_vars.clone());
463 }
464 vars
465 }
466
467 pub fn build_qubo_matrix(&self, _base_matrix: Array2<f64>) -> Array2<f64> {
469 let binary_vars = self.get_binary_variables();
470 let var_indices: HashMap<String, usize> = binary_vars
471 .iter()
472 .enumerate()
473 .map(|(i, v)| (v.clone(), i))
474 .collect();
475
476 let n = binary_vars.len();
477 let mut qubo = Array2::zeros((n, n));
478
479 for encoded in self.encodings.values() {
481 let penalty = encoded.get_penalty_matrix(&var_indices);
482 qubo = qubo + penalty;
483 }
484
485 qubo
489 }
490}
491
492pub fn compare_encodings(
494 domain_size: usize,
495 constraint_density: f64,
496) -> HashMap<String, EncodingMetrics> {
497 let mut results = HashMap::new();
498
499 let onehot_bits = domain_size;
501 let onehot_constraints = domain_size * (domain_size - 1) / 2; results.insert(
503 "one-hot".to_string(),
504 EncodingMetrics {
505 num_bits: onehot_bits,
506 num_constraints: onehot_constraints,
507 avg_connectivity: domain_size as f64 - 1.0,
508 space_efficiency: 1.0 / domain_size as f64,
509 },
510 );
511
512 let binary_bits = (domain_size as f64).log2().ceil() as usize;
514 results.insert(
515 "binary".to_string(),
516 EncodingMetrics {
517 num_bits: binary_bits,
518 num_constraints: 0,
519 avg_connectivity: constraint_density * binary_bits as f64,
520 space_efficiency: (domain_size as f64).log2() / domain_size as f64,
521 },
522 );
523
524 let dw_bits = domain_size - 1;
526 let dw_constraints = domain_size - 1;
527 results.insert(
528 "domain-wall".to_string(),
529 EncodingMetrics {
530 num_bits: dw_bits,
531 num_constraints: dw_constraints,
532 avg_connectivity: 2.0,
533 space_efficiency: 1.0 / domain_size as f64,
534 },
535 );
536
537 results
538}
539
540#[derive(Debug, Clone)]
541pub struct EncodingMetrics {
542 pub num_bits: usize,
543 pub num_constraints: usize,
544 pub avg_connectivity: f64,
545 pub space_efficiency: f64,
546}
547
548#[cfg(test)]
549mod tests {
550 use super::*;
551
552 #[test]
553 fn test_one_hot_encoding() {
554 let encoded = EncodedVariable::new("x", EncodingScheme::OneHot { num_values: 4 });
555 assert_eq!(encoded.binary_vars.len(), 4);
556
557 let mut binary = encoded.encode(2);
559 assert!(!binary[&"x_0".to_string()]);
560 assert!(!binary[&"x_1".to_string()]);
561 assert!(binary[&"x_2".to_string()]);
562 assert!(!binary[&"x_3".to_string()]);
563
564 let value = encoded
566 .decode(&binary)
567 .expect("Failed to decode one-hot value");
568 assert_eq!(value, 2);
569 }
570
571 #[test]
572 fn test_binary_encoding() {
573 let encoded = EncodedVariable::new("y", EncodingScheme::Binary { num_values: 8 });
574 assert_eq!(encoded.binary_vars.len(), 3); let mut binary = encoded.encode(5);
578 assert!(binary[&"y_bit0".to_string()]);
579 assert!(!binary[&"y_bit1".to_string()]);
580 assert!(binary[&"y_bit2".to_string()]);
581
582 let value = encoded
583 .decode(&binary)
584 .expect("Failed to decode binary value");
585 assert_eq!(value, 5);
586 }
587
588 #[test]
589 fn test_domain_wall_encoding() {
590 let encoded = EncodedVariable::new("z", EncodingScheme::DomainWall { num_values: 5 });
591 assert_eq!(encoded.binary_vars.len(), 4);
592
593 let mut binary = encoded.encode(2);
595 assert!(binary[&"z_dw0".to_string()]);
596 assert!(binary[&"z_dw1".to_string()]);
597 assert!(!binary[&"z_dw2".to_string()]);
598 assert!(!binary[&"z_dw3".to_string()]);
599
600 let value = encoded
601 .decode(&binary)
602 .expect("Failed to decode domain wall value");
603 assert_eq!(value, 2);
604 }
605
606 #[test]
607 fn test_encoding_optimizer() {
608 let mut optimizer = EncodingOptimizer::new();
609 optimizer.add_variable("small", 0, 3);
610 optimizer.add_variable("large", 0, 100);
611 optimizer.add_variable("binary", 0, 1);
612
613 let encodings = optimizer.optimize_encodings();
614
615 match &encodings["binary"] {
617 EncodingScheme::Direct => {}
618 _ => panic!("Expected direct encoding for binary variable"),
619 }
620 }
621}