1use crate::error::MLError;
7use quantrs2_circuit::prelude::*;
8use scirs2_core::Complex64 as Complex;
9use std::f64::consts::PI;
10
11type DMatrix = Vec<Vec<f64>>;
13type DVector<T> = Vec<T>;
14
15#[derive(Debug, Clone)]
17pub struct QuantumConvFilter {
18 pub num_qubits: usize,
20 pub stride: usize,
22 pub params: Vec<f64>,
24}
25
26impl QuantumConvFilter {
27 pub fn new(num_qubits: usize, stride: usize) -> Self {
29 let num_params = num_qubits * 3; let params = vec![0.1; num_params];
32
33 Self {
34 num_qubits,
35 stride,
36 params,
37 }
38 }
39
40 pub fn apply_filter<const N: usize>(
42 &self,
43 circuit: &mut Circuit<N>,
44 start_qubit: usize,
45 ) -> Result<(), MLError> {
46 let end_qubit = (start_qubit + self.num_qubits).min(N);
47
48 let mut param_idx = 0;
50 for i in start_qubit..end_qubit {
51 if param_idx < self.params.len() {
52 circuit.rx(i, self.params[param_idx])?;
53 param_idx += 1;
54 }
55 if param_idx < self.params.len() {
56 circuit.ry(i, self.params[param_idx])?;
57 param_idx += 1;
58 }
59 if param_idx < self.params.len() {
60 circuit.rz(i, self.params[param_idx])?;
61 param_idx += 1;
62 }
63 }
64
65 for i in start_qubit..(end_qubit - 1) {
67 circuit.cnot(i, i + 1)?;
68 }
69
70 Ok(())
71 }
72}
73
74#[derive(Debug, Clone)]
76pub struct QuantumPooling {
77 pub pool_size: usize,
79 pub pool_type: PoolingType,
81}
82
83#[derive(Debug, Clone, Copy)]
84pub enum PoolingType {
85 TraceOut,
87 MeasureReset,
89 Quantum,
91}
92
93impl QuantumPooling {
94 pub fn new(pool_size: usize, pool_type: PoolingType) -> Self {
96 Self {
97 pool_size,
98 pool_type,
99 }
100 }
101
102 pub fn apply_pooling<const N: usize>(
104 &self,
105 circuit: &mut Circuit<N>,
106 active_qubits: &mut Vec<usize>,
107 ) -> Result<(), MLError> {
108 match self.pool_type {
109 PoolingType::TraceOut => {
110 let new_size = active_qubits.len() / self.pool_size;
112 active_qubits.truncate(new_size);
113 }
114 PoolingType::MeasureReset => {
115 let mut new_active = Vec::new();
117 for (i, &qubit) in active_qubits.iter().enumerate() {
118 if i % self.pool_size == 0 {
119 new_active.push(qubit);
120 } else {
121 }
124 }
125 *active_qubits = new_active;
126 }
127 PoolingType::Quantum => {
128 let pool_size = self.pool_size;
130 let new_size = active_qubits.len() / pool_size;
131
132 for i in 0..new_size {
134 let start_idx = i * pool_size;
135 let end_idx = (start_idx + pool_size).min(active_qubits.len());
136
137 if end_idx > start_idx + 1 {
138 for j in start_idx..end_idx - 1 {
140 circuit.cnot(active_qubits[j], active_qubits[j + 1]);
141 }
142 }
143 }
144
145 active_qubits.truncate(new_size);
147 }
148 }
149 Ok(())
150 }
151}
152
153pub struct QCNN {
155 pub num_qubits: usize,
157 pub conv_layers: Vec<(QuantumConvFilter, QuantumPooling)>,
159 pub fc_params: Vec<f64>,
161}
162
163impl QCNN {
164 pub fn new(
166 num_qubits: usize,
167 conv_filters: Vec<(usize, usize)>, pool_sizes: Vec<usize>,
169 fc_params: usize,
170 ) -> Result<Self, MLError> {
171 if conv_filters.len() != pool_sizes.len() {
172 return Err(MLError::ModelCreationError(
173 "Number of conv filters must match number of pooling layers".to_string(),
174 ));
175 }
176
177 let mut conv_layers = Vec::new();
178 for ((filter_size, stride), pool_size) in conv_filters.into_iter().zip(pool_sizes) {
179 let filter = QuantumConvFilter::new(filter_size, stride);
180 let pooling = QuantumPooling::new(pool_size, PoolingType::TraceOut);
181 conv_layers.push((filter, pooling));
182 }
183
184 let fc_params = vec![0.1; fc_params];
185
186 Ok(Self {
187 num_qubits,
188 conv_layers,
189 fc_params,
190 })
191 }
192
193 pub fn forward(&self, input_state: &DVector<Complex>) -> Result<DVector<Complex>, MLError> {
195 const MAX_QUBITS: usize = 20;
197
198 if self.num_qubits > MAX_QUBITS {
199 return Err(MLError::InvalidParameter(format!(
200 "QCNN supports up to {} qubits",
201 MAX_QUBITS
202 )));
203 }
204
205 let mut circuit = Circuit::<MAX_QUBITS>::new();
206 let mut active_qubits: Vec<usize> = (0..self.num_qubits).collect();
207
208 for (conv_filter, pooling) in &self.conv_layers {
213 let mut pos = 0;
215 while pos + conv_filter.num_qubits <= active_qubits.len() {
216 let start_qubit = active_qubits[pos];
217 conv_filter.apply_filter(&mut circuit, start_qubit)?;
218 pos += conv_filter.stride;
219 }
220
221 pooling.apply_pooling(&mut circuit, &mut active_qubits)?;
223 }
224
225 for (i, &qubit) in active_qubits.iter().enumerate() {
227 if i < self.fc_params.len() {
228 circuit.ry(qubit, self.fc_params[i])?;
229 }
230 }
231
232 let output_size = 1 << active_qubits.len();
235 let mut output = vec![Complex::new(0.0, 0.0); output_size];
236
237 let norm = 1.0 / (output_size as f64).sqrt();
239 for i in 0..output_size {
240 output[i] = Complex::new(norm, 0.0);
241 }
242
243 Ok(output)
244 }
245
246 pub fn get_parameters(&self) -> Vec<f64> {
248 let mut params = Vec::new();
249
250 for (conv_filter, _) in &self.conv_layers {
251 params.extend(&conv_filter.params);
252 }
253 params.extend(&self.fc_params);
254
255 params
256 }
257
258 pub fn set_parameters(&mut self, params: &[f64]) -> Result<(), MLError> {
260 let mut idx = 0;
261
262 for (conv_filter, _) in &mut self.conv_layers {
263 let filter_params = conv_filter.params.len();
264 if idx + filter_params > params.len() {
265 return Err(MLError::InvalidParameter(
266 "Not enough parameters provided".to_string(),
267 ));
268 }
269 conv_filter
270 .params
271 .copy_from_slice(¶ms[idx..idx + filter_params]);
272 idx += filter_params;
273 }
274
275 let fc_params_len = self.fc_params.len();
276 if idx + fc_params_len > params.len() {
277 return Err(MLError::InvalidParameter(
278 "Not enough parameters for FC layer".to_string(),
279 ));
280 }
281 self.fc_params
282 .copy_from_slice(¶ms[idx..idx + fc_params_len]);
283
284 Ok(())
285 }
286
287 pub fn compute_gradients(
289 &mut self,
290 input_state: &DVector<Complex>,
291 target: &DVector<Complex>,
292 loss_fn: impl Fn(&DVector<Complex>, &DVector<Complex>) -> f64,
293 ) -> Result<Vec<f64>, MLError> {
294 let params = self.get_parameters();
295 let mut gradients = vec![0.0; params.len()];
296 let shift = PI / 2.0;
297
298 for i in 0..params.len() {
299 let mut params_plus = params.clone();
301 params_plus[i] += shift;
302 self.set_parameters(¶ms_plus)?;
303 let output_plus = self.forward(input_state)?;
304 let loss_plus = loss_fn(&output_plus, target);
305
306 let mut params_minus = params.clone();
308 params_minus[i] -= shift;
309 self.set_parameters(¶ms_minus)?;
310 let output_minus = self.forward(input_state)?;
311 let loss_minus = loss_fn(&output_minus, target);
312
313 gradients[i] = (loss_plus - loss_minus) / (2.0 * shift);
315 }
316
317 self.set_parameters(¶ms)?;
319
320 Ok(gradients)
321 }
322}
323
324pub struct QuantumImageEncoder {
326 pub width: usize,
328 pub height: usize,
329 pub num_qubits: usize,
331}
332
333impl QuantumImageEncoder {
334 pub fn new(width: usize, height: usize, num_qubits: usize) -> Self {
336 Self {
337 width,
338 height,
339 num_qubits,
340 }
341 }
342
343 pub fn encode(&self, image: &DMatrix) -> Result<DVector<Complex>, MLError> {
345 if image.len() != self.height || image[0].len() != self.width {
346 return Err(MLError::InvalidParameter(
347 "Image dimensions don't match encoder settings".to_string(),
348 ));
349 }
350
351 let pixels: Vec<f64> = image.iter().flat_map(|row| row.iter()).copied().collect();
353 let norm = pixels.iter().map(|x| x * x).sum::<f64>().sqrt();
354
355 let state_size = 1 << self.num_qubits;
357 let mut state = vec![Complex::new(0.0, 0.0); state_size];
358
359 for (i, &pixel) in pixels.iter().enumerate() {
360 if i < state_size {
361 state[i] = Complex::new(pixel / norm, 0.0);
362 }
363 }
364
365 Ok(state)
366 }
367
368 pub fn decode(&self, state: &DVector<Complex>) -> DMatrix {
370 let mut image = vec![vec![0.0; self.width]; self.height];
371 let mut idx = 0;
372
373 for i in 0..self.height {
374 for j in 0..self.width {
375 if idx < state.len() {
376 image[i][j] = state[idx].norm();
377 idx += 1;
378 }
379 }
380 }
381
382 image
383 }
384}
385
386#[cfg(test)]
387mod tests {
388 use super::*;
389
390 #[test]
391 fn test_qcnn_creation() {
392 let qcnn = QCNN::new(
393 8, vec![(4, 2), (2, 1)], vec![2, 2], 4, )
398 .unwrap();
399
400 assert_eq!(qcnn.num_qubits, 8);
401 assert_eq!(qcnn.conv_layers.len(), 2);
402 }
403
404 #[test]
405 fn test_quantum_filter() {
406 let filter = QuantumConvFilter::new(3, 1);
407 assert_eq!(filter.num_qubits, 3);
408 assert_eq!(filter.params.len(), 9); }
410
411 #[test]
412 fn test_filter_application() {
413 let filter = QuantumConvFilter::new(3, 1);
414 let mut circuit = Circuit::<8>::new();
415
416 filter.apply_filter(&mut circuit, 0).unwrap();
418
419 assert!(circuit.num_gates() > 0);
421 }
422
423 #[test]
424 fn test_pooling_trace_out() {
425 let pooling = QuantumPooling::new(2, PoolingType::TraceOut);
426 let mut circuit = Circuit::<8>::new();
427 let mut active_qubits = vec![0, 1, 2, 3, 4, 5, 6, 7];
428
429 pooling
430 .apply_pooling(&mut circuit, &mut active_qubits)
431 .unwrap();
432
433 assert_eq!(active_qubits.len(), 4);
435 }
436
437 #[test]
438 fn test_pooling_measure_reset() {
439 let pooling = QuantumPooling::new(2, PoolingType::MeasureReset);
440 let mut circuit = Circuit::<8>::new();
441 let mut active_qubits = vec![0, 1, 2, 3, 4, 5, 6, 7];
442
443 pooling
444 .apply_pooling(&mut circuit, &mut active_qubits)
445 .unwrap();
446
447 assert_eq!(active_qubits.len(), 4);
449 assert_eq!(active_qubits, vec![0, 2, 4, 6]);
450 }
451
452 #[test]
453 fn test_image_encoding() {
454 let encoder = QuantumImageEncoder::new(2, 2, 2);
455 let image = vec![vec![0.5, 0.5], vec![0.5, 0.5]];
456
457 let encoded = encoder.encode(&image).unwrap();
458 assert_eq!(encoded.len(), 4); let norm: f64 = encoded.iter().map(|c| c.norm_sqr()).sum();
462 assert!((norm - 1.0).abs() < 1e-10);
463 }
464
465 #[test]
466 fn test_image_decode() {
467 let encoder = QuantumImageEncoder::new(2, 2, 2);
468 let state = vec![
469 Complex::new(0.5, 0.0),
470 Complex::new(0.5, 0.0),
471 Complex::new(0.5, 0.0),
472 Complex::new(0.5, 0.0),
473 ];
474
475 let decoded = encoder.decode(&state);
476 assert_eq!(decoded.len(), 2);
477 assert_eq!(decoded[0].len(), 2);
478 }
479
480 #[test]
481 fn test_qcnn_forward() {
482 let qcnn = QCNN::new(
483 4, vec![(2, 1)], vec![2], 2, )
488 .unwrap();
489
490 let input_state = vec![Complex::new(1.0, 0.0); 16]; let output = qcnn.forward(&input_state).unwrap();
492
493 assert!(output.len() > 0);
495 }
496
497 #[test]
498 fn test_parameter_management() {
499 let mut qcnn = QCNN::new(
500 4, vec![(2, 1)], vec![2], 2, )
505 .unwrap();
506
507 let params = qcnn.get_parameters();
508 let num_params = params.len();
509
510 let new_params: Vec<f64> = (0..num_params).map(|i| i as f64 * 0.1).collect();
512 qcnn.set_parameters(&new_params).unwrap();
513
514 let retrieved_params = qcnn.get_parameters();
515 assert_eq!(retrieved_params, new_params);
516 }
517
518 #[test]
519 fn test_gradient_computation() {
520 let mut qcnn = QCNN::new(
521 4, vec![(2, 1)], vec![2], 2, )
526 .unwrap();
527
528 let input_state = vec![Complex::new(0.5, 0.0); 16];
529 let target_state = vec![Complex::new(0.707, 0.0); 2];
530
531 let loss_fn = |output: &DVector<Complex>, target: &DVector<Complex>| -> f64 {
533 output
534 .iter()
535 .zip(target.iter())
536 .map(|(o, t)| (o - t).norm_sqr())
537 .sum::<f64>()
538 };
539
540 let gradients = qcnn
541 .compute_gradients(&input_state, &target_state, loss_fn)
542 .unwrap();
543
544 assert_eq!(gradients.len(), qcnn.get_parameters().len());
546 }
547
548 #[test]
549 fn test_invalid_layer_configuration() {
550 let result = QCNN::new(
552 8,
553 vec![(4, 2), (2, 1)], vec![2], 4,
556 );
557
558 assert!(result.is_err());
559 }
560
561 #[test]
562 fn test_stride_behavior() {
563 let filter = QuantumConvFilter::new(2, 2); assert_eq!(filter.stride, 2);
565
566 let mut circuit = Circuit::<8>::new();
567
568 filter.apply_filter(&mut circuit, 0).unwrap();
570 filter.apply_filter(&mut circuit, 2).unwrap(); }
572
573 #[test]
574 fn test_large_image_encoding() {
575 let encoder = QuantumImageEncoder::new(4, 4, 4); let image = vec![vec![0.25; 4]; 4];
577
578 let encoded = encoder.encode(&image).unwrap();
579 assert_eq!(encoded.len(), 16); let decoded = encoder.decode(&encoded);
583 assert_eq!(decoded.len(), 4);
584 assert_eq!(decoded[0].len(), 4);
585 }
586}