quantrs2_ml/
lstm.rs

1//! Quantum Long Short-Term Memory (QLSTM) and recurrent architectures.
2//!
3//! This module implements quantum versions of LSTM and other recurrent neural networks
4//! for processing sequential data with quantum advantages.
5
6use scirs2_core::ndarray::{s, Array1, Array2, Array3};
7use std::collections::HashMap;
8
9use crate::error::{MLError, Result};
10use crate::qnn::QNNLayer;
11use crate::utils::VariationalCircuit;
12use quantrs2_circuit::prelude::*;
13use quantrs2_core::gate::{multi::*, single::*, GateOp};
14
15/// Quantum LSTM cell
16#[derive(Debug, Clone)]
17pub struct QLSTMCell {
18    /// Number of qubits for hidden state
19    hidden_qubits: usize,
20    /// Number of qubits for cell state
21    cell_qubits: usize,
22    /// Input encoding qubits
23    input_qubits: usize,
24    /// Forget gate circuit
25    forget_gate: VariationalCircuit,
26    /// Input gate circuit
27    input_gate: VariationalCircuit,
28    /// Output gate circuit
29    output_gate: VariationalCircuit,
30    /// Candidate state circuit
31    candidate_circuit: VariationalCircuit,
32    /// Parameters
33    parameters: HashMap<String, f64>,
34}
35
36impl QLSTMCell {
37    /// Create a new QLSTM cell
38    pub fn new(input_dim: usize, hidden_dim: usize) -> Self {
39        let input_qubits = (input_dim as f64).log2().ceil() as usize;
40        let hidden_qubits = (hidden_dim as f64).log2().ceil() as usize;
41        let cell_qubits = hidden_qubits;
42
43        // Initialize gate circuits
44        let total_qubits = input_qubits + hidden_qubits;
45
46        let forget_gate = Self::create_gate_circuit(total_qubits, "forget");
47        let input_gate = Self::create_gate_circuit(total_qubits, "input");
48        let output_gate = Self::create_gate_circuit(total_qubits, "output");
49        let candidate_circuit = Self::create_gate_circuit(total_qubits, "candidate");
50
51        Self {
52            hidden_qubits,
53            cell_qubits,
54            input_qubits,
55            forget_gate,
56            input_gate,
57            output_gate,
58            candidate_circuit,
59            parameters: HashMap::new(),
60        }
61    }
62
63    /// Create a gate circuit for LSTM
64    fn create_gate_circuit(num_qubits: usize, gate_name: &str) -> VariationalCircuit {
65        let mut circuit = VariationalCircuit::new(num_qubits);
66
67        // Layer 1: Hadamard initialization
68        for q in 0..num_qubits {
69            circuit.add_gate("H", vec![q], vec![]);
70        }
71
72        // Layer 2: Parameterized rotations
73        for q in 0..num_qubits {
74            circuit.add_gate("RY", vec![q], vec![format!("{}_{}_ry1", gate_name, q)]);
75            circuit.add_gate("RZ", vec![q], vec![format!("{}_{}_rz1", gate_name, q)]);
76        }
77
78        // Layer 3: Entangling gates
79        for q in 0..num_qubits - 1 {
80            circuit.add_gate("CNOT", vec![q, q + 1], vec![]);
81        }
82
83        // Layer 4: Final rotations
84        for q in 0..num_qubits {
85            circuit.add_gate("RY", vec![q], vec![format!("{}_{}_ry2", gate_name, q)]);
86        }
87
88        circuit
89    }
90
91    /// Forward pass through LSTM cell
92    pub fn forward(
93        &self,
94        input_state: &Array1<f64>,
95        hidden_state: &Array1<f64>,
96        cell_state: &Array1<f64>,
97    ) -> Result<(Array1<f64>, Array1<f64>)> {
98        // Encode classical states to quantum
99        let input_encoded = self.encode_classical_data(input_state)?;
100        let hidden_encoded = self.encode_classical_data(hidden_state)?;
101
102        // Compute forget gate
103        let forget_output =
104            self.compute_gate_output(&self.forget_gate, &input_encoded, &hidden_encoded)?;
105
106        // Compute input gate
107        let input_output =
108            self.compute_gate_output(&self.input_gate, &input_encoded, &hidden_encoded)?;
109
110        // Compute candidate values
111        let candidate_output =
112            self.compute_gate_output(&self.candidate_circuit, &input_encoded, &hidden_encoded)?;
113
114        // Update cell state: C_t = f_t * C_{t-1} + i_t * C_tilde
115        let mut new_cell_state = Array1::zeros(cell_state.len());
116        for i in 0..cell_state.len() {
117            new_cell_state[i] =
118                forget_output[i] * cell_state[i] + input_output[i] * candidate_output[i];
119        }
120
121        // Compute output gate
122        let output_gate_values =
123            self.compute_gate_output(&self.output_gate, &input_encoded, &hidden_encoded)?;
124
125        // Compute hidden state: h_t = o_t * tanh(C_t)
126        let mut new_hidden_state = Array1::zeros(hidden_state.len());
127        for i in 0..hidden_state.len() {
128            new_hidden_state[i] = output_gate_values[i] * new_cell_state[i].tanh();
129        }
130
131        Ok((new_hidden_state, new_cell_state))
132    }
133
134    /// Encode classical data to quantum state
135    fn encode_classical_data(&self, data: &Array1<f64>) -> Result<Vec<f64>> {
136        // Amplitude encoding
137        let norm: f64 = data.iter().map(|x| x * x).sum::<f64>().sqrt();
138        if norm < 1e-10 {
139            return Err(MLError::InvalidInput("Zero norm input".to_string()));
140        }
141
142        Ok(data.iter().map(|x| x / norm).collect())
143    }
144
145    /// Compute gate output (simplified)
146    fn compute_gate_output(
147        &self,
148        gate_circuit: &VariationalCircuit,
149        input_encoded: &[f64],
150        hidden_encoded: &[f64],
151    ) -> Result<Array1<f64>> {
152        // Simplified - would execute quantum circuit
153        let output_dim = 2_usize.pow(self.hidden_qubits as u32);
154        let mut output = Array1::zeros(output_dim);
155
156        // Placeholder computation
157        for i in 0..output_dim {
158            output[i] = 0.5 + 0.5 * ((i as f64) * 0.1).sin();
159        }
160
161        Ok(output)
162    }
163
164    /// Get number of parameters
165    pub fn num_parameters(&self) -> usize {
166        self.forget_gate.num_parameters()
167            + self.input_gate.num_parameters()
168            + self.output_gate.num_parameters()
169            + self.candidate_circuit.num_parameters()
170    }
171}
172
173/// Quantum LSTM network
174#[derive(Debug)]
175pub struct QLSTM {
176    /// LSTM cells for each layer
177    cells: Vec<QLSTMCell>,
178    /// Hidden dimensions
179    hidden_dims: Vec<usize>,
180    /// Whether to return sequences
181    return_sequences: bool,
182    /// Dropout rate
183    dropout_rate: f64,
184}
185
186impl QLSTM {
187    /// Create a new QLSTM network
188    pub fn new(
189        input_dim: usize,
190        hidden_dims: Vec<usize>,
191        return_sequences: bool,
192        dropout_rate: f64,
193    ) -> Self {
194        let mut cells = Vec::new();
195
196        // Create cells for each layer
197        let mut prev_dim = input_dim;
198        for &hidden_dim in &hidden_dims {
199            cells.push(QLSTMCell::new(prev_dim, hidden_dim));
200            prev_dim = hidden_dim;
201        }
202
203        Self {
204            cells,
205            hidden_dims,
206            return_sequences,
207            dropout_rate,
208        }
209    }
210
211    /// Forward pass through the network
212    pub fn forward(&self, input_sequence: &Array2<f64>) -> Result<Array2<f64>> {
213        let seq_len = input_sequence.nrows();
214        let batch_size = 1; // Simplified for single sequence
215
216        // Initialize hidden and cell states with small non-zero values
217        let mut hidden_states: Vec<Array1<f64>> = self
218            .hidden_dims
219            .iter()
220            .map(|&dim| Array1::from_elem(dim, 0.01))
221            .collect();
222
223        let mut cell_states: Vec<Array1<f64>> = self
224            .hidden_dims
225            .iter()
226            .map(|&dim| Array1::from_elem(dim, 0.01))
227            .collect();
228
229        let mut outputs = Vec::new();
230
231        // Process each time step
232        for t in 0..seq_len {
233            let input_t = input_sequence.row(t).to_owned();
234            let mut layer_input = input_t;
235
236            // Pass through each layer
237            for (layer_idx, cell) in self.cells.iter().enumerate() {
238                let (new_hidden, new_cell) = cell.forward(
239                    &layer_input,
240                    &hidden_states[layer_idx],
241                    &cell_states[layer_idx],
242                )?;
243
244                hidden_states[layer_idx] = new_hidden.clone();
245                cell_states[layer_idx] = new_cell;
246                layer_input = new_hidden;
247            }
248
249            // Store output
250            if self.return_sequences || t == seq_len - 1 {
251                outputs.push(layer_input);
252            }
253        }
254
255        // Convert outputs to Array2
256        let output_dim = outputs[0].len();
257        let mut output_array = Array2::zeros((outputs.len(), output_dim));
258        for (i, output) in outputs.iter().enumerate() {
259            output_array.row_mut(i).assign(output);
260        }
261
262        Ok(output_array)
263    }
264
265    /// Bidirectional QLSTM forward pass
266    pub fn bidirectional_forward(&self, input_sequence: &Array2<f64>) -> Result<Array2<f64>> {
267        // Forward pass
268        let forward_output = self.forward(input_sequence)?;
269
270        // Backward pass (reverse sequence)
271        let mut reversed_input = input_sequence.clone();
272        for i in 0..input_sequence.nrows() / 2 {
273            let j = input_sequence.nrows() - 1 - i;
274            for k in 0..input_sequence.ncols() {
275                let tmp = reversed_input[[i, k]];
276                reversed_input[[i, k]] = reversed_input[[j, k]];
277                reversed_input[[j, k]] = tmp;
278            }
279        }
280        let backward_output = self.forward(&reversed_input)?;
281
282        // Concatenate outputs
283        let output_dim = forward_output.ncols() + backward_output.ncols();
284        let mut combined_output = Array2::zeros((forward_output.nrows(), output_dim));
285
286        for i in 0..forward_output.nrows() {
287            for j in 0..forward_output.ncols() {
288                combined_output[[i, j]] = forward_output[[i, j]];
289            }
290            for j in 0..backward_output.ncols() {
291                combined_output[[i, forward_output.ncols() + j]] =
292                    backward_output[[backward_output.nrows() - 1 - i, j]];
293            }
294        }
295
296        Ok(combined_output)
297    }
298}
299
300/// Quantum Gated Recurrent Unit (QGRU)
301#[derive(Debug)]
302pub struct QGRUCell {
303    /// Hidden dimension qubits
304    hidden_qubits: usize,
305    /// Update gate circuit
306    update_gate: VariationalCircuit,
307    /// Reset gate circuit
308    reset_gate: VariationalCircuit,
309    /// Candidate circuit
310    candidate_circuit: VariationalCircuit,
311}
312
313impl QGRUCell {
314    /// Create a new QGRU cell
315    pub fn new(input_dim: usize, hidden_dim: usize) -> Self {
316        let input_qubits = (input_dim as f64).log2().ceil() as usize;
317        let hidden_qubits = (hidden_dim as f64).log2().ceil() as usize;
318        let total_qubits = input_qubits + hidden_qubits;
319
320        Self {
321            hidden_qubits,
322            update_gate: QLSTMCell::create_gate_circuit(total_qubits, "update"),
323            reset_gate: QLSTMCell::create_gate_circuit(total_qubits, "reset"),
324            candidate_circuit: QLSTMCell::create_gate_circuit(total_qubits, "candidate"),
325        }
326    }
327
328    /// Forward pass through GRU cell
329    pub fn forward(
330        &self,
331        input_state: &Array1<f64>,
332        hidden_state: &Array1<f64>,
333    ) -> Result<Array1<f64>> {
334        // Simplified GRU computation
335        // z_t = σ(W_z · [h_{t-1}, x_t])
336        // r_t = σ(W_r · [h_{t-1}, x_t])
337        // h_tilde = tanh(W · [r_t * h_{t-1}, x_t])
338        // h_t = (1 - z_t) * h_{t-1} + z_t * h_tilde
339
340        let output_dim = hidden_state.len();
341        let mut new_hidden = Array1::zeros(output_dim);
342
343        // Placeholder computation
344        for i in 0..output_dim {
345            new_hidden[i] = 0.9 * hidden_state[i] + 0.1 * input_state[i % input_state.len()];
346        }
347
348        Ok(new_hidden)
349    }
350}
351
352/// Quantum attention mechanism for sequence-to-sequence models
353#[derive(Debug)]
354pub struct QuantumAttention {
355    /// Number of attention heads
356    num_heads: usize,
357    /// Dimension per head
358    head_dim: usize,
359    /// Query circuit
360    query_circuit: VariationalCircuit,
361    /// Key circuit
362    key_circuit: VariationalCircuit,
363    /// Value circuit
364    value_circuit: VariationalCircuit,
365}
366
367impl QuantumAttention {
368    /// Create quantum attention layer
369    pub fn new(embed_dim: usize, num_heads: usize) -> Self {
370        let head_dim = embed_dim / num_heads;
371        let num_qubits = (embed_dim as f64).log2().ceil() as usize;
372
373        Self {
374            num_heads,
375            head_dim,
376            query_circuit: Self::create_projection_circuit(num_qubits, "query"),
377            key_circuit: Self::create_projection_circuit(num_qubits, "key"),
378            value_circuit: Self::create_projection_circuit(num_qubits, "value"),
379        }
380    }
381
382    /// Create projection circuit for Q, K, V
383    fn create_projection_circuit(num_qubits: usize, name: &str) -> VariationalCircuit {
384        let mut circuit = VariationalCircuit::new(num_qubits);
385
386        // Parameterized layer
387        for q in 0..num_qubits {
388            circuit.add_gate("RY", vec![q], vec![format!("{}_{}_theta", name, q)]);
389            circuit.add_gate("RZ", vec![q], vec![format!("{}_{}_phi", name, q)]);
390        }
391
392        // Entangling layer
393        for q in 0..num_qubits - 1 {
394            circuit.add_gate("CZ", vec![q, q + 1], vec![]);
395        }
396
397        circuit
398    }
399
400    /// Compute attention scores
401    pub fn forward(
402        &self,
403        query: &Array2<f64>,
404        key: &Array2<f64>,
405        value: &Array2<f64>,
406    ) -> Result<Array2<f64>> {
407        let seq_len = query.nrows();
408        let embed_dim = query.ncols();
409
410        // Simplified attention computation
411        // Would compute Q, K, V projections using quantum circuits
412        // Then compute attention scores as softmax(QK^T/√d_k)V
413
414        let mut output = Array2::zeros((seq_len, embed_dim));
415
416        // Placeholder
417        for i in 0..seq_len {
418            for j in 0..embed_dim {
419                output[[i, j]] = 0.5 * query[[i, j]] + 0.3 * value[[i, j]];
420            }
421        }
422
423        Ok(output)
424    }
425}
426
427/// Sequence-to-sequence model with quantum components
428#[derive(Debug)]
429pub struct QuantumSeq2Seq {
430    /// Encoder LSTM
431    encoder: QLSTM,
432    /// Decoder LSTM
433    decoder: QLSTM,
434    /// Attention mechanism
435    attention: Option<QuantumAttention>,
436    /// Output projection
437    output_projection: QNNLayer,
438}
439
440impl QuantumSeq2Seq {
441    /// Create a new seq2seq model
442    pub fn new(
443        input_vocab_size: usize,
444        output_vocab_size: usize,
445        embed_dim: usize,
446        hidden_dims: Vec<usize>,
447        use_attention: bool,
448    ) -> Self {
449        let encoder = QLSTM::new(embed_dim, hidden_dims.clone(), false, 0.1);
450        let decoder = QLSTM::new(embed_dim, hidden_dims.clone(), true, 0.1);
451
452        let attention = if use_attention {
453            Some(QuantumAttention::new(
454                hidden_dims.last().copied().unwrap_or(embed_dim),
455                4,
456            ))
457        } else {
458            None
459        };
460
461        let output_projection = QNNLayer::new(
462            hidden_dims.last().copied().unwrap_or(embed_dim),
463            output_vocab_size,
464            crate::qnn::ActivationType::Linear,
465        );
466
467        Self {
468            encoder,
469            decoder,
470            attention,
471            output_projection,
472        }
473    }
474
475    /// Encode input sequence
476    pub fn encode(&self, input_sequence: &Array2<f64>) -> Result<Array2<f64>> {
477        self.encoder.forward(input_sequence)
478    }
479
480    /// Decode with optional attention
481    pub fn decode(
482        &self,
483        encoder_outputs: &Array2<f64>,
484        decoder_input: &Array2<f64>,
485    ) -> Result<Array2<f64>> {
486        let decoder_outputs = self.decoder.forward(decoder_input)?;
487
488        if let Some(attention) = &self.attention {
489            // Apply attention
490            attention.forward(&decoder_outputs, encoder_outputs, encoder_outputs)
491        } else {
492            Ok(decoder_outputs)
493        }
494    }
495}
496
497/// Training utilities for recurrent models
498pub mod training {
499    use super::*;
500    use crate::autodiff::{optimizers::Adam, QuantumAutoDiff};
501
502    /// Truncated backpropagation through time
503    pub struct TBPTT {
504        /// Truncation length
505        truncation_length: usize,
506        /// Gradient clipping value
507        gradient_clip: f64,
508    }
509
510    impl TBPTT {
511        pub fn new(truncation_length: usize, gradient_clip: f64) -> Self {
512            Self {
513                truncation_length,
514                gradient_clip,
515            }
516        }
517
518        /// Train QLSTM with TBPTT
519        pub fn train_step(
520            &self,
521            model: &mut QLSTM,
522            sequence: &Array2<f64>,
523            targets: &Array2<f64>,
524            optimizer: &mut Adam,
525        ) -> Result<f64> {
526            let seq_len = sequence.nrows();
527            let mut total_loss = 0.0;
528
529            // Process sequence in chunks
530            for start in (0..seq_len).step_by(self.truncation_length) {
531                let end = (start + self.truncation_length).min(seq_len);
532                let chunk = sequence.slice(s![start..end, ..]).to_owned();
533                let chunk_targets = targets.slice(s![start..end, ..]).to_owned();
534
535                // Forward pass
536                let outputs = model.forward(&chunk)?;
537
538                // Compute loss (simplified)
539                let loss = self.compute_loss(&outputs, &chunk_targets)?;
540                total_loss += loss;
541
542                // Backward pass would compute gradients
543                // Clip gradients
544                // Update parameters
545            }
546
547            Ok(total_loss / (seq_len as f64))
548        }
549
550        fn compute_loss(&self, outputs: &Array2<f64>, targets: &Array2<f64>) -> Result<f64> {
551            // MSE loss
552            let diff = outputs - targets;
553            Ok(diff.iter().map(|x| x * x).sum::<f64>() / diff.len() as f64)
554        }
555    }
556}
557
558#[cfg(test)]
559mod tests {
560    use super::*;
561    use scirs2_core::ndarray::array;
562
563    #[test]
564    fn test_qlstm_cell() {
565        let cell = QLSTMCell::new(4, 4);
566
567        let input = Array1::from_vec(vec![0.1, 0.2, 0.3, 0.4]);
568        let hidden = Array1::from_vec(vec![0.05, 0.05, 0.05, 0.05]);
569        let cell_state = Array1::from_vec(vec![0.05, 0.05, 0.05, 0.05]);
570
571        let (new_hidden, new_cell) = cell.forward(&input, &hidden, &cell_state).unwrap();
572
573        assert_eq!(new_hidden.len(), 4);
574        assert_eq!(new_cell.len(), 4);
575    }
576
577    #[test]
578    fn test_qlstm_network() {
579        let lstm = QLSTM::new(4, vec![8, 4], true, 0.1);
580
581        let sequence = array![
582            [0.1, 0.2, 0.3, 0.4],
583            [0.2, 0.3, 0.4, 0.5],
584            [0.3, 0.4, 0.5, 0.6]
585        ];
586
587        let output = lstm.forward(&sequence).unwrap();
588        assert_eq!(output.nrows(), 3); // return_sequences=true
589        assert_eq!(output.ncols(), 4); // Last hidden dim
590    }
591
592    #[test]
593    fn test_bidirectional_lstm() {
594        let lstm = QLSTM::new(4, vec![4], true, 0.0);
595
596        let sequence = array![[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]];
597
598        let output = lstm.bidirectional_forward(&sequence).unwrap();
599        assert_eq!(output.nrows(), 2);
600        assert_eq!(output.ncols(), 8); // Concatenated forward + backward
601    }
602
603    #[test]
604    fn test_qgru_cell() {
605        let gru = QGRUCell::new(4, 4);
606
607        let input = Array1::from_vec(vec![0.1, 0.2, 0.3, 0.4]);
608        let hidden = Array1::zeros(4);
609
610        let new_hidden = gru.forward(&input, &hidden).unwrap();
611        assert_eq!(new_hidden.len(), 4);
612    }
613
614    #[test]
615    fn test_quantum_attention() {
616        let attention = QuantumAttention::new(16, 4);
617
618        let seq_len = 3;
619        let embed_dim = 16;
620        let query = Array2::zeros((seq_len, embed_dim));
621        let key = Array2::zeros((seq_len, embed_dim));
622        let value = Array2::ones((seq_len, embed_dim));
623
624        let output = attention.forward(&query, &key, &value).unwrap();
625        assert_eq!(output.shape(), &[seq_len, embed_dim]);
626    }
627}