1use scirs2_core::ndarray::{s, Array1, Array2, Array3};
7use std::collections::HashMap;
8
9use crate::error::{MLError, Result};
10use crate::qnn::QNNLayer;
11use crate::utils::VariationalCircuit;
12use quantrs2_circuit::prelude::*;
13use quantrs2_core::gate::{multi::*, single::*, GateOp};
14
15#[derive(Debug, Clone)]
17pub struct QLSTMCell {
18 hidden_qubits: usize,
20 cell_qubits: usize,
22 input_qubits: usize,
24 forget_gate: VariationalCircuit,
26 input_gate: VariationalCircuit,
28 output_gate: VariationalCircuit,
30 candidate_circuit: VariationalCircuit,
32 parameters: HashMap<String, f64>,
34}
35
36impl QLSTMCell {
37 pub fn new(input_dim: usize, hidden_dim: usize) -> Self {
39 let input_qubits = (input_dim as f64).log2().ceil() as usize;
40 let hidden_qubits = (hidden_dim as f64).log2().ceil() as usize;
41 let cell_qubits = hidden_qubits;
42
43 let total_qubits = input_qubits + hidden_qubits;
45
46 let forget_gate = Self::create_gate_circuit(total_qubits, "forget");
47 let input_gate = Self::create_gate_circuit(total_qubits, "input");
48 let output_gate = Self::create_gate_circuit(total_qubits, "output");
49 let candidate_circuit = Self::create_gate_circuit(total_qubits, "candidate");
50
51 Self {
52 hidden_qubits,
53 cell_qubits,
54 input_qubits,
55 forget_gate,
56 input_gate,
57 output_gate,
58 candidate_circuit,
59 parameters: HashMap::new(),
60 }
61 }
62
63 fn create_gate_circuit(num_qubits: usize, gate_name: &str) -> VariationalCircuit {
65 let mut circuit = VariationalCircuit::new(num_qubits);
66
67 for q in 0..num_qubits {
69 circuit.add_gate("H", vec![q], vec![]);
70 }
71
72 for q in 0..num_qubits {
74 circuit.add_gate("RY", vec![q], vec![format!("{}_{}_ry1", gate_name, q)]);
75 circuit.add_gate("RZ", vec![q], vec![format!("{}_{}_rz1", gate_name, q)]);
76 }
77
78 for q in 0..num_qubits - 1 {
80 circuit.add_gate("CNOT", vec![q, q + 1], vec![]);
81 }
82
83 for q in 0..num_qubits {
85 circuit.add_gate("RY", vec![q], vec![format!("{}_{}_ry2", gate_name, q)]);
86 }
87
88 circuit
89 }
90
91 pub fn forward(
93 &self,
94 input_state: &Array1<f64>,
95 hidden_state: &Array1<f64>,
96 cell_state: &Array1<f64>,
97 ) -> Result<(Array1<f64>, Array1<f64>)> {
98 let input_encoded = self.encode_classical_data(input_state)?;
100 let hidden_encoded = self.encode_classical_data(hidden_state)?;
101
102 let forget_output =
104 self.compute_gate_output(&self.forget_gate, &input_encoded, &hidden_encoded)?;
105
106 let input_output =
108 self.compute_gate_output(&self.input_gate, &input_encoded, &hidden_encoded)?;
109
110 let candidate_output =
112 self.compute_gate_output(&self.candidate_circuit, &input_encoded, &hidden_encoded)?;
113
114 let mut new_cell_state = Array1::zeros(cell_state.len());
116 for i in 0..cell_state.len() {
117 new_cell_state[i] =
118 forget_output[i] * cell_state[i] + input_output[i] * candidate_output[i];
119 }
120
121 let output_gate_values =
123 self.compute_gate_output(&self.output_gate, &input_encoded, &hidden_encoded)?;
124
125 let mut new_hidden_state = Array1::zeros(hidden_state.len());
127 for i in 0..hidden_state.len() {
128 new_hidden_state[i] = output_gate_values[i] * new_cell_state[i].tanh();
129 }
130
131 Ok((new_hidden_state, new_cell_state))
132 }
133
134 fn encode_classical_data(&self, data: &Array1<f64>) -> Result<Vec<f64>> {
136 let norm: f64 = data.iter().map(|x| x * x).sum::<f64>().sqrt();
138 if norm < 1e-10 {
139 return Err(MLError::InvalidInput("Zero norm input".to_string()));
140 }
141
142 Ok(data.iter().map(|x| x / norm).collect())
143 }
144
145 fn compute_gate_output(
147 &self,
148 gate_circuit: &VariationalCircuit,
149 input_encoded: &[f64],
150 hidden_encoded: &[f64],
151 ) -> Result<Array1<f64>> {
152 let output_dim = 2_usize.pow(self.hidden_qubits as u32);
154 let mut output = Array1::zeros(output_dim);
155
156 for i in 0..output_dim {
158 output[i] = 0.5 + 0.5 * ((i as f64) * 0.1).sin();
159 }
160
161 Ok(output)
162 }
163
164 pub fn num_parameters(&self) -> usize {
166 self.forget_gate.num_parameters()
167 + self.input_gate.num_parameters()
168 + self.output_gate.num_parameters()
169 + self.candidate_circuit.num_parameters()
170 }
171}
172
173#[derive(Debug)]
175pub struct QLSTM {
176 cells: Vec<QLSTMCell>,
178 hidden_dims: Vec<usize>,
180 return_sequences: bool,
182 dropout_rate: f64,
184}
185
186impl QLSTM {
187 pub fn new(
189 input_dim: usize,
190 hidden_dims: Vec<usize>,
191 return_sequences: bool,
192 dropout_rate: f64,
193 ) -> Self {
194 let mut cells = Vec::new();
195
196 let mut prev_dim = input_dim;
198 for &hidden_dim in &hidden_dims {
199 cells.push(QLSTMCell::new(prev_dim, hidden_dim));
200 prev_dim = hidden_dim;
201 }
202
203 Self {
204 cells,
205 hidden_dims,
206 return_sequences,
207 dropout_rate,
208 }
209 }
210
211 pub fn forward(&self, input_sequence: &Array2<f64>) -> Result<Array2<f64>> {
213 let seq_len = input_sequence.nrows();
214 let batch_size = 1; let mut hidden_states: Vec<Array1<f64>> = self
218 .hidden_dims
219 .iter()
220 .map(|&dim| Array1::from_elem(dim, 0.01))
221 .collect();
222
223 let mut cell_states: Vec<Array1<f64>> = self
224 .hidden_dims
225 .iter()
226 .map(|&dim| Array1::from_elem(dim, 0.01))
227 .collect();
228
229 let mut outputs = Vec::new();
230
231 for t in 0..seq_len {
233 let input_t = input_sequence.row(t).to_owned();
234 let mut layer_input = input_t;
235
236 for (layer_idx, cell) in self.cells.iter().enumerate() {
238 let (new_hidden, new_cell) = cell.forward(
239 &layer_input,
240 &hidden_states[layer_idx],
241 &cell_states[layer_idx],
242 )?;
243
244 hidden_states[layer_idx] = new_hidden.clone();
245 cell_states[layer_idx] = new_cell;
246 layer_input = new_hidden;
247 }
248
249 if self.return_sequences || t == seq_len - 1 {
251 outputs.push(layer_input);
252 }
253 }
254
255 let output_dim = outputs[0].len();
257 let mut output_array = Array2::zeros((outputs.len(), output_dim));
258 for (i, output) in outputs.iter().enumerate() {
259 output_array.row_mut(i).assign(output);
260 }
261
262 Ok(output_array)
263 }
264
265 pub fn bidirectional_forward(&self, input_sequence: &Array2<f64>) -> Result<Array2<f64>> {
267 let forward_output = self.forward(input_sequence)?;
269
270 let mut reversed_input = input_sequence.clone();
272 for i in 0..input_sequence.nrows() / 2 {
273 let j = input_sequence.nrows() - 1 - i;
274 for k in 0..input_sequence.ncols() {
275 let tmp = reversed_input[[i, k]];
276 reversed_input[[i, k]] = reversed_input[[j, k]];
277 reversed_input[[j, k]] = tmp;
278 }
279 }
280 let backward_output = self.forward(&reversed_input)?;
281
282 let output_dim = forward_output.ncols() + backward_output.ncols();
284 let mut combined_output = Array2::zeros((forward_output.nrows(), output_dim));
285
286 for i in 0..forward_output.nrows() {
287 for j in 0..forward_output.ncols() {
288 combined_output[[i, j]] = forward_output[[i, j]];
289 }
290 for j in 0..backward_output.ncols() {
291 combined_output[[i, forward_output.ncols() + j]] =
292 backward_output[[backward_output.nrows() - 1 - i, j]];
293 }
294 }
295
296 Ok(combined_output)
297 }
298}
299
300#[derive(Debug)]
302pub struct QGRUCell {
303 hidden_qubits: usize,
305 update_gate: VariationalCircuit,
307 reset_gate: VariationalCircuit,
309 candidate_circuit: VariationalCircuit,
311}
312
313impl QGRUCell {
314 pub fn new(input_dim: usize, hidden_dim: usize) -> Self {
316 let input_qubits = (input_dim as f64).log2().ceil() as usize;
317 let hidden_qubits = (hidden_dim as f64).log2().ceil() as usize;
318 let total_qubits = input_qubits + hidden_qubits;
319
320 Self {
321 hidden_qubits,
322 update_gate: QLSTMCell::create_gate_circuit(total_qubits, "update"),
323 reset_gate: QLSTMCell::create_gate_circuit(total_qubits, "reset"),
324 candidate_circuit: QLSTMCell::create_gate_circuit(total_qubits, "candidate"),
325 }
326 }
327
328 pub fn forward(
330 &self,
331 input_state: &Array1<f64>,
332 hidden_state: &Array1<f64>,
333 ) -> Result<Array1<f64>> {
334 let output_dim = hidden_state.len();
341 let mut new_hidden = Array1::zeros(output_dim);
342
343 for i in 0..output_dim {
345 new_hidden[i] = 0.9 * hidden_state[i] + 0.1 * input_state[i % input_state.len()];
346 }
347
348 Ok(new_hidden)
349 }
350}
351
352#[derive(Debug)]
354pub struct QuantumAttention {
355 num_heads: usize,
357 head_dim: usize,
359 query_circuit: VariationalCircuit,
361 key_circuit: VariationalCircuit,
363 value_circuit: VariationalCircuit,
365}
366
367impl QuantumAttention {
368 pub fn new(embed_dim: usize, num_heads: usize) -> Self {
370 let head_dim = embed_dim / num_heads;
371 let num_qubits = (embed_dim as f64).log2().ceil() as usize;
372
373 Self {
374 num_heads,
375 head_dim,
376 query_circuit: Self::create_projection_circuit(num_qubits, "query"),
377 key_circuit: Self::create_projection_circuit(num_qubits, "key"),
378 value_circuit: Self::create_projection_circuit(num_qubits, "value"),
379 }
380 }
381
382 fn create_projection_circuit(num_qubits: usize, name: &str) -> VariationalCircuit {
384 let mut circuit = VariationalCircuit::new(num_qubits);
385
386 for q in 0..num_qubits {
388 circuit.add_gate("RY", vec![q], vec![format!("{}_{}_theta", name, q)]);
389 circuit.add_gate("RZ", vec![q], vec![format!("{}_{}_phi", name, q)]);
390 }
391
392 for q in 0..num_qubits - 1 {
394 circuit.add_gate("CZ", vec![q, q + 1], vec![]);
395 }
396
397 circuit
398 }
399
400 pub fn forward(
402 &self,
403 query: &Array2<f64>,
404 key: &Array2<f64>,
405 value: &Array2<f64>,
406 ) -> Result<Array2<f64>> {
407 let seq_len = query.nrows();
408 let embed_dim = query.ncols();
409
410 let mut output = Array2::zeros((seq_len, embed_dim));
415
416 for i in 0..seq_len {
418 for j in 0..embed_dim {
419 output[[i, j]] = 0.5 * query[[i, j]] + 0.3 * value[[i, j]];
420 }
421 }
422
423 Ok(output)
424 }
425}
426
427#[derive(Debug)]
429pub struct QuantumSeq2Seq {
430 encoder: QLSTM,
432 decoder: QLSTM,
434 attention: Option<QuantumAttention>,
436 output_projection: QNNLayer,
438}
439
440impl QuantumSeq2Seq {
441 pub fn new(
443 input_vocab_size: usize,
444 output_vocab_size: usize,
445 embed_dim: usize,
446 hidden_dims: Vec<usize>,
447 use_attention: bool,
448 ) -> Self {
449 let encoder = QLSTM::new(embed_dim, hidden_dims.clone(), false, 0.1);
450 let decoder = QLSTM::new(embed_dim, hidden_dims.clone(), true, 0.1);
451
452 let attention = if use_attention {
453 Some(QuantumAttention::new(
454 hidden_dims.last().copied().unwrap_or(embed_dim),
455 4,
456 ))
457 } else {
458 None
459 };
460
461 let output_projection = QNNLayer::new(
462 hidden_dims.last().copied().unwrap_or(embed_dim),
463 output_vocab_size,
464 crate::qnn::ActivationType::Linear,
465 );
466
467 Self {
468 encoder,
469 decoder,
470 attention,
471 output_projection,
472 }
473 }
474
475 pub fn encode(&self, input_sequence: &Array2<f64>) -> Result<Array2<f64>> {
477 self.encoder.forward(input_sequence)
478 }
479
480 pub fn decode(
482 &self,
483 encoder_outputs: &Array2<f64>,
484 decoder_input: &Array2<f64>,
485 ) -> Result<Array2<f64>> {
486 let decoder_outputs = self.decoder.forward(decoder_input)?;
487
488 if let Some(attention) = &self.attention {
489 attention.forward(&decoder_outputs, encoder_outputs, encoder_outputs)
491 } else {
492 Ok(decoder_outputs)
493 }
494 }
495}
496
497pub mod training {
499 use super::*;
500 use crate::autodiff::{optimizers::Adam, QuantumAutoDiff};
501
502 pub struct TBPTT {
504 truncation_length: usize,
506 gradient_clip: f64,
508 }
509
510 impl TBPTT {
511 pub fn new(truncation_length: usize, gradient_clip: f64) -> Self {
512 Self {
513 truncation_length,
514 gradient_clip,
515 }
516 }
517
518 pub fn train_step(
520 &self,
521 model: &mut QLSTM,
522 sequence: &Array2<f64>,
523 targets: &Array2<f64>,
524 optimizer: &mut Adam,
525 ) -> Result<f64> {
526 let seq_len = sequence.nrows();
527 let mut total_loss = 0.0;
528
529 for start in (0..seq_len).step_by(self.truncation_length) {
531 let end = (start + self.truncation_length).min(seq_len);
532 let chunk = sequence.slice(s![start..end, ..]).to_owned();
533 let chunk_targets = targets.slice(s![start..end, ..]).to_owned();
534
535 let outputs = model.forward(&chunk)?;
537
538 let loss = self.compute_loss(&outputs, &chunk_targets)?;
540 total_loss += loss;
541
542 }
546
547 Ok(total_loss / (seq_len as f64))
548 }
549
550 fn compute_loss(&self, outputs: &Array2<f64>, targets: &Array2<f64>) -> Result<f64> {
551 let diff = outputs - targets;
553 Ok(diff.iter().map(|x| x * x).sum::<f64>() / diff.len() as f64)
554 }
555 }
556}
557
558#[cfg(test)]
559mod tests {
560 use super::*;
561 use scirs2_core::ndarray::array;
562
563 #[test]
564 fn test_qlstm_cell() {
565 let cell = QLSTMCell::new(4, 4);
566
567 let input = Array1::from_vec(vec![0.1, 0.2, 0.3, 0.4]);
568 let hidden = Array1::from_vec(vec![0.05, 0.05, 0.05, 0.05]);
569 let cell_state = Array1::from_vec(vec![0.05, 0.05, 0.05, 0.05]);
570
571 let (new_hidden, new_cell) = cell.forward(&input, &hidden, &cell_state).unwrap();
572
573 assert_eq!(new_hidden.len(), 4);
574 assert_eq!(new_cell.len(), 4);
575 }
576
577 #[test]
578 fn test_qlstm_network() {
579 let lstm = QLSTM::new(4, vec![8, 4], true, 0.1);
580
581 let sequence = array![
582 [0.1, 0.2, 0.3, 0.4],
583 [0.2, 0.3, 0.4, 0.5],
584 [0.3, 0.4, 0.5, 0.6]
585 ];
586
587 let output = lstm.forward(&sequence).unwrap();
588 assert_eq!(output.nrows(), 3); assert_eq!(output.ncols(), 4); }
591
592 #[test]
593 fn test_bidirectional_lstm() {
594 let lstm = QLSTM::new(4, vec![4], true, 0.0);
595
596 let sequence = array![[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]];
597
598 let output = lstm.bidirectional_forward(&sequence).unwrap();
599 assert_eq!(output.nrows(), 2);
600 assert_eq!(output.ncols(), 8); }
602
603 #[test]
604 fn test_qgru_cell() {
605 let gru = QGRUCell::new(4, 4);
606
607 let input = Array1::from_vec(vec![0.1, 0.2, 0.3, 0.4]);
608 let hidden = Array1::zeros(4);
609
610 let new_hidden = gru.forward(&input, &hidden).unwrap();
611 assert_eq!(new_hidden.len(), 4);
612 }
613
614 #[test]
615 fn test_quantum_attention() {
616 let attention = QuantumAttention::new(16, 4);
617
618 let seq_len = 3;
619 let embed_dim = 16;
620 let query = Array2::zeros((seq_len, embed_dim));
621 let key = Array2::zeros((seq_len, embed_dim));
622 let value = Array2::ones((seq_len, embed_dim));
623
624 let output = attention.forward(&query, &key, &value).unwrap();
625 assert_eq!(output.shape(), &[seq_len, embed_dim]);
626 }
627}