1use crate::error::{MLError, Result};
2use quantrs2_circuit::prelude::Circuit;
3use quantrs2_sim::statevector::StateVectorSimulator;
4use scirs2_core::ndarray::{Array1, Array2};
5use scirs2_core::random::prelude::*;
6use std::collections::HashMap;
7use std::fmt;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
11pub enum ProtocolType {
12 BB84,
14
15 E91,
17
18 B92,
20
21 BBM92,
23
24 SARG04,
26}
27
28#[derive(Debug, Clone)]
30pub struct Party {
31 pub name: String,
33
34 pub key: Option<Vec<u8>>,
36
37 pub bases: Option<Vec<usize>>,
39
40 pub state: Option<Vec<f64>>,
42}
43
44#[derive(Debug, Clone)]
46pub struct QuantumKeyDistribution {
47 pub protocol: ProtocolType,
49
50 pub num_qubits: usize,
52
53 pub alice: Party,
55
56 pub bob: Party,
58
59 pub error_rate: f64,
61
62 pub security_bits: usize,
64}
65
66impl QuantumKeyDistribution {
67 pub fn new(protocol: ProtocolType, num_qubits: usize) -> Self {
69 QuantumKeyDistribution {
70 protocol,
71 num_qubits,
72 alice: Party {
73 name: "Alice".to_string(),
74 key: None,
75 bases: None,
76 state: None,
77 },
78 bob: Party {
79 name: "Bob".to_string(),
80 key: None,
81 bases: None,
82 state: None,
83 },
84 error_rate: 0.0,
85 security_bits: num_qubits / 10,
86 }
87 }
88
89 pub fn with_error_rate(mut self, error_rate: f64) -> Self {
91 self.error_rate = error_rate;
92 self
93 }
94
95 pub fn with_security_bits(mut self, security_bits: usize) -> Self {
97 self.security_bits = security_bits;
98 self
99 }
100
101 pub fn distribute_key(&mut self) -> Result<usize> {
103 match self.protocol {
104 ProtocolType::BB84 => self.bb84_protocol(),
105 ProtocolType::E91 => self.e91_protocol(),
106 ProtocolType::B92 => self.b92_protocol(),
107 ProtocolType::BBM92 => Err(MLError::NotImplemented(
108 "BBM92 protocol not implemented yet".to_string(),
109 )),
110 ProtocolType::SARG04 => Err(MLError::NotImplemented(
111 "SARG04 protocol not implemented yet".to_string(),
112 )),
113 }
114 }
115
116 fn bb84_protocol(&mut self) -> Result<usize> {
118 let alice_bits = (0..self.num_qubits)
123 .map(|_| {
124 if thread_rng().gen::<f64>() > 0.5 {
125 1u8
126 } else {
127 0u8
128 }
129 })
130 .collect::<Vec<_>>();
131
132 let alice_bases = (0..self.num_qubits)
134 .map(|_| {
135 if thread_rng().gen::<f64>() > 0.5 {
136 1usize
137 } else {
138 0usize
139 }
140 })
141 .collect::<Vec<_>>();
142
143 let bob_bases = (0..self.num_qubits)
144 .map(|_| {
145 if thread_rng().gen::<f64>() > 0.5 {
146 1usize
147 } else {
148 0usize
149 }
150 })
151 .collect::<Vec<_>>();
152
153 let matching_bases = alice_bases
155 .iter()
156 .zip(bob_bases.iter())
157 .enumerate()
158 .filter_map(|(i, (a, b))| if a == b { Some(i) } else { None })
159 .collect::<Vec<_>>();
160
161 let mut key_bits = Vec::new();
163 for &i in &matching_bases {
164 if thread_rng().gen::<f64>() > self.error_rate {
166 key_bits.push(alice_bits[i]);
167 } else {
168 key_bits.push(alice_bits[i] ^ 1);
170 }
171 }
172
173 let mut key_bytes = Vec::new();
175 for chunk in key_bits.chunks(8) {
176 let byte = chunk
177 .iter()
178 .enumerate()
179 .fold(0u8, |acc, (i, &bit)| acc | (bit << i));
180 key_bytes.push(byte);
181 }
182
183 self.alice.key = Some(key_bytes.clone());
185 self.bob.key = Some(key_bytes);
186
187 self.alice.bases = Some(alice_bases);
189 self.bob.bases = Some(bob_bases);
190
191 Ok(matching_bases.len())
192 }
193
194 fn e91_protocol(&mut self) -> Result<usize> {
196 let key_length = self.num_qubits / 3; let key_bytes = (0..key_length / 8 + 1)
202 .map(|_| thread_rng().gen::<u8>())
203 .collect::<Vec<_>>();
204
205 self.alice.key = Some(key_bytes.clone());
207 self.bob.key = Some(key_bytes);
208
209 Ok(key_length)
210 }
211
212 fn b92_protocol(&mut self) -> Result<usize> {
214 let key_length = self.num_qubits / 4; let key_bytes = (0..key_length / 8 + 1)
220 .map(|_| thread_rng().gen::<u8>())
221 .collect::<Vec<_>>();
222
223 self.alice.key = Some(key_bytes.clone());
225 self.bob.key = Some(key_bytes);
226
227 Ok(key_length)
228 }
229
230 pub fn verify_keys(&self) -> bool {
232 match (&self.alice.key, &self.bob.key) {
233 (Some(alice_key), Some(bob_key)) => alice_key == bob_key,
234 _ => false,
235 }
236 }
237
238 pub fn get_alice_key(&self) -> Option<Vec<u8>> {
240 self.alice.key.clone()
241 }
242
243 pub fn get_bob_key(&self) -> Option<Vec<u8>> {
245 self.bob.key.clone()
246 }
247}
248
249#[derive(Debug, Clone)]
251pub struct QuantumSignature {
252 security_bits: usize,
254
255 algorithm: String,
257
258 public_key: Vec<u8>,
260
261 private_key: Vec<u8>,
263}
264
265impl QuantumSignature {
266 pub fn new(security_bits: usize, algorithm: &str) -> Result<Self> {
268 let public_key = (0..security_bits / 8 + 1)
273 .map(|_| thread_rng().gen::<u8>())
274 .collect::<Vec<_>>();
275
276 let private_key = (0..security_bits / 8 + 1)
277 .map(|_| thread_rng().gen::<u8>())
278 .collect::<Vec<_>>();
279
280 Ok(QuantumSignature {
281 security_bits,
282 algorithm: algorithm.to_string(),
283 public_key,
284 private_key,
285 })
286 }
287
288 pub fn sign(&self, message: &[u8]) -> Result<Vec<u8>> {
290 let mut signature = self.private_key.clone();
295
296 for (i, &byte) in message.iter().enumerate() {
298 if i < signature.len() {
299 signature[i] ^= byte;
300 }
301 }
302
303 Ok(signature)
304 }
305
306 pub fn verify(&self, message: &[u8], signature: &[u8]) -> Result<bool> {
308 let expected_signature = self.sign(message)?;
313
314 let is_valid = signature.len() == expected_signature.len()
316 && signature
317 .iter()
318 .zip(expected_signature.iter())
319 .all(|(a, b)| a == b);
320
321 Ok(is_valid)
322 }
323}
324
325#[derive(Debug, Clone)]
327pub struct QuantumAuthentication {
328 protocol: String,
330
331 security_bits: usize,
333
334 keys: HashMap<String, Vec<u8>>,
336}
337
338impl QuantumAuthentication {
339 pub fn new(protocol: &str, security_bits: usize) -> Self {
341 QuantumAuthentication {
342 protocol: protocol.to_string(),
343 security_bits,
344 keys: HashMap::new(),
345 }
346 }
347
348 pub fn add_party(&mut self, party_name: &str) -> Result<()> {
350 let key = (0..self.security_bits / 8 + 1)
352 .map(|_| thread_rng().gen::<u8>())
353 .collect::<Vec<_>>();
354
355 self.keys.insert(party_name.to_string(), key);
356
357 Ok(())
358 }
359
360 pub fn authenticate(&self, party_name: &str, message: &[u8]) -> Result<Vec<u8>> {
362 let key = self
364 .keys
365 .get(party_name)
366 .ok_or_else(|| MLError::InvalidParameter(format!("Party {} not found", party_name)))?;
367
368 let mut tag = key.clone();
370
371 for (i, &byte) in message.iter().enumerate() {
373 if i < tag.len() {
374 tag[i] ^= byte;
375 }
376 }
377
378 Ok(tag)
379 }
380
381 pub fn verify(&self, party_name: &str, message: &[u8], tag: &[u8]) -> Result<bool> {
383 let expected_tag = self.authenticate(party_name, message)?;
385
386 let is_valid = tag.len() == expected_tag.len()
388 && tag.iter().zip(expected_tag.iter()).all(|(a, b)| a == b);
389
390 Ok(is_valid)
391 }
392}
393
394#[derive(Debug, Clone)]
396pub struct QSDC {
397 pub num_qubits: usize,
399
400 pub error_rate: f64,
402}
403
404impl QSDC {
405 pub fn new(num_qubits: usize) -> Self {
407 QSDC {
408 num_qubits,
409 error_rate: 0.01, }
411 }
412
413 pub fn with_error_rate(mut self, error_rate: f64) -> Self {
415 self.error_rate = error_rate;
416 self
417 }
418
419 pub fn transmit_message(&self, message: &[u8]) -> Result<Vec<u8>> {
421 let mut received = message.to_vec();
427
428 for byte in &mut received {
430 for bit_pos in 0..8 {
431 if thread_rng().gen::<f64>() < self.error_rate {
432 *byte ^= 1 << bit_pos;
434 }
435 }
436 }
437
438 Ok(received)
439 }
440}
441
442pub fn encrypt_with_qkd(message: &[u8], key: Vec<u8>) -> Vec<u8> {
444 message
446 .iter()
447 .enumerate()
448 .map(|(i, &byte)| byte ^ key[i % key.len()])
449 .collect()
450}
451
452pub fn decrypt_with_qkd(encrypted: &[u8], key: Vec<u8>) -> Vec<u8> {
454 encrypt_with_qkd(encrypted, key)
456}
457
458impl fmt::Display for ProtocolType {
459 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
460 match self {
461 ProtocolType::BB84 => write!(f, "BB84"),
462 ProtocolType::E91 => write!(f, "E91"),
463 ProtocolType::B92 => write!(f, "B92"),
464 ProtocolType::BBM92 => write!(f, "BBM92"),
465 ProtocolType::SARG04 => write!(f, "SARG04"),
466 }
467 }
468}