1use crate::error::{MLError, Result};
8use quantrs2_circuit::prelude::Circuit;
9use quantrs2_sim::statevector::StateVectorSimulator;
10use scirs2_core::ndarray::{Array1, Array2};
11use scirs2_core::random::prelude::*;
12use std::collections::HashMap;
13use std::fmt;
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
17pub enum ProtocolType {
18 BB84,
20
21 E91,
23
24 B92,
26
27 BBM92,
29
30 SARG04,
32}
33
34#[derive(Debug, Clone)]
36pub struct Party {
37 pub name: String,
39
40 pub key: Option<Vec<u8>>,
42
43 pub bases: Option<Vec<usize>>,
45
46 pub state: Option<Vec<f64>>,
48}
49
50#[derive(Debug, Clone)]
52pub struct QuantumKeyDistribution {
53 pub protocol: ProtocolType,
55
56 pub num_qubits: usize,
58
59 pub alice: Party,
61
62 pub bob: Party,
64
65 pub error_rate: f64,
67
68 pub security_bits: usize,
70}
71
72impl QuantumKeyDistribution {
73 pub fn new(protocol: ProtocolType, num_qubits: usize) -> Self {
75 QuantumKeyDistribution {
76 protocol,
77 num_qubits,
78 alice: Party {
79 name: "Alice".to_string(),
80 key: None,
81 bases: None,
82 state: None,
83 },
84 bob: Party {
85 name: "Bob".to_string(),
86 key: None,
87 bases: None,
88 state: None,
89 },
90 error_rate: 0.0,
91 security_bits: num_qubits / 10,
92 }
93 }
94
95 pub fn with_error_rate(mut self, error_rate: f64) -> Self {
97 self.error_rate = error_rate;
98 self
99 }
100
101 pub fn with_security_bits(mut self, security_bits: usize) -> Self {
103 self.security_bits = security_bits;
104 self
105 }
106
107 pub fn distribute_key(&mut self) -> Result<usize> {
109 match self.protocol {
110 ProtocolType::BB84 => self.bb84_protocol(),
111 ProtocolType::E91 => self.e91_protocol(),
112 ProtocolType::B92 => self.b92_protocol(),
113 ProtocolType::BBM92 => self.bbm92_protocol(),
114 ProtocolType::SARG04 => self.sarg04_protocol(),
115 }
116 }
117
118 fn bb84_protocol(&mut self) -> Result<usize> {
120 let alice_bits = (0..self.num_qubits)
125 .map(|_| {
126 if thread_rng().random::<f64>() > 0.5 {
127 1u8
128 } else {
129 0u8
130 }
131 })
132 .collect::<Vec<_>>();
133
134 let alice_bases = (0..self.num_qubits)
136 .map(|_| {
137 if thread_rng().random::<f64>() > 0.5 {
138 1usize
139 } else {
140 0usize
141 }
142 })
143 .collect::<Vec<_>>();
144
145 let bob_bases = (0..self.num_qubits)
146 .map(|_| {
147 if thread_rng().random::<f64>() > 0.5 {
148 1usize
149 } else {
150 0usize
151 }
152 })
153 .collect::<Vec<_>>();
154
155 let matching_bases = alice_bases
157 .iter()
158 .zip(bob_bases.iter())
159 .enumerate()
160 .filter_map(|(i, (a, b))| if a == b { Some(i) } else { None })
161 .collect::<Vec<_>>();
162
163 let mut key_bits = Vec::new();
165 for &i in &matching_bases {
166 if thread_rng().random::<f64>() > self.error_rate {
168 key_bits.push(alice_bits[i]);
169 } else {
170 key_bits.push(alice_bits[i] ^ 1);
172 }
173 }
174
175 let mut key_bytes = Vec::new();
177 for chunk in key_bits.chunks(8) {
178 let byte = chunk
179 .iter()
180 .enumerate()
181 .fold(0u8, |acc, (i, &bit)| acc | (bit << i));
182 key_bytes.push(byte);
183 }
184
185 self.alice.key = Some(key_bytes.clone());
187 self.bob.key = Some(key_bytes);
188
189 self.alice.bases = Some(alice_bases);
191 self.bob.bases = Some(bob_bases);
192
193 Ok(matching_bases.len())
194 }
195
196 fn e91_protocol(&mut self) -> Result<usize> {
198 let key_length = self.num_qubits / 3; let key_bytes = (0..key_length / 8 + 1)
204 .map(|_| thread_rng().random::<u8>())
205 .collect::<Vec<_>>();
206
207 self.alice.key = Some(key_bytes.clone());
209 self.bob.key = Some(key_bytes);
210
211 Ok(key_length)
212 }
213
214 fn b92_protocol(&mut self) -> Result<usize> {
216 let key_length = self.num_qubits / 4; let key_bytes = (0..key_length / 8 + 1)
222 .map(|_| thread_rng().random::<u8>())
223 .collect::<Vec<_>>();
224
225 self.alice.key = Some(key_bytes.clone());
227 self.bob.key = Some(key_bytes);
228
229 Ok(key_length)
230 }
231
232 fn bbm92_protocol(&mut self) -> Result<usize> {
240 let mut rng = thread_rng();
241
242 let alice_bases: Vec<usize> = (0..self.num_qubits)
244 .map(|_| if rng.random::<f64>() > 0.5 { 1 } else { 0 })
245 .collect();
246 let bob_bases: Vec<usize> = (0..self.num_qubits)
247 .map(|_| if rng.random::<f64>() > 0.5 { 1 } else { 0 })
248 .collect();
249
250 let alice_bits: Vec<u8> = (0..self.num_qubits)
252 .map(|_| if rng.random::<f64>() > 0.5 { 1 } else { 0 })
253 .collect();
254
255 let sifted_indices: Vec<usize> = (0..self.num_qubits)
257 .filter(|&i| alice_bases[i] == bob_bases[i])
258 .collect();
259 let key_length = sifted_indices.len();
260
261 let key_bytes: Vec<u8> = sifted_indices
263 .chunks(8)
264 .map(|chunk| {
265 chunk.iter().enumerate().fold(0u8, |acc, (bit_pos, &idx)| {
266 acc | (alice_bits[idx] << bit_pos)
267 })
268 })
269 .collect();
270
271 self.alice.key = Some(key_bytes.clone());
272 self.bob.key = Some(key_bytes);
274 Ok(key_length)
275 }
276
277 fn sarg04_protocol(&mut self) -> Result<usize> {
285 let mut rng = thread_rng();
286
287 let alice_bits: Vec<u8> = (0..self.num_qubits)
289 .map(|_| if rng.random::<f64>() > 0.5 { 1 } else { 0 })
290 .collect();
291 let alice_bases: Vec<usize> = (0..self.num_qubits)
292 .map(|_| if rng.random::<f64>() > 0.5 { 1 } else { 0 })
293 .collect();
294
295 let bob_conclusive: Vec<bool> = (0..self.num_qubits)
298 .map(|_| rng.random::<f64>() > 0.5)
299 .collect();
300
301 let bob_bases: Vec<usize> = (0..self.num_qubits)
303 .map(|_| if rng.random::<f64>() > 0.5 { 1 } else { 0 })
304 .collect();
305 let sifted_indices: Vec<usize> = (0..self.num_qubits)
306 .filter(|&i| bob_conclusive[i] && alice_bases[i] == bob_bases[i])
307 .collect();
308 let key_length = sifted_indices.len();
309
310 let key_bytes: Vec<u8> = sifted_indices
311 .chunks(8)
312 .map(|chunk| {
313 chunk.iter().enumerate().fold(0u8, |acc, (bit_pos, &idx)| {
314 acc | (alice_bits[idx] << bit_pos)
315 })
316 })
317 .collect();
318
319 self.alice.key = Some(key_bytes.clone());
320 self.bob.key = Some(key_bytes);
321 Ok(key_length)
322 }
323
324 pub fn verify_keys(&self) -> bool {
326 match (&self.alice.key, &self.bob.key) {
327 (Some(alice_key), Some(bob_key)) => alice_key == bob_key,
328 _ => false,
329 }
330 }
331
332 pub fn get_alice_key(&self) -> Option<Vec<u8>> {
334 self.alice.key.clone()
335 }
336
337 pub fn get_bob_key(&self) -> Option<Vec<u8>> {
339 self.bob.key.clone()
340 }
341}
342
343#[derive(Debug, Clone)]
345pub struct QuantumSignature {
346 security_bits: usize,
348
349 algorithm: String,
351
352 public_key: Vec<u8>,
354
355 private_key: Vec<u8>,
357}
358
359impl QuantumSignature {
360 pub fn new(security_bits: usize, algorithm: &str) -> Result<Self> {
362 let public_key = (0..security_bits / 8 + 1)
367 .map(|_| thread_rng().random::<u8>())
368 .collect::<Vec<_>>();
369
370 let private_key = (0..security_bits / 8 + 1)
371 .map(|_| thread_rng().random::<u8>())
372 .collect::<Vec<_>>();
373
374 Ok(QuantumSignature {
375 security_bits,
376 algorithm: algorithm.to_string(),
377 public_key,
378 private_key,
379 })
380 }
381
382 pub fn sign(&self, message: &[u8]) -> Result<Vec<u8>> {
384 let mut signature = self.private_key.clone();
389
390 for (i, &byte) in message.iter().enumerate() {
392 if i < signature.len() {
393 signature[i] ^= byte;
394 }
395 }
396
397 Ok(signature)
398 }
399
400 pub fn verify(&self, message: &[u8], signature: &[u8]) -> Result<bool> {
402 let expected_signature = self.sign(message)?;
407
408 let is_valid = signature.len() == expected_signature.len()
410 && signature
411 .iter()
412 .zip(expected_signature.iter())
413 .all(|(a, b)| a == b);
414
415 Ok(is_valid)
416 }
417}
418
419#[derive(Debug, Clone)]
421pub struct QuantumAuthentication {
422 protocol: String,
424
425 security_bits: usize,
427
428 keys: HashMap<String, Vec<u8>>,
430}
431
432impl QuantumAuthentication {
433 pub fn new(protocol: &str, security_bits: usize) -> Self {
435 QuantumAuthentication {
436 protocol: protocol.to_string(),
437 security_bits,
438 keys: HashMap::new(),
439 }
440 }
441
442 pub fn add_party(&mut self, party_name: &str) -> Result<()> {
444 let key = (0..self.security_bits / 8 + 1)
446 .map(|_| thread_rng().random::<u8>())
447 .collect::<Vec<_>>();
448
449 self.keys.insert(party_name.to_string(), key);
450
451 Ok(())
452 }
453
454 pub fn authenticate(&self, party_name: &str, message: &[u8]) -> Result<Vec<u8>> {
456 let key = self
458 .keys
459 .get(party_name)
460 .ok_or_else(|| MLError::InvalidParameter(format!("Party {} not found", party_name)))?;
461
462 let mut tag = key.clone();
464
465 for (i, &byte) in message.iter().enumerate() {
467 if i < tag.len() {
468 tag[i] ^= byte;
469 }
470 }
471
472 Ok(tag)
473 }
474
475 pub fn verify(&self, party_name: &str, message: &[u8], tag: &[u8]) -> Result<bool> {
477 let expected_tag = self.authenticate(party_name, message)?;
479
480 let is_valid = tag.len() == expected_tag.len()
482 && tag.iter().zip(expected_tag.iter()).all(|(a, b)| a == b);
483
484 Ok(is_valid)
485 }
486}
487
488#[derive(Debug, Clone)]
490pub struct QSDC {
491 pub num_qubits: usize,
493
494 pub error_rate: f64,
496}
497
498impl QSDC {
499 pub fn new(num_qubits: usize) -> Self {
501 QSDC {
502 num_qubits,
503 error_rate: 0.01, }
505 }
506
507 pub fn with_error_rate(mut self, error_rate: f64) -> Self {
509 self.error_rate = error_rate;
510 self
511 }
512
513 pub fn transmit_message(&self, message: &[u8]) -> Result<Vec<u8>> {
515 let mut received = message.to_vec();
521
522 for byte in &mut received {
524 for bit_pos in 0..8 {
525 if thread_rng().random::<f64>() < self.error_rate {
526 *byte ^= 1 << bit_pos;
528 }
529 }
530 }
531
532 Ok(received)
533 }
534}
535
536pub fn encrypt_with_qkd(message: &[u8], key: Vec<u8>) -> Vec<u8> {
538 message
540 .iter()
541 .enumerate()
542 .map(|(i, &byte)| byte ^ key[i % key.len()])
543 .collect()
544}
545
546pub fn decrypt_with_qkd(encrypted: &[u8], key: Vec<u8>) -> Vec<u8> {
548 encrypt_with_qkd(encrypted, key)
550}
551
552impl fmt::Display for ProtocolType {
553 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
554 match self {
555 ProtocolType::BB84 => write!(f, "BB84"),
556 ProtocolType::E91 => write!(f, "E91"),
557 ProtocolType::B92 => write!(f, "B92"),
558 ProtocolType::BBM92 => write!(f, "BBM92"),
559 ProtocolType::SARG04 => write!(f, "SARG04"),
560 }
561 }
562}