1use crate::error::MLError;
7use num_complex::Complex64 as Complex;
8use quantrs2_circuit::prelude::*;
9use quantrs2_core::prelude::*;
10use std::f64::consts::PI;
11
12pub struct QVAE {
14 pub num_data_qubits: usize,
16 pub num_latent_qubits: usize,
18 pub num_ancilla_qubits: usize,
20 pub encoder_params: Vec<f64>,
22 pub decoder_params: Vec<f64>,
24}
25
26impl QVAE {
27 pub fn new(
29 num_data_qubits: usize,
30 num_latent_qubits: usize,
31 num_ancilla_qubits: usize,
32 ) -> Result<Self, MLError> {
33 if num_latent_qubits >= num_data_qubits {
34 return Err(MLError::InvalidParameter(
35 "Latent space must be smaller than data space".to_string(),
36 ));
37 }
38
39 let encoder_depth = 3;
41 let decoder_depth = 3;
42
43 let encoder_params = vec![0.1; num_data_qubits * encoder_depth * 3];
44 let decoder_params = vec![0.1; num_data_qubits * decoder_depth * 3];
45
46 Ok(Self {
47 num_data_qubits,
48 num_latent_qubits,
49 num_ancilla_qubits,
50 encoder_params,
51 decoder_params,
52 })
53 }
54
55 pub fn total_qubits(&self) -> usize {
57 self.num_data_qubits + self.num_latent_qubits + self.num_ancilla_qubits
58 }
59
60 pub fn encode<const N: usize>(
62 &self,
63 circuit: &mut Circuit<N>,
64 data_start: usize,
65 latent_start: usize,
66 ) -> Result<(), MLError> {
67 if data_start + self.num_data_qubits > N {
69 return Err(MLError::InvalidParameter(
70 "Data qubits exceed circuit size".to_string(),
71 ));
72 }
73 if latent_start + self.num_latent_qubits > N {
74 return Err(MLError::InvalidParameter(
75 "Latent qubits exceed circuit size".to_string(),
76 ));
77 }
78
79 let mut param_idx = 0;
81 let depth = self.encoder_params.len() / (self.num_data_qubits * 3);
82
83 for layer in 0..depth {
84 for i in 0..self.num_data_qubits {
86 let q = data_start + i;
87 if param_idx < self.encoder_params.len() {
88 circuit.rx(q, self.encoder_params[param_idx])?;
89 param_idx += 1;
90 }
91 if param_idx < self.encoder_params.len() {
92 circuit.ry(q, self.encoder_params[param_idx])?;
93 param_idx += 1;
94 }
95 if param_idx < self.encoder_params.len() {
96 circuit.rz(q, self.encoder_params[param_idx])?;
97 param_idx += 1;
98 }
99 }
100
101 for i in 0..self.num_data_qubits - 1 {
103 circuit.cnot(data_start + i, data_start + i + 1)?;
104 }
105
106 if layer == depth - 1 {
108 for i in 0..self.num_latent_qubits {
109 let data_q = data_start + (i % self.num_data_qubits);
110 let latent_q = latent_start + i;
111 circuit.cnot(data_q, latent_q)?;
112 }
113 }
114 }
115
116 Ok(())
117 }
118
119 pub fn decode<const N: usize>(
121 &self,
122 circuit: &mut Circuit<N>,
123 latent_start: usize,
124 output_start: usize,
125 ) -> Result<(), MLError> {
126 if latent_start + self.num_latent_qubits > N {
128 return Err(MLError::InvalidParameter(
129 "Latent qubits exceed circuit size".to_string(),
130 ));
131 }
132 if output_start + self.num_data_qubits > N {
133 return Err(MLError::InvalidParameter(
134 "Output qubits exceed circuit size".to_string(),
135 ));
136 }
137
138 let mut param_idx = 0;
140 let depth = self.decoder_params.len() / (self.num_data_qubits * 3);
141
142 for layer in 0..depth {
143 if layer == 0 {
145 for i in 0..self.num_latent_qubits {
146 let latent_q = latent_start + i;
147 let output_q = output_start + (i % self.num_data_qubits);
148 circuit.cnot(latent_q, output_q)?;
149 }
150 }
151
152 for i in 0..self.num_data_qubits {
154 let q = output_start + i;
155 if param_idx < self.decoder_params.len() {
156 circuit.rx(q, self.decoder_params[param_idx])?;
157 param_idx += 1;
158 }
159 if param_idx < self.decoder_params.len() {
160 circuit.ry(q, self.decoder_params[param_idx])?;
161 param_idx += 1;
162 }
163 if param_idx < self.decoder_params.len() {
164 circuit.rz(q, self.decoder_params[param_idx])?;
165 param_idx += 1;
166 }
167 }
168
169 for i in 0..self.num_data_qubits - 1 {
171 circuit.cnot(output_start + i, output_start + i + 1)?;
172 }
173 }
174
175 Ok(())
176 }
177
178 pub fn build_circuit<const N: usize>(&self) -> Result<Circuit<N>, MLError> {
180 if N < self.total_qubits() {
181 return Err(MLError::InvalidParameter(format!(
182 "Circuit needs at least {} qubits",
183 self.total_qubits()
184 )));
185 }
186
187 let mut circuit = Circuit::<N>::new();
188
189 let data_start = 0;
191 let latent_start = self.num_data_qubits;
192 let output_start = self.num_data_qubits + self.num_latent_qubits;
193
194 self.encode(&mut circuit, data_start, latent_start)?;
196
197 self.decode(&mut circuit, latent_start, output_start)?;
199
200 Ok(circuit)
201 }
202
203 pub fn reconstruction_fidelity(
205 &self,
206 input_state: &[Complex],
207 output_state: &[Complex],
208 ) -> Result<f64, MLError> {
209 if input_state.len() != output_state.len() {
210 return Err(MLError::InvalidParameter(
211 "State dimensions mismatch".to_string(),
212 ));
213 }
214
215 let inner_product: Complex = input_state
217 .iter()
218 .zip(output_state.iter())
219 .map(|(a, b)| a.conj() * b)
220 .sum();
221
222 Ok(inner_product.norm_sqr())
224 }
225
226 pub fn get_parameters(&self) -> Vec<f64> {
228 let mut params = self.encoder_params.clone();
229 params.extend(&self.decoder_params);
230 params
231 }
232
233 pub fn set_parameters(&mut self, params: &[f64]) -> Result<(), MLError> {
235 let encoder_size = self.encoder_params.len();
236 let decoder_size = self.decoder_params.len();
237
238 if params.len() != encoder_size + decoder_size {
239 return Err(MLError::InvalidParameter(format!(
240 "Expected {} parameters, got {}",
241 encoder_size + decoder_size,
242 params.len()
243 )));
244 }
245
246 self.encoder_params.copy_from_slice(¶ms[..encoder_size]);
247 self.decoder_params.copy_from_slice(¶ms[encoder_size..]);
248
249 Ok(())
250 }
251
252 pub fn compute_loss(&self, input_states: &[Vec<Complex>], lambda: f64) -> Result<f64, MLError> {
254 let mut total_loss = 0.0;
257
258 for _input in input_states {
259 total_loss += 1.0; }
263
264 let reg_term: f64 = self.get_parameters().iter().map(|p| p * p).sum::<f64>() * lambda;
266
267 Ok(total_loss / input_states.len() as f64 + reg_term)
268 }
269}
270
271pub struct ClassicalAutoencoder {
273 pub input_dim: usize,
275 pub latent_dim: usize,
277 pub encoder_weights: Vec<Vec<f64>>,
279 pub decoder_weights: Vec<Vec<f64>>,
281}
282
283impl ClassicalAutoencoder {
284 pub fn new(input_dim: usize, latent_dim: usize) -> Self {
286 let mut rng = fastrand::Rng::with_seed(42);
287
288 let encoder_weights = (0..latent_dim)
290 .map(|_| (0..input_dim).map(|_| rng.f64() * 0.1 - 0.05).collect())
291 .collect();
292
293 let decoder_weights = (0..input_dim)
294 .map(|_| (0..latent_dim).map(|_| rng.f64() * 0.1 - 0.05).collect())
295 .collect();
296
297 Self {
298 input_dim,
299 latent_dim,
300 encoder_weights,
301 decoder_weights,
302 }
303 }
304
305 pub fn encode(&self, input: &[f64]) -> Vec<f64> {
307 let mut latent = vec![0.0; self.latent_dim];
308
309 for i in 0..self.latent_dim {
310 for j in 0..self.input_dim {
311 latent[i] += self.encoder_weights[i][j] * input[j];
312 }
313 latent[i] = latent[i].tanh();
315 }
316
317 latent
318 }
319
320 pub fn decode(&self, latent: &[f64]) -> Vec<f64> {
322 let mut output = vec![0.0; self.input_dim];
323
324 for i in 0..self.input_dim {
325 for j in 0..self.latent_dim {
326 output[i] += self.decoder_weights[i][j] * latent[j];
327 }
328 output[i] = 1.0 / (1.0 + (-output[i]).exp());
330 }
331
332 output
333 }
334
335 pub fn forward(&self, input: &[f64]) -> Vec<f64> {
337 let latent = self.encode(input);
338 self.decode(&latent)
339 }
340}
341
342pub struct HybridAutoencoder {
344 pub quantum_encoder: QVAE,
346 pub classical_decoder: ClassicalAutoencoder,
348}
349
350impl HybridAutoencoder {
351 pub fn new(
353 num_data_qubits: usize,
354 num_latent_qubits: usize,
355 classical_latent_dim: usize,
356 ) -> Result<Self, MLError> {
357 let quantum_encoder = QVAE::new(num_data_qubits, num_latent_qubits, 0)?;
358
359 let quantum_latent_dim = 1 << num_latent_qubits;
361 let classical_decoder = ClassicalAutoencoder::new(quantum_latent_dim, classical_latent_dim);
362
363 Ok(Self {
364 quantum_encoder,
365 classical_decoder,
366 })
367 }
368}
369
370#[cfg(test)]
371mod tests {
372 use super::*;
373
374 #[test]
375 fn test_qvae_creation() {
376 let qvae = QVAE::new(4, 2, 0).unwrap();
377 assert_eq!(qvae.num_data_qubits, 4);
378 assert_eq!(qvae.num_latent_qubits, 2);
379 assert_eq!(qvae.total_qubits(), 6);
380 }
381
382 #[test]
383 fn test_qvae_invalid_params() {
384 let result = QVAE::new(4, 5, 0);
386 assert!(result.is_err());
387 }
388
389 #[test]
390 fn test_classical_autoencoder() {
391 let ae = ClassicalAutoencoder::new(10, 3);
392 let input = vec![0.5; 10];
393 let output = ae.forward(&input);
394
395 assert_eq!(output.len(), 10);
396 for &val in &output {
398 assert!(val >= 0.0 && val <= 1.0);
399 }
400 }
401
402 #[test]
403 fn test_parameter_management() {
404 let mut qvae = QVAE::new(4, 2, 0).unwrap();
405 let params = qvae.get_parameters();
406 let new_params = vec![0.2; params.len()];
407
408 qvae.set_parameters(&new_params).unwrap();
409 let retrieved = qvae.get_parameters();
410
411 assert_eq!(retrieved, new_params);
412 }
413
414 #[test]
415 fn test_reconstruction_fidelity() {
416 let qvae = QVAE::new(2, 1, 0).unwrap();
417 let state = vec![
418 Complex::new(0.5, 0.0),
419 Complex::new(0.5, 0.0),
420 Complex::new(0.5, 0.0),
421 Complex::new(0.5, 0.0),
422 ];
423
424 let fidelity = qvae.reconstruction_fidelity(&state, &state).unwrap();
425 assert!((fidelity - 1.0).abs() < 1e-10);
426 }
427}