1use crate::error::MLError;
7use quantrs2_circuit::prelude::*;
8use quantrs2_core::prelude::*;
9use scirs2_core::random::prelude::*;
10use scirs2_core::Complex64 as Complex;
11use std::f64::consts::PI;
12
13pub struct QVAE {
15 pub num_data_qubits: usize,
17 pub num_latent_qubits: usize,
19 pub num_ancilla_qubits: usize,
21 pub encoder_params: Vec<f64>,
23 pub decoder_params: Vec<f64>,
25}
26
27impl QVAE {
28 pub fn new(
30 num_data_qubits: usize,
31 num_latent_qubits: usize,
32 num_ancilla_qubits: usize,
33 ) -> Result<Self, MLError> {
34 if num_latent_qubits >= num_data_qubits {
35 return Err(MLError::InvalidParameter(
36 "Latent space must be smaller than data space".to_string(),
37 ));
38 }
39
40 let encoder_depth = 3;
42 let decoder_depth = 3;
43
44 let encoder_params = vec![0.1; num_data_qubits * encoder_depth * 3];
45 let decoder_params = vec![0.1; num_data_qubits * decoder_depth * 3];
46
47 Ok(Self {
48 num_data_qubits,
49 num_latent_qubits,
50 num_ancilla_qubits,
51 encoder_params,
52 decoder_params,
53 })
54 }
55
56 pub fn total_qubits(&self) -> usize {
58 self.num_data_qubits + self.num_latent_qubits + self.num_ancilla_qubits
59 }
60
61 pub fn encode<const N: usize>(
63 &self,
64 circuit: &mut Circuit<N>,
65 data_start: usize,
66 latent_start: usize,
67 ) -> Result<(), MLError> {
68 if data_start + self.num_data_qubits > N {
70 return Err(MLError::InvalidParameter(
71 "Data qubits exceed circuit size".to_string(),
72 ));
73 }
74 if latent_start + self.num_latent_qubits > N {
75 return Err(MLError::InvalidParameter(
76 "Latent qubits exceed circuit size".to_string(),
77 ));
78 }
79
80 let mut param_idx = 0;
82 let depth = self.encoder_params.len() / (self.num_data_qubits * 3);
83
84 for layer in 0..depth {
85 for i in 0..self.num_data_qubits {
87 let q = data_start + i;
88 if param_idx < self.encoder_params.len() {
89 circuit.rx(q, self.encoder_params[param_idx])?;
90 param_idx += 1;
91 }
92 if param_idx < self.encoder_params.len() {
93 circuit.ry(q, self.encoder_params[param_idx])?;
94 param_idx += 1;
95 }
96 if param_idx < self.encoder_params.len() {
97 circuit.rz(q, self.encoder_params[param_idx])?;
98 param_idx += 1;
99 }
100 }
101
102 for i in 0..self.num_data_qubits - 1 {
104 circuit.cnot(data_start + i, data_start + i + 1)?;
105 }
106
107 if layer == depth - 1 {
109 for i in 0..self.num_latent_qubits {
110 let data_q = data_start + (i % self.num_data_qubits);
111 let latent_q = latent_start + i;
112 circuit.cnot(data_q, latent_q)?;
113 }
114 }
115 }
116
117 Ok(())
118 }
119
120 pub fn decode<const N: usize>(
122 &self,
123 circuit: &mut Circuit<N>,
124 latent_start: usize,
125 output_start: usize,
126 ) -> Result<(), MLError> {
127 if latent_start + self.num_latent_qubits > N {
129 return Err(MLError::InvalidParameter(
130 "Latent qubits exceed circuit size".to_string(),
131 ));
132 }
133 if output_start + self.num_data_qubits > N {
134 return Err(MLError::InvalidParameter(
135 "Output qubits exceed circuit size".to_string(),
136 ));
137 }
138
139 let mut param_idx = 0;
141 let depth = self.decoder_params.len() / (self.num_data_qubits * 3);
142
143 for layer in 0..depth {
144 if layer == 0 {
146 for i in 0..self.num_latent_qubits {
147 let latent_q = latent_start + i;
148 let output_q = output_start + (i % self.num_data_qubits);
149 circuit.cnot(latent_q, output_q)?;
150 }
151 }
152
153 for i in 0..self.num_data_qubits {
155 let q = output_start + i;
156 if param_idx < self.decoder_params.len() {
157 circuit.rx(q, self.decoder_params[param_idx])?;
158 param_idx += 1;
159 }
160 if param_idx < self.decoder_params.len() {
161 circuit.ry(q, self.decoder_params[param_idx])?;
162 param_idx += 1;
163 }
164 if param_idx < self.decoder_params.len() {
165 circuit.rz(q, self.decoder_params[param_idx])?;
166 param_idx += 1;
167 }
168 }
169
170 for i in 0..self.num_data_qubits - 1 {
172 circuit.cnot(output_start + i, output_start + i + 1)?;
173 }
174 }
175
176 Ok(())
177 }
178
179 pub fn build_circuit<const N: usize>(&self) -> Result<Circuit<N>, MLError> {
181 if N < self.total_qubits() {
182 return Err(MLError::InvalidParameter(format!(
183 "Circuit needs at least {} qubits",
184 self.total_qubits()
185 )));
186 }
187
188 let mut circuit = Circuit::<N>::new();
189
190 let data_start = 0;
192 let latent_start = self.num_data_qubits;
193 let output_start = self.num_data_qubits + self.num_latent_qubits;
194
195 self.encode(&mut circuit, data_start, latent_start)?;
197
198 self.decode(&mut circuit, latent_start, output_start)?;
200
201 Ok(circuit)
202 }
203
204 pub fn reconstruction_fidelity(
206 &self,
207 input_state: &[Complex],
208 output_state: &[Complex],
209 ) -> Result<f64, MLError> {
210 if input_state.len() != output_state.len() {
211 return Err(MLError::InvalidParameter(
212 "State dimensions mismatch".to_string(),
213 ));
214 }
215
216 let inner_product: Complex = input_state
218 .iter()
219 .zip(output_state.iter())
220 .map(|(a, b)| a.conj() * b)
221 .sum();
222
223 Ok(inner_product.norm_sqr())
225 }
226
227 pub fn get_parameters(&self) -> Vec<f64> {
229 let mut params = self.encoder_params.clone();
230 params.extend(&self.decoder_params);
231 params
232 }
233
234 pub fn set_parameters(&mut self, params: &[f64]) -> Result<(), MLError> {
236 let encoder_size = self.encoder_params.len();
237 let decoder_size = self.decoder_params.len();
238
239 if params.len() != encoder_size + decoder_size {
240 return Err(MLError::InvalidParameter(format!(
241 "Expected {} parameters, got {}",
242 encoder_size + decoder_size,
243 params.len()
244 )));
245 }
246
247 self.encoder_params.copy_from_slice(¶ms[..encoder_size]);
248 self.decoder_params.copy_from_slice(¶ms[encoder_size..]);
249
250 Ok(())
251 }
252
253 pub fn compute_loss(&self, input_states: &[Vec<Complex>], lambda: f64) -> Result<f64, MLError> {
255 let mut total_loss = 0.0;
258
259 for _input in input_states {
260 total_loss += 1.0; }
264
265 let reg_term: f64 = self.get_parameters().iter().map(|p| p * p).sum::<f64>() * lambda;
267
268 Ok(total_loss / input_states.len() as f64 + reg_term)
269 }
270}
271
272pub struct ClassicalAutoencoder {
274 pub input_dim: usize,
276 pub latent_dim: usize,
278 pub encoder_weights: Vec<Vec<f64>>,
280 pub decoder_weights: Vec<Vec<f64>>,
282}
283
284impl ClassicalAutoencoder {
285 pub fn new(input_dim: usize, latent_dim: usize) -> Self {
287 let mut rng = scirs2_core::random::ChaCha8Rng::seed_from_u64(42);
288
289 let encoder_weights = (0..latent_dim)
291 .map(|_| {
292 (0..input_dim)
293 .map(|_| rng.gen::<f64>() * 0.1 - 0.05)
294 .collect()
295 })
296 .collect();
297
298 let decoder_weights = (0..input_dim)
299 .map(|_| {
300 (0..latent_dim)
301 .map(|_| rng.gen::<f64>() * 0.1 - 0.05)
302 .collect()
303 })
304 .collect();
305
306 Self {
307 input_dim,
308 latent_dim,
309 encoder_weights,
310 decoder_weights,
311 }
312 }
313
314 pub fn encode(&self, input: &[f64]) -> Vec<f64> {
316 let mut latent = vec![0.0; self.latent_dim];
317
318 for i in 0..self.latent_dim {
319 for j in 0..self.input_dim {
320 latent[i] += self.encoder_weights[i][j] * input[j];
321 }
322 latent[i] = latent[i].tanh();
324 }
325
326 latent
327 }
328
329 pub fn decode(&self, latent: &[f64]) -> Vec<f64> {
331 let mut output = vec![0.0; self.input_dim];
332
333 for i in 0..self.input_dim {
334 for j in 0..self.latent_dim {
335 output[i] += self.decoder_weights[i][j] * latent[j];
336 }
337 output[i] = 1.0 / (1.0 + (-output[i]).exp());
339 }
340
341 output
342 }
343
344 pub fn forward(&self, input: &[f64]) -> Vec<f64> {
346 let latent = self.encode(input);
347 self.decode(&latent)
348 }
349}
350
351pub struct HybridAutoencoder {
353 pub quantum_encoder: QVAE,
355 pub classical_decoder: ClassicalAutoencoder,
357}
358
359impl HybridAutoencoder {
360 pub fn new(
362 num_data_qubits: usize,
363 num_latent_qubits: usize,
364 classical_latent_dim: usize,
365 ) -> Result<Self, MLError> {
366 let quantum_encoder = QVAE::new(num_data_qubits, num_latent_qubits, 0)?;
367
368 let quantum_latent_dim = 1 << num_latent_qubits;
370 let classical_decoder = ClassicalAutoencoder::new(quantum_latent_dim, classical_latent_dim);
371
372 Ok(Self {
373 quantum_encoder,
374 classical_decoder,
375 })
376 }
377}
378
379#[cfg(test)]
380mod tests {
381 use super::*;
382
383 #[test]
384 fn test_qvae_creation() {
385 let qvae = QVAE::new(4, 2, 0).expect("Failed to create QVAE");
386 assert_eq!(qvae.num_data_qubits, 4);
387 assert_eq!(qvae.num_latent_qubits, 2);
388 assert_eq!(qvae.total_qubits(), 6);
389 }
390
391 #[test]
392 fn test_qvae_invalid_params() {
393 let result = QVAE::new(4, 5, 0);
395 assert!(result.is_err());
396 }
397
398 #[test]
399 fn test_classical_autoencoder() {
400 let ae = ClassicalAutoencoder::new(10, 3);
401 let input = vec![0.5; 10];
402 let output = ae.forward(&input);
403
404 assert_eq!(output.len(), 10);
405 for &val in &output {
407 assert!(val >= 0.0 && val <= 1.0);
408 }
409 }
410
411 #[test]
412 fn test_parameter_management() {
413 let mut qvae = QVAE::new(4, 2, 0).expect("Failed to create QVAE");
414 let params = qvae.get_parameters();
415 let new_params = vec![0.2; params.len()];
416
417 qvae.set_parameters(&new_params)
418 .expect("Failed to set parameters");
419 let retrieved = qvae.get_parameters();
420
421 assert_eq!(retrieved, new_params);
422 }
423
424 #[test]
425 fn test_reconstruction_fidelity() {
426 let qvae = QVAE::new(2, 1, 0).expect("Failed to create QVAE");
427 let state = vec![
428 Complex::new(0.5, 0.0),
429 Complex::new(0.5, 0.0),
430 Complex::new(0.5, 0.0),
431 Complex::new(0.5, 0.0),
432 ];
433
434 let fidelity = qvae
435 .reconstruction_fidelity(&state, &state)
436 .expect("Fidelity computation should succeed");
437 assert!((fidelity - 1.0).abs() < 1e-10);
438 }
439}