1use crate::encryption::{
37 EncryptionError, EncryptionKey, EncryptionNonce, decrypt, encrypt, generate_nonce,
38};
39use crate::hash::hash;
40use rand::{RngCore, thread_rng};
41use serde::{Deserialize, Serialize};
42use std::collections::HashMap;
43use thiserror::Error;
44
45pub type GarbledCircuitResult<T> = Result<T, GarbledCircuitError>;
47
48#[derive(Debug, Error)]
50pub enum GarbledCircuitError {
51 #[error("Encryption error: {0}")]
52 Encryption(#[from] EncryptionError),
53
54 #[error("Serialization error: {0}")]
55 Serialization(String),
56
57 #[error("Deserialization error: {0}")]
58 Deserialization(String),
59
60 #[error("Invalid input: {0}")]
61 InvalidInput(String),
62}
63
64#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
66pub struct WireLabel {
67 data: [u8; 16],
69 permute_bit: bool,
71}
72
73impl WireLabel {
74 pub fn random() -> Self {
76 let mut rng = thread_rng();
77 let mut data = [0u8; 16];
78 rng.fill_bytes(&mut data);
79 let mut random_byte = [0u8; 1];
80 rng.fill_bytes(&mut random_byte);
81 Self {
82 data,
83 permute_bit: random_byte[0] & 1 == 1,
84 }
85 }
86
87 pub fn data(&self) -> &[u8; 16] {
89 &self.data
90 }
91
92 pub fn permute_bit(&self) -> bool {
94 self.permute_bit
95 }
96
97 pub fn xor(&self, other: &WireLabel) -> WireLabel {
99 let mut result = [0u8; 16];
100 for (i, item) in result.iter_mut().enumerate() {
101 *item = self.data[i] ^ other.data[i];
102 }
103 WireLabel {
104 data: result,
105 permute_bit: self.permute_bit ^ other.permute_bit,
106 }
107 }
108
109 pub fn from_bytes(data: [u8; 16], permute_bit: bool) -> Self {
111 Self { data, permute_bit }
112 }
113}
114
115#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
117pub enum GateType {
118 And,
119 Or,
120 Xor,
121 Not,
122}
123
124#[derive(Clone, Debug, Serialize, Deserialize)]
126pub struct Gate {
127 gate_type: GateType,
128 input_a: usize,
129 input_b: Option<usize>, output: usize,
131}
132
133impl Gate {
134 pub fn new(gate_type: GateType, input_a: usize, input_b: usize, output: usize) -> Self {
136 Self {
137 gate_type,
138 input_a,
139 input_b: Some(input_b),
140 output,
141 }
142 }
143
144 pub fn not(input: usize, output: usize) -> Self {
146 Self {
147 gate_type: GateType::Not,
148 input_a: input,
149 input_b: None,
150 output,
151 }
152 }
153}
154
155#[derive(Clone, Debug, Serialize, Deserialize)]
157pub struct GarbledGate {
158 gate_type: GateType,
159 encrypted_table: Vec<Vec<u8>>,
161 nonces: Vec<EncryptionNonce>,
163}
164
165#[derive(Clone, Debug)]
167pub struct Circuit {
168 gates: Vec<Gate>,
169 num_wires: usize,
170 input_wires: Vec<usize>,
171 output_wires: Vec<usize>,
172}
173
174impl Circuit {
175 pub fn new() -> Self {
177 Self {
178 gates: Vec::new(),
179 num_wires: 0,
180 input_wires: Vec::new(),
181 output_wires: Vec::new(),
182 }
183 }
184
185 pub fn add_wire(&mut self) -> usize {
187 let idx = self.num_wires;
188 self.num_wires += 1;
189 idx
190 }
191
192 pub fn add_input_wire(&mut self) -> usize {
194 let idx = self.add_wire();
195 self.input_wires.push(idx);
196 idx
197 }
198
199 pub fn set_output_wire(&mut self, wire: usize) {
201 self.output_wires.push(wire);
202 }
203
204 pub fn add_gate(&mut self, gate: Gate) {
206 self.gates.push(gate);
207 }
208
209 pub fn num_inputs(&self) -> usize {
211 self.input_wires.len()
212 }
213
214 pub fn garble(&self) -> GarbledCircuit {
216 let mut rng = thread_rng();
217
218 let mut global_offset = [0u8; 16];
220 rng.fill_bytes(&mut global_offset);
221 let global_offset = WireLabel::from_bytes(global_offset, true);
222
223 let mut wire_labels: HashMap<usize, (WireLabel, WireLabel)> = HashMap::new();
225
226 for i in 0..self.num_wires {
228 let label_0 = WireLabel::random();
229 let label_1 = label_0.xor(&global_offset); wire_labels.insert(i, (label_0, label_1));
231 }
232
233 for gate in &self.gates {
235 if gate.gate_type == GateType::Xor {
236 if let Some(input_b) = gate.input_b {
237 let (a0, _a1) = wire_labels[&gate.input_a];
238 let (b0, _b1) = wire_labels[&input_b];
239 let out_0 = a0.xor(&b0);
241 let out_1 = out_0.xor(&global_offset);
242 wire_labels.insert(gate.output, (out_0, out_1));
243 }
244 }
245 }
246
247 let mut garbled_gates = Vec::new();
249 for gate in &self.gates {
250 let garbled_gate = self.garble_gate(gate, &wire_labels);
251 garbled_gates.push(garbled_gate);
252 }
253
254 let input_labels: Vec<(WireLabel, WireLabel)> = self
256 .input_wires
257 .iter()
258 .map(|&wire| wire_labels[&wire])
259 .collect();
260
261 let output_labels: Vec<(WireLabel, WireLabel)> = self
263 .output_wires
264 .iter()
265 .map(|&wire| wire_labels[&wire])
266 .collect();
267
268 GarbledCircuit {
269 gates: garbled_gates,
270 input_labels,
271 output_labels,
272 num_inputs: self.input_wires.len(),
273 gate_topology: self.gates.clone(),
274 }
275 }
276
277 #[allow(clippy::too_many_arguments)]
279 fn garble_gate(
280 &self,
281 gate: &Gate,
282 wire_labels: &HashMap<usize, (WireLabel, WireLabel)>,
283 ) -> GarbledGate {
284 match gate.gate_type {
285 GateType::Xor => {
286 GarbledGate {
288 gate_type: GateType::Xor,
289 encrypted_table: Vec::new(), nonces: Vec::new(),
291 }
292 }
293 GateType::Not => {
294 GarbledGate {
296 gate_type: GateType::Not,
297 encrypted_table: Vec::new(),
298 nonces: Vec::new(),
299 }
300 }
301 GateType::And | GateType::Or => {
302 let input_b = gate.input_b.expect("Binary gate must have two inputs");
304
305 let (a0, a1) = wire_labels[&gate.input_a];
306 let (b0, b1) = wire_labels[&input_b];
307 let (out0, out1) = wire_labels[&gate.output];
308
309 let truth_table = match gate.gate_type {
311 GateType::And => [
312 (a0, b0, out0),
313 (a0, b1, out0),
314 (a1, b0, out0),
315 (a1, b1, out1),
316 ],
317 GateType::Or => [
318 (a0, b0, out0),
319 (a0, b1, out1),
320 (a1, b0, out1),
321 (a1, b1, out1),
322 ],
323 _ => unreachable!(),
324 };
325
326 let mut encrypted_table = Vec::new();
328 let mut nonces = Vec::new();
329 for (i, (label_a, label_b, output_label)) in truth_table.iter().enumerate() {
330 let mut key_material = Vec::new();
332 key_material.extend_from_slice(label_a.data());
333 key_material.extend_from_slice(label_b.data());
334 key_material.extend_from_slice(&[i as u8]);
335
336 let key_hash = hash(&key_material);
337 let key: EncryptionKey = key_hash;
338
339 let nonce = generate_nonce();
341
342 let mut plaintext = Vec::new();
344 plaintext.extend_from_slice(output_label.data());
345 plaintext.push(if output_label.permute_bit() { 1 } else { 0 });
346
347 let encrypted = encrypt(&plaintext, &key, &nonce).expect("Encryption failed");
348 encrypted_table.push(encrypted);
349 nonces.push(nonce);
350 }
351
352 GarbledGate {
353 gate_type: gate.gate_type,
354 encrypted_table,
355 nonces,
356 }
357 }
358 }
359 }
360}
361
362impl Default for Circuit {
363 fn default() -> Self {
364 Self::new()
365 }
366}
367
368#[derive(Clone, Debug, Serialize, Deserialize)]
370pub struct GarbledCircuit {
371 gates: Vec<GarbledGate>,
372 input_labels: Vec<(WireLabel, WireLabel)>,
373 output_labels: Vec<(WireLabel, WireLabel)>,
374 num_inputs: usize,
375 gate_topology: Vec<Gate>,
376}
377
378impl GarbledCircuit {
379 pub fn evaluate(&self, inputs: &[bool]) -> bool {
381 if inputs.len() != self.num_inputs {
382 panic!("Invalid number of inputs");
383 }
384
385 let mut wire_values: HashMap<usize, WireLabel> = HashMap::new();
387 for (i, &input_val) in inputs.iter().enumerate() {
388 let (label_0, label_1) = self.input_labels[i];
389 wire_values.insert(i, if input_val { label_1 } else { label_0 });
390 }
391
392 for (gate, garbled_gate) in self.gate_topology.iter().zip(&self.gates) {
394 let output_label = self.evaluate_gate(gate, garbled_gate, &wire_values);
395 wire_values.insert(gate.output, output_label);
396 }
397
398 let output_wire_idx = if self.gate_topology.is_empty() {
401 0 } else {
403 self.gate_topology.last().unwrap().output
404 };
405
406 let output_label = wire_values
407 .get(&output_wire_idx)
408 .copied()
409 .unwrap_or(self.output_labels[0].0);
410 let (_label_0, label_1) = self.output_labels[0];
411
412 output_label == label_1
414 }
415
416 fn evaluate_gate(
418 &self,
419 gate: &Gate,
420 garbled_gate: &GarbledGate,
421 wire_values: &HashMap<usize, WireLabel>,
422 ) -> WireLabel {
423 match garbled_gate.gate_type {
424 GateType::Xor => {
425 let input_b = gate.input_b.expect("XOR gate must have two inputs");
426 let label_a = wire_values[&gate.input_a];
427 let label_b = wire_values[&input_b];
428 label_a.xor(&label_b) }
430 GateType::Not => {
431 wire_values[&gate.input_a]
434 }
435 GateType::And | GateType::Or => {
436 let input_b = gate.input_b.expect("Binary gate must have two inputs");
437 let label_a = wire_values[&gate.input_a];
438 let label_b = wire_values[&input_b];
439
440 for row_index in 0..4 {
442 let mut key_material = Vec::new();
444 key_material.extend_from_slice(label_a.data());
445 key_material.extend_from_slice(label_b.data());
446 key_material.extend_from_slice(&[row_index as u8]);
447
448 let key_hash = hash(&key_material);
449 let key: EncryptionKey = key_hash;
450
451 let encrypted = &garbled_gate.encrypted_table[row_index];
453 let nonce = &garbled_gate.nonces[row_index];
454
455 if let Ok(decrypted) = decrypt(encrypted, &key, nonce) {
456 let mut label_data = [0u8; 16];
457 label_data.copy_from_slice(&decrypted[..16]);
458 let permute_bit = if decrypted.len() > 16 {
460 decrypted[16] == 1
461 } else {
462 false
463 };
464 return WireLabel::from_bytes(label_data, permute_bit);
465 }
466 }
467
468 panic!("Failed to decrypt any row for gate");
470 }
471 }
472 }
473
474 pub fn get_input_labels(&self, input_index: usize) -> (WireLabel, WireLabel) {
476 self.input_labels[input_index]
477 }
478
479 pub fn to_bytes(&self) -> GarbledCircuitResult<Vec<u8>> {
481 crate::codec::encode(self).map_err(|e| GarbledCircuitError::Serialization(format!("{}", e)))
482 }
483
484 pub fn from_bytes(bytes: &[u8]) -> GarbledCircuitResult<Self> {
486 crate::codec::decode(bytes)
487 .map_err(|e| GarbledCircuitError::Deserialization(format!("{}", e)))
488 }
489}
490
491#[cfg(test)]
492mod tests {
493 use super::*;
494
495 #[test]
496 fn test_wire_label_xor() {
497 let label1 = WireLabel::random();
498 let label2 = WireLabel::random();
499 let xor_result = label1.xor(&label2);
500
501 let reversed = xor_result.xor(&label2);
503 assert_eq!(reversed.data(), label1.data());
504 }
505
506 #[test]
507 fn test_simple_and_circuit() {
508 let mut circuit = Circuit::new();
509 let wire_a = circuit.add_input_wire();
510 let wire_b = circuit.add_input_wire();
511 let wire_out = circuit.add_wire();
512 circuit.add_gate(Gate::new(GateType::And, wire_a, wire_b, wire_out));
513 circuit.set_output_wire(wire_out);
514
515 let garbled = circuit.garble();
516
517 assert!(!garbled.evaluate(&[false, false]));
519 assert!(!garbled.evaluate(&[false, true]));
520 assert!(!garbled.evaluate(&[true, false]));
521 assert!(garbled.evaluate(&[true, true]));
522 }
523
524 #[test]
525 fn test_simple_or_circuit() {
526 let mut circuit = Circuit::new();
527 let wire_a = circuit.add_input_wire();
528 let wire_b = circuit.add_input_wire();
529 let wire_out = circuit.add_wire();
530 circuit.add_gate(Gate::new(GateType::Or, wire_a, wire_b, wire_out));
531 circuit.set_output_wire(wire_out);
532
533 let garbled = circuit.garble();
534
535 assert!(!garbled.evaluate(&[false, false]));
536 assert!(garbled.evaluate(&[false, true]));
537 assert!(garbled.evaluate(&[true, false]));
538 assert!(garbled.evaluate(&[true, true]));
539 }
540
541 #[test]
542 fn test_simple_xor_circuit() {
543 let mut circuit = Circuit::new();
544 let wire_a = circuit.add_input_wire();
545 let wire_b = circuit.add_input_wire();
546 let wire_out = circuit.add_wire();
547 circuit.add_gate(Gate::new(GateType::Xor, wire_a, wire_b, wire_out));
548 circuit.set_output_wire(wire_out);
549
550 let garbled = circuit.garble();
551
552 assert!(!garbled.evaluate(&[false, false]));
553 assert!(garbled.evaluate(&[false, true]));
554 assert!(garbled.evaluate(&[true, false]));
555 assert!(!garbled.evaluate(&[true, true]));
556 }
557
558 #[test]
559 fn test_multi_gate_circuit() {
560 let mut circuit = Circuit::new();
562 let wire_a = circuit.add_input_wire();
563 let wire_b = circuit.add_input_wire();
564 let wire_c = circuit.add_input_wire();
565 let wire_and = circuit.add_wire();
566 let wire_or = circuit.add_wire();
567
568 circuit.add_gate(Gate::new(GateType::And, wire_a, wire_b, wire_and));
569 circuit.add_gate(Gate::new(GateType::Or, wire_and, wire_c, wire_or));
570 circuit.set_output_wire(wire_or);
571
572 let garbled = circuit.garble();
573
574 assert!(!garbled.evaluate(&[false, false, false]));
576 assert!(garbled.evaluate(&[true, true, false]));
578 assert!(garbled.evaluate(&[false, false, true]));
580 assert!(!garbled.evaluate(&[true, false, false]));
582 }
583
584 #[test]
585 fn test_serialization() {
586 let mut circuit = Circuit::new();
587 let wire_a = circuit.add_input_wire();
588 let wire_b = circuit.add_input_wire();
589 let wire_out = circuit.add_wire();
590 circuit.add_gate(Gate::new(GateType::And, wire_a, wire_b, wire_out));
591 circuit.set_output_wire(wire_out);
592
593 let garbled = circuit.garble();
594
595 let bytes = garbled.to_bytes().unwrap();
597 let deserialized = GarbledCircuit::from_bytes(&bytes).unwrap();
598
599 assert!(deserialized.evaluate(&[true, true]));
601 assert!(!deserialized.evaluate(&[false, true]));
602 }
603
604 #[test]
605 fn test_complex_circuit() {
606 let mut circuit = Circuit::new();
608 let wire_a = circuit.add_input_wire();
609 let wire_b = circuit.add_input_wire();
610 let wire_c = circuit.add_input_wire();
611 let wire_d = circuit.add_input_wire();
612 let wire_xor = circuit.add_wire();
613 let wire_and = circuit.add_wire();
614 let wire_or = circuit.add_wire();
615
616 circuit.add_gate(Gate::new(GateType::Xor, wire_a, wire_b, wire_xor));
617 circuit.add_gate(Gate::new(GateType::And, wire_xor, wire_c, wire_and));
618 circuit.add_gate(Gate::new(GateType::Or, wire_and, wire_d, wire_or));
619 circuit.set_output_wire(wire_or);
620
621 let garbled = circuit.garble();
622
623 assert!(garbled.evaluate(&[true, false, true, false]));
625 assert!(!garbled.evaluate(&[false, false, true, false]));
627 assert!(garbled.evaluate(&[false, false, true, true]));
629 }
630
631 #[test]
632 fn test_wire_label_generation() {
633 let label1 = WireLabel::random();
634 let label2 = WireLabel::random();
635
636 assert_ne!(label1.data(), label2.data());
638 }
639
640 #[test]
641 fn test_circuit_with_multiple_outputs() {
642 let mut circuit = Circuit::new();
644 let wire_a = circuit.add_input_wire();
645 let wire_b = circuit.add_input_wire();
646 let wire_and = circuit.add_wire();
647
648 circuit.add_gate(Gate::new(GateType::And, wire_a, wire_b, wire_and));
649 circuit.set_output_wire(wire_and);
650
651 let garbled = circuit.garble();
652
653 assert!(garbled.evaluate(&[true, true]));
654 }
655
656 #[test]
657 fn test_gate_types() {
658 let gate_and = Gate::new(GateType::And, 0, 1, 2);
660 let gate_or = Gate::new(GateType::Or, 0, 1, 2);
661 let gate_xor = Gate::new(GateType::Xor, 0, 1, 2);
662 let gate_not = Gate::not(0, 1);
663
664 assert_eq!(gate_and.gate_type, GateType::And);
665 assert_eq!(gate_or.gate_type, GateType::Or);
666 assert_eq!(gate_xor.gate_type, GateType::Xor);
667 assert_eq!(gate_not.gate_type, GateType::Not);
668 }
669
670 #[test]
671 fn test_free_xor_optimization() {
672 let mut circuit = Circuit::new();
674 let wire_a = circuit.add_input_wire();
675 let wire_b = circuit.add_input_wire();
676 let wire_out = circuit.add_wire();
677 circuit.add_gate(Gate::new(GateType::Xor, wire_a, wire_b, wire_out));
678 circuit.set_output_wire(wire_out);
679
680 let garbled = circuit.garble();
681
682 assert_eq!(garbled.gates[0].encrypted_table.len(), 0);
684 }
685
686 #[test]
687 fn test_get_input_labels() {
688 let mut circuit = Circuit::new();
689 let wire_a = circuit.add_input_wire();
690 let wire_b = circuit.add_input_wire();
691 let wire_out = circuit.add_wire();
692 circuit.add_gate(Gate::new(GateType::And, wire_a, wire_b, wire_out));
693 circuit.set_output_wire(wire_out);
694
695 let garbled = circuit.garble();
696
697 let (label_0, label_1) = garbled.get_input_labels(0);
699 assert_ne!(label_0.data(), label_1.data());
700 }
701
702 #[test]
703 fn test_point_and_permute() {
704 let label1 = WireLabel::random();
706 let label2 = WireLabel::random();
707
708 let _ = label1.permute_bit();
710 let _ = label2.permute_bit();
711
712 let xor_result = label1.xor(&label2);
714 assert_eq!(
715 xor_result.permute_bit(),
716 label1.permute_bit() ^ label2.permute_bit()
717 );
718 }
719
720 #[test]
721 fn test_circuit_default() {
722 let circuit = Circuit::default();
723 assert_eq!(circuit.num_wires, 0);
724 assert_eq!(circuit.gates.len(), 0);
725 }
726}