quantrs2_tytan/
encoding.rs

1//! Variable encoding schemes for optimization problems.
2//!
3//! This module provides various encoding schemes to represent different types
4//! of variables and constraints as binary optimization problems.
5
6use scirs2_core::ndarray::Array2;
7use std::collections::HashMap;
8
9#[cfg(feature = "dwave")]
10use crate::symbol::Expression;
11
12/// Variable encoding scheme
13#[derive(Debug, Clone)]
14pub enum EncodingScheme {
15    /// One-hot encoding: exactly one bit is 1
16    OneHot { num_values: usize },
17    /// Binary encoding: log2(n) bits for n values
18    Binary { num_values: usize },
19    /// Gray code encoding
20    GrayCode { num_values: usize },
21    /// Domain wall encoding: string of 1s followed by 0s
22    DomainWall { num_values: usize },
23    /// Unary/thermometer encoding: first k bits are 1
24    Unary { num_values: usize },
25    /// Order encoding: bit i is 1 if value >= i
26    OrderEncoding { min_value: i32, max_value: i32 },
27    /// Direct binary (for binary variables)
28    Direct,
29}
30
31/// Encoded variable representation
32#[derive(Debug, Clone)]
33pub struct EncodedVariable {
34    /// Original variable name
35    pub name: String,
36    /// Encoding scheme used
37    pub scheme: EncodingScheme,
38    /// Binary variable names
39    pub binary_vars: Vec<String>,
40    /// Encoding constraints (as penalty terms)
41    #[cfg(feature = "dwave")]
42    pub constraints: Option<Expression>,
43}
44
45impl EncodedVariable {
46    /// Create new encoded variable
47    pub fn new(name: &str, scheme: EncodingScheme) -> Self {
48        let binary_vars = Self::generate_binary_vars(name, &scheme);
49        Self {
50            name: name.to_string(),
51            scheme,
52            binary_vars,
53            #[cfg(feature = "dwave")]
54            constraints: None,
55        }
56    }
57
58    /// Generate binary variable names based on encoding
59    fn generate_binary_vars(name: &str, scheme: &EncodingScheme) -> Vec<String> {
60        match scheme {
61            EncodingScheme::OneHot { num_values } => {
62                (0..*num_values).map(|i| format!("{name}_{i}")).collect()
63            }
64            EncodingScheme::Binary { num_values } => {
65                let num_bits = (*num_values as f64).log2().ceil() as usize;
66                (0..num_bits).map(|i| format!("{name}_bit{i}")).collect()
67            }
68            EncodingScheme::GrayCode { num_values } => {
69                let num_bits = (*num_values as f64).log2().ceil() as usize;
70                (0..num_bits).map(|i| format!("{name}_gray{i}")).collect()
71            }
72            EncodingScheme::DomainWall { num_values } => (0..*num_values - 1)
73                .map(|i| format!("{name}_dw{i}"))
74                .collect(),
75            EncodingScheme::Unary { num_values } => (0..*num_values - 1)
76                .map(|i| format!("{name}_u{i}"))
77                .collect(),
78            EncodingScheme::OrderEncoding {
79                min_value,
80                max_value,
81            } => {
82                let range = max_value - min_value;
83                (0..range).map(|i| format!("{name}_ord{i}")).collect()
84            }
85            EncodingScheme::Direct => vec![name.to_string()],
86        }
87    }
88
89    /// Decode binary values to original value
90    pub fn decode(&self, binary_values: &HashMap<String, bool>) -> Option<i32> {
91        match &self.scheme {
92            EncodingScheme::OneHot { .. } => {
93                for (i, var) in self.binary_vars.iter().enumerate() {
94                    if binary_values.get(var).copied().unwrap_or(false) {
95                        return Some(i as i32);
96                    }
97                }
98                None // Invalid: no bit set
99            }
100            EncodingScheme::Binary { .. } => {
101                let mut value = 0;
102                for (i, var) in self.binary_vars.iter().enumerate() {
103                    if binary_values.get(var).copied().unwrap_or(false) {
104                        value |= 1 << i;
105                    }
106                }
107                Some(value)
108            }
109            EncodingScheme::GrayCode { .. } => {
110                let mut gray = 0;
111                for (i, var) in self.binary_vars.iter().enumerate() {
112                    if binary_values.get(var).copied().unwrap_or(false) {
113                        gray |= 1 << i;
114                    }
115                }
116                // Convert Gray code to binary
117                let mut binary = gray;
118                binary ^= binary >> 16;
119                binary ^= binary >> 8;
120                binary ^= binary >> 4;
121                binary ^= binary >> 2;
122                binary ^= binary >> 1;
123                Some(binary)
124            }
125            EncodingScheme::DomainWall { num_values } => {
126                let mut value = *num_values as i32 - 1;
127                for (i, var) in self.binary_vars.iter().enumerate() {
128                    if !binary_values.get(var).copied().unwrap_or(false) {
129                        value = i as i32;
130                        break;
131                    }
132                }
133                Some(value)
134            }
135            EncodingScheme::Unary { .. } => {
136                let mut value = 0;
137                for var in &self.binary_vars {
138                    if binary_values.get(var).copied().unwrap_or(false) {
139                        value += 1;
140                    } else {
141                        break;
142                    }
143                }
144                Some(value)
145            }
146            EncodingScheme::OrderEncoding { min_value, .. } => {
147                let mut value = *min_value;
148                for var in &self.binary_vars {
149                    if binary_values.get(var).copied().unwrap_or(false) {
150                        value += 1;
151                    }
152                }
153                Some(value - 1)
154            }
155            EncodingScheme::Direct => binary_values.get(&self.name).map(|&b| i32::from(b)),
156        }
157    }
158
159    /// Encode value to binary representation
160    pub fn encode(&self, value: i32) -> HashMap<String, bool> {
161        let mut binary_values = HashMap::new();
162
163        match &self.scheme {
164            EncodingScheme::OneHot { num_values: _ } => {
165                for (i, var) in self.binary_vars.iter().enumerate() {
166                    binary_values.insert(var.clone(), i == value as usize);
167                }
168            }
169            EncodingScheme::Binary { .. } => {
170                for (i, var) in self.binary_vars.iter().enumerate() {
171                    binary_values.insert(var.clone(), (value & (1 << i)) != 0);
172                }
173            }
174            EncodingScheme::GrayCode { .. } => {
175                // Convert to Gray code
176                let gray = value ^ (value >> 1);
177                for (i, var) in self.binary_vars.iter().enumerate() {
178                    binary_values.insert(var.clone(), (gray & (1 << i)) != 0);
179                }
180            }
181            EncodingScheme::DomainWall { num_values: _ } => {
182                for (i, var) in self.binary_vars.iter().enumerate() {
183                    binary_values.insert(var.clone(), i < value as usize);
184                }
185            }
186            EncodingScheme::Unary { .. } => {
187                for (i, var) in self.binary_vars.iter().enumerate() {
188                    binary_values.insert(var.clone(), i < value as usize);
189                }
190            }
191            EncodingScheme::OrderEncoding { min_value, .. } => {
192                let adjusted = value - min_value + 1;
193                for (i, var) in self.binary_vars.iter().enumerate() {
194                    binary_values.insert(var.clone(), i < adjusted as usize);
195                }
196            }
197            EncodingScheme::Direct => {
198                binary_values.insert(self.name.clone(), value != 0);
199            }
200        }
201
202        binary_values
203    }
204
205    /// Get penalty matrix for encoding constraints
206    pub fn get_penalty_matrix(&self, var_indices: &HashMap<String, usize>) -> Array2<f64> {
207        let n = var_indices.len();
208        let mut penalty = Array2::zeros((n, n));
209
210        match &self.scheme {
211            EncodingScheme::OneHot { .. } => {
212                // Exactly one bit must be 1
213                // Penalty: (sum(xi) - 1)^2 = sum(xi)^2 - 2*sum(xi) + 1
214
215                // Get indices of our binary variables
216                let indices: Vec<usize> = self
217                    .binary_vars
218                    .iter()
219                    .filter_map(|var| var_indices.get(var).copied())
220                    .collect();
221
222                // Quadratic terms: xi * xj for i != j
223                for &i in &indices {
224                    for &j in &indices {
225                        if i != j {
226                            penalty[[i, j]] += 1.0;
227                        }
228                    }
229                }
230
231                // Linear terms: -2 * xi
232                for &i in &indices {
233                    penalty[[i, i]] -= 2.0;
234                }
235            }
236            EncodingScheme::DomainWall { .. } => {
237                // Domain wall constraint: xi >= xi+1
238                // Penalty for violation: xi+1 * (1 - xi)
239
240                let indices: Vec<usize> = self
241                    .binary_vars
242                    .iter()
243                    .filter_map(|var| var_indices.get(var).copied())
244                    .collect();
245
246                for i in 0..indices.len() - 1 {
247                    let idx1 = indices[i];
248                    let idx2 = indices[i + 1];
249
250                    // Penalty term: x_{i+1} - x_i * x_{i+1}
251                    penalty[[idx2, idx2]] += 1.0;
252                    penalty[[idx1, idx2]] -= 1.0;
253                    penalty[[idx2, idx1]] -= 1.0;
254                }
255            }
256            EncodingScheme::Unary { .. } => {
257                // Unary constraint: xi >= xi+1
258                // Same as domain wall
259                let indices: Vec<usize> = self
260                    .binary_vars
261                    .iter()
262                    .filter_map(|var| var_indices.get(var).copied())
263                    .collect();
264
265                for i in 0..indices.len() - 1 {
266                    let idx1 = indices[i];
267                    let idx2 = indices[i + 1];
268
269                    penalty[[idx2, idx2]] += 1.0;
270                    penalty[[idx1, idx2]] -= 1.0;
271                    penalty[[idx2, idx1]] -= 1.0;
272                }
273            }
274            _ => {
275                // No encoding constraints for binary, gray code, or direct
276            }
277        }
278
279        penalty
280    }
281}
282
283/// Encoding optimizer that selects best encoding for variables
284pub struct EncodingOptimizer {
285    /// Variable domains
286    domains: HashMap<String, (i32, i32)>,
287    /// Constraint information
288    constraint_graph: HashMap<String, Vec<String>>,
289}
290
291impl Default for EncodingOptimizer {
292    fn default() -> Self {
293        Self::new()
294    }
295}
296
297impl EncodingOptimizer {
298    /// Create new encoding optimizer
299    pub fn new() -> Self {
300        Self {
301            domains: HashMap::new(),
302            constraint_graph: HashMap::new(),
303        }
304    }
305
306    /// Add variable with domain
307    pub fn add_variable(&mut self, name: &str, min_value: i32, max_value: i32) {
308        self.domains
309            .insert(name.to_string(), (min_value, max_value));
310    }
311
312    /// Add constraint between variables
313    pub fn add_constraint(&mut self, var1: &str, var2: &str) {
314        self.constraint_graph
315            .entry(var1.to_string())
316            .or_default()
317            .push(var2.to_string());
318        self.constraint_graph
319            .entry(var2.to_string())
320            .or_default()
321            .push(var1.to_string());
322    }
323
324    /// Select optimal encoding for each variable
325    pub fn optimize_encodings(&self) -> HashMap<String, EncodingScheme> {
326        let mut encodings = HashMap::new();
327
328        for (var, &(min_val, max_val)) in &self.domains {
329            let domain_size = (max_val - min_val + 1) as usize;
330            let neighbors = self.constraint_graph.get(var).map_or(0, |v| v.len());
331
332            // Heuristics for encoding selection
333            let encoding = if domain_size == 2 {
334                // Binary variable
335                EncodingScheme::Direct
336            } else if domain_size <= 4 && neighbors > 3 {
337                // Small domain with many constraints: one-hot
338                EncodingScheme::OneHot {
339                    num_values: domain_size,
340                }
341            } else if domain_size <= 8 {
342                // Medium domain: binary or gray code
343                if self.has_ordering_constraints(var) {
344                    EncodingScheme::GrayCode {
345                        num_values: domain_size,
346                    }
347                } else {
348                    EncodingScheme::Binary {
349                        num_values: domain_size,
350                    }
351                }
352            } else if self.has_ordering_constraints(var) {
353                // Large ordered domain: order encoding or domain wall
354                if domain_size <= 32 {
355                    EncodingScheme::OrderEncoding {
356                        min_value: min_val,
357                        max_value: max_val,
358                    }
359                } else {
360                    EncodingScheme::DomainWall {
361                        num_values: domain_size,
362                    }
363                }
364            } else {
365                // Large unordered domain: binary
366                EncodingScheme::Binary {
367                    num_values: domain_size,
368                }
369            };
370
371            encodings.insert(var.clone(), encoding);
372        }
373
374        encodings
375    }
376
377    /// Check if variable has ordering constraints
378    const fn has_ordering_constraints(&self, _var: &str) -> bool {
379        // Simplified: would check actual constraint types
380        false
381    }
382}
383
384/// Auxiliary variable generator for complex encodings
385pub struct AuxiliaryVariableGenerator {
386    /// Counter for generating unique names
387    counter: usize,
388    /// Prefix for auxiliary variables
389    prefix: String,
390}
391
392impl AuxiliaryVariableGenerator {
393    /// Create new generator
394    pub fn new(prefix: &str) -> Self {
395        Self {
396            counter: 0,
397            prefix: prefix.to_string(),
398        }
399    }
400
401    /// Generate new auxiliary variable name
402    pub fn next(&mut self) -> String {
403        let name = format!("{}_{}", self.prefix, self.counter);
404        self.counter += 1;
405        name
406    }
407
408    /// Generate auxiliary variables for product encoding
409    pub fn product_encoding(
410        &mut self,
411        _var1: &str,
412        _var2: &str,
413        enc1: &EncodedVariable,
414        enc2: &EncodedVariable,
415    ) -> Vec<(String, Vec<String>)> {
416        let mut auxiliaries = Vec::new();
417
418        // Generate auxiliary for each pair of binary variables
419        for bin1 in &enc1.binary_vars {
420            for bin2 in &enc2.binary_vars {
421                let aux = self.next();
422                auxiliaries.push((aux.clone(), vec![bin1.clone(), bin2.clone()]));
423            }
424        }
425
426        auxiliaries
427    }
428}
429
430/// Convert integer program to QUBO using encodings
431pub struct EncodingConverter {
432    /// Variable encodings
433    encodings: HashMap<String, EncodedVariable>,
434    /// Auxiliary variable generator
435    aux_gen: AuxiliaryVariableGenerator,
436}
437
438impl Default for EncodingConverter {
439    fn default() -> Self {
440        Self::new()
441    }
442}
443
444impl EncodingConverter {
445    /// Create new converter
446    pub fn new() -> Self {
447        Self {
448            encodings: HashMap::new(),
449            aux_gen: AuxiliaryVariableGenerator::new("aux"),
450        }
451    }
452
453    /// Add encoded variable
454    pub fn add_variable(&mut self, encoded: EncodedVariable) {
455        self.encodings.insert(encoded.name.clone(), encoded);
456    }
457
458    /// Get all binary variables
459    pub fn get_binary_variables(&self) -> Vec<String> {
460        let mut vars = Vec::new();
461        for encoded in self.encodings.values() {
462            vars.extend(encoded.binary_vars.clone());
463        }
464        vars
465    }
466
467    /// Build QUBO matrix with encoding penalties
468    pub fn build_qubo_matrix(&self, _base_matrix: Array2<f64>) -> Array2<f64> {
469        let binary_vars = self.get_binary_variables();
470        let var_indices: HashMap<String, usize> = binary_vars
471            .iter()
472            .enumerate()
473            .map(|(i, v)| (v.clone(), i))
474            .collect();
475
476        let n = binary_vars.len();
477        let mut qubo = Array2::zeros((n, n));
478
479        // Add encoding penalties
480        for encoded in self.encodings.values() {
481            let penalty = encoded.get_penalty_matrix(&var_indices);
482            qubo = qubo + penalty;
483        }
484
485        // Add base problem matrix (would need proper mapping)
486        // This is simplified - real implementation would map original to binary vars
487
488        qubo
489    }
490}
491
492/// Compare different encodings
493pub fn compare_encodings(
494    domain_size: usize,
495    constraint_density: f64,
496) -> HashMap<String, EncodingMetrics> {
497    let mut results = HashMap::new();
498
499    // One-hot encoding
500    let onehot_bits = domain_size;
501    let onehot_constraints = domain_size * (domain_size - 1) / 2; // Quadratic
502    results.insert(
503        "one-hot".to_string(),
504        EncodingMetrics {
505            num_bits: onehot_bits,
506            num_constraints: onehot_constraints,
507            avg_connectivity: domain_size as f64 - 1.0,
508            space_efficiency: 1.0 / domain_size as f64,
509        },
510    );
511
512    // Binary encoding
513    let binary_bits = (domain_size as f64).log2().ceil() as usize;
514    results.insert(
515        "binary".to_string(),
516        EncodingMetrics {
517            num_bits: binary_bits,
518            num_constraints: 0,
519            avg_connectivity: constraint_density * binary_bits as f64,
520            space_efficiency: (domain_size as f64).log2() / domain_size as f64,
521        },
522    );
523
524    // Domain wall encoding
525    let dw_bits = domain_size - 1;
526    let dw_constraints = domain_size - 1;
527    results.insert(
528        "domain-wall".to_string(),
529        EncodingMetrics {
530            num_bits: dw_bits,
531            num_constraints: dw_constraints,
532            avg_connectivity: 2.0,
533            space_efficiency: 1.0 / domain_size as f64,
534        },
535    );
536
537    results
538}
539
540#[derive(Debug, Clone)]
541pub struct EncodingMetrics {
542    pub num_bits: usize,
543    pub num_constraints: usize,
544    pub avg_connectivity: f64,
545    pub space_efficiency: f64,
546}
547
548#[cfg(test)]
549mod tests {
550    use super::*;
551
552    #[test]
553    fn test_one_hot_encoding() {
554        let encoded = EncodedVariable::new("x", EncodingScheme::OneHot { num_values: 4 });
555        assert_eq!(encoded.binary_vars.len(), 4);
556
557        // Encode value 2
558        let mut binary = encoded.encode(2);
559        assert!(!binary[&"x_0".to_string()]);
560        assert!(!binary[&"x_1".to_string()]);
561        assert!(binary[&"x_2".to_string()]);
562        assert!(!binary[&"x_3".to_string()]);
563
564        // Decode back
565        let value = encoded
566            .decode(&binary)
567            .expect("Failed to decode one-hot value");
568        assert_eq!(value, 2);
569    }
570
571    #[test]
572    fn test_binary_encoding() {
573        let encoded = EncodedVariable::new("y", EncodingScheme::Binary { num_values: 8 });
574        assert_eq!(encoded.binary_vars.len(), 3); // log2(8) = 3
575
576        // Encode value 5 (binary: 101)
577        let mut binary = encoded.encode(5);
578        assert!(binary[&"y_bit0".to_string()]);
579        assert!(!binary[&"y_bit1".to_string()]);
580        assert!(binary[&"y_bit2".to_string()]);
581
582        let value = encoded
583            .decode(&binary)
584            .expect("Failed to decode binary value");
585        assert_eq!(value, 5);
586    }
587
588    #[test]
589    fn test_domain_wall_encoding() {
590        let encoded = EncodedVariable::new("z", EncodingScheme::DomainWall { num_values: 5 });
591        assert_eq!(encoded.binary_vars.len(), 4);
592
593        // Encode value 2 (domain wall: 1100)
594        let mut binary = encoded.encode(2);
595        assert!(binary[&"z_dw0".to_string()]);
596        assert!(binary[&"z_dw1".to_string()]);
597        assert!(!binary[&"z_dw2".to_string()]);
598        assert!(!binary[&"z_dw3".to_string()]);
599
600        let value = encoded
601            .decode(&binary)
602            .expect("Failed to decode domain wall value");
603        assert_eq!(value, 2);
604    }
605
606    #[test]
607    fn test_encoding_optimizer() {
608        let mut optimizer = EncodingOptimizer::new();
609        optimizer.add_variable("small", 0, 3);
610        optimizer.add_variable("large", 0, 100);
611        optimizer.add_variable("binary", 0, 1);
612
613        let encodings = optimizer.optimize_encodings();
614
615        // Binary variable should use direct encoding
616        match &encodings["binary"] {
617            EncodingScheme::Direct => {}
618            _ => panic!("Expected direct encoding for binary variable"),
619        }
620    }
621}