quantrs2_ml/pytorch_api/
rnn.rs

1//! RNN layers for PyTorch-like API (LSTM, GRU)
2
3use super::{Parameter, QuantumModule};
4use crate::error::Result;
5use crate::scirs2_integration::SciRS2Array;
6use scirs2_core::ndarray::{s, ArrayD, IxDyn};
7
8/// LSTM cell state
9#[derive(Debug, Clone)]
10pub struct LSTMState {
11    /// Hidden state
12    pub h: SciRS2Array,
13    /// Cell state
14    pub c: SciRS2Array,
15}
16
17/// LSTM layer
18pub struct QuantumLSTM {
19    input_size: usize,
20    hidden_size: usize,
21    num_layers: usize,
22    bidirectional: bool,
23    dropout: f64,
24    batch_first: bool,
25    weights: Vec<Parameter>,
26    training: bool,
27}
28
29impl QuantumLSTM {
30    /// Create new LSTM
31    pub fn new(input_size: usize, hidden_size: usize) -> Self {
32        let weight_ih = ArrayD::from_shape_fn(IxDyn(&[4 * hidden_size, input_size]), |_| {
33            fastrand::f64() * 0.1 - 0.05
34        });
35        let weight_hh = ArrayD::from_shape_fn(IxDyn(&[4 * hidden_size, hidden_size]), |_| {
36            fastrand::f64() * 0.1 - 0.05
37        });
38        let bias_ih = ArrayD::zeros(IxDyn(&[4 * hidden_size]));
39        let bias_hh = ArrayD::zeros(IxDyn(&[4 * hidden_size]));
40
41        Self {
42            input_size,
43            hidden_size,
44            num_layers: 1,
45            bidirectional: false,
46            dropout: 0.0,
47            batch_first: true,
48            weights: vec![
49                Parameter::new(SciRS2Array::with_grad(weight_ih), "weight_ih_l0"),
50                Parameter::new(SciRS2Array::with_grad(weight_hh), "weight_hh_l0"),
51                Parameter::new(SciRS2Array::with_grad(bias_ih), "bias_ih_l0"),
52                Parameter::new(SciRS2Array::with_grad(bias_hh), "bias_hh_l0"),
53            ],
54            training: true,
55        }
56    }
57
58    /// Set number of layers
59    pub fn num_layers(mut self, num_layers: usize) -> Self {
60        self.num_layers = num_layers;
61        self
62    }
63
64    /// Set bidirectional
65    pub fn bidirectional(mut self, bidirectional: bool) -> Self {
66        self.bidirectional = bidirectional;
67        self
68    }
69
70    /// Set dropout
71    pub fn dropout(mut self, dropout: f64) -> Self {
72        self.dropout = dropout;
73        self
74    }
75
76    /// Set batch first
77    pub fn batch_first(mut self, batch_first: bool) -> Self {
78        self.batch_first = batch_first;
79        self
80    }
81
82    /// Forward pass with optional initial state
83    pub fn forward_with_state(
84        &mut self,
85        input: &SciRS2Array,
86        initial_state: Option<LSTMState>,
87    ) -> Result<(SciRS2Array, LSTMState)> {
88        let shape = input.data.shape();
89        let (batch_size, seq_len, _input_size) = if self.batch_first {
90            (shape[0], shape[1], shape[2])
91        } else {
92            (shape[1], shape[0], shape[2])
93        };
94
95        let (mut h, mut c) = match initial_state {
96            Some(state) => (state.h.data, state.c.data),
97            None => (
98                ArrayD::zeros(IxDyn(&[batch_size, self.hidden_size])),
99                ArrayD::zeros(IxDyn(&[batch_size, self.hidden_size])),
100            ),
101        };
102
103        let mut outputs = Vec::with_capacity(seq_len);
104
105        for t in 0..seq_len {
106            let x_t = if self.batch_first {
107                input.data.slice(s![.., t, ..]).to_owned()
108            } else {
109                input.data.slice(s![t, .., ..]).to_owned()
110            };
111
112            let weight_ih = &self.weights[0].data.data;
113            let weight_hh = &self.weights[1].data.data;
114
115            let mut gates = ArrayD::zeros(IxDyn(&[batch_size, 4 * self.hidden_size]));
116
117            for b in 0..batch_size {
118                for g in 0..4 * self.hidden_size {
119                    let mut sum = 0.0;
120                    for i in 0..self
121                        .input_size
122                        .min(x_t.shape().last().copied().unwrap_or(self.input_size))
123                    {
124                        sum += x_t[[b, i]] * weight_ih[[g, i]];
125                    }
126                    for j in 0..self.hidden_size {
127                        sum += h[[b, j]] * weight_hh[[g, j]];
128                    }
129                    gates[[b, g]] = sum;
130                }
131            }
132
133            let mut i_gate = ArrayD::zeros(IxDyn(&[batch_size, self.hidden_size]));
134            let mut f_gate = ArrayD::zeros(IxDyn(&[batch_size, self.hidden_size]));
135            let mut g_gate = ArrayD::zeros(IxDyn(&[batch_size, self.hidden_size]));
136            let mut o_gate = ArrayD::zeros(IxDyn(&[batch_size, self.hidden_size]));
137
138            for b in 0..batch_size {
139                for j in 0..self.hidden_size {
140                    i_gate[[b, j]] = 1.0 / (1.0 + (-gates[[b, j]]).exp());
141                    f_gate[[b, j]] = 1.0 / (1.0 + (-gates[[b, self.hidden_size + j]]).exp());
142                    g_gate[[b, j]] = gates[[b, 2 * self.hidden_size + j]].tanh();
143                    o_gate[[b, j]] = 1.0 / (1.0 + (-gates[[b, 3 * self.hidden_size + j]]).exp());
144                }
145            }
146
147            for b in 0..batch_size {
148                for j in 0..self.hidden_size {
149                    c[[b, j]] = f_gate[[b, j]] * c[[b, j]] + i_gate[[b, j]] * g_gate[[b, j]];
150                    h[[b, j]] = o_gate[[b, j]] * c[[b, j]].tanh();
151                }
152            }
153
154            outputs.push(h.clone());
155        }
156
157        let output_shape = if self.batch_first {
158            IxDyn(&[batch_size, seq_len, self.hidden_size])
159        } else {
160            IxDyn(&[seq_len, batch_size, self.hidden_size])
161        };
162        let mut output = ArrayD::zeros(output_shape);
163
164        for (t, h_t) in outputs.iter().enumerate() {
165            for b in 0..batch_size {
166                for j in 0..self.hidden_size {
167                    if self.batch_first {
168                        output[[b, t, j]] = h_t[[b, j]];
169                    } else {
170                        output[[t, b, j]] = h_t[[b, j]];
171                    }
172                }
173            }
174        }
175
176        let final_state = LSTMState {
177            h: SciRS2Array::new(h, input.requires_grad),
178            c: SciRS2Array::new(c, input.requires_grad),
179        };
180
181        Ok((SciRS2Array::new(output, input.requires_grad), final_state))
182    }
183}
184
185impl QuantumModule for QuantumLSTM {
186    fn forward(&mut self, input: &SciRS2Array) -> Result<SciRS2Array> {
187        let (output, _) = self.forward_with_state(input, None)?;
188        Ok(output)
189    }
190
191    fn parameters(&self) -> Vec<Parameter> {
192        self.weights.clone()
193    }
194
195    fn train(&mut self, mode: bool) {
196        self.training = mode;
197    }
198
199    fn training(&self) -> bool {
200        self.training
201    }
202
203    fn zero_grad(&mut self) {
204        for w in &mut self.weights {
205            w.data.zero_grad();
206        }
207    }
208
209    fn name(&self) -> &str {
210        "LSTM"
211    }
212}
213
214/// GRU layer
215pub struct QuantumGRU {
216    input_size: usize,
217    hidden_size: usize,
218    num_layers: usize,
219    bidirectional: bool,
220    dropout: f64,
221    batch_first: bool,
222    weights: Vec<Parameter>,
223    training: bool,
224}
225
226impl QuantumGRU {
227    /// Create new GRU
228    pub fn new(input_size: usize, hidden_size: usize) -> Self {
229        let weight_ih = ArrayD::from_shape_fn(IxDyn(&[3 * hidden_size, input_size]), |_| {
230            fastrand::f64() * 0.1 - 0.05
231        });
232        let weight_hh = ArrayD::from_shape_fn(IxDyn(&[3 * hidden_size, hidden_size]), |_| {
233            fastrand::f64() * 0.1 - 0.05
234        });
235        let bias_ih = ArrayD::zeros(IxDyn(&[3 * hidden_size]));
236        let bias_hh = ArrayD::zeros(IxDyn(&[3 * hidden_size]));
237
238        Self {
239            input_size,
240            hidden_size,
241            num_layers: 1,
242            bidirectional: false,
243            dropout: 0.0,
244            batch_first: true,
245            weights: vec![
246                Parameter::new(SciRS2Array::with_grad(weight_ih), "weight_ih_l0"),
247                Parameter::new(SciRS2Array::with_grad(weight_hh), "weight_hh_l0"),
248                Parameter::new(SciRS2Array::with_grad(bias_ih), "bias_ih_l0"),
249                Parameter::new(SciRS2Array::with_grad(bias_hh), "bias_hh_l0"),
250            ],
251            training: true,
252        }
253    }
254
255    /// Set number of layers
256    pub fn num_layers(mut self, num_layers: usize) -> Self {
257        self.num_layers = num_layers;
258        self
259    }
260
261    /// Set bidirectional
262    pub fn bidirectional(mut self, bidirectional: bool) -> Self {
263        self.bidirectional = bidirectional;
264        self
265    }
266
267    /// Set batch first
268    pub fn batch_first(mut self, batch_first: bool) -> Self {
269        self.batch_first = batch_first;
270        self
271    }
272
273    /// Forward pass with optional initial hidden state
274    pub fn forward_with_hidden(
275        &mut self,
276        input: &SciRS2Array,
277        initial_hidden: Option<SciRS2Array>,
278    ) -> Result<(SciRS2Array, SciRS2Array)> {
279        let shape = input.data.shape();
280        let (batch_size, seq_len, _) = if self.batch_first {
281            (shape[0], shape[1], shape[2])
282        } else {
283            (shape[1], shape[0], shape[2])
284        };
285
286        let mut h = match initial_hidden {
287            Some(state) => state.data,
288            None => ArrayD::zeros(IxDyn(&[batch_size, self.hidden_size])),
289        };
290
291        let mut outputs = Vec::with_capacity(seq_len);
292
293        for t in 0..seq_len {
294            let x_t = if self.batch_first {
295                input.data.slice(s![.., t, ..]).to_owned()
296            } else {
297                input.data.slice(s![t, .., ..]).to_owned()
298            };
299
300            let weight_ih = &self.weights[0].data.data;
301            let weight_hh = &self.weights[1].data.data;
302
303            let mut gates = ArrayD::zeros(IxDyn(&[batch_size, 3 * self.hidden_size]));
304
305            for b in 0..batch_size {
306                for g in 0..3 * self.hidden_size {
307                    let mut sum = 0.0;
308                    for i in 0..self
309                        .input_size
310                        .min(x_t.shape().last().copied().unwrap_or(self.input_size))
311                    {
312                        sum += x_t[[b, i]] * weight_ih[[g, i]];
313                    }
314                    for j in 0..self.hidden_size {
315                        sum += h[[b, j]] * weight_hh[[g, j]];
316                    }
317                    gates[[b, g]] = sum;
318                }
319            }
320
321            let mut r_gate = ArrayD::zeros(IxDyn(&[batch_size, self.hidden_size]));
322            let mut z_gate = ArrayD::zeros(IxDyn(&[batch_size, self.hidden_size]));
323            let mut n_gate = ArrayD::zeros(IxDyn(&[batch_size, self.hidden_size]));
324
325            for b in 0..batch_size {
326                for j in 0..self.hidden_size {
327                    r_gate[[b, j]] = 1.0 / (1.0 + (-gates[[b, j]]).exp());
328                    z_gate[[b, j]] = 1.0 / (1.0 + (-gates[[b, self.hidden_size + j]]).exp());
329                    n_gate[[b, j]] =
330                        (gates[[b, 2 * self.hidden_size + j]] + r_gate[[b, j]] * h[[b, j]]).tanh();
331                }
332            }
333
334            for b in 0..batch_size {
335                for j in 0..self.hidden_size {
336                    h[[b, j]] =
337                        (1.0 - z_gate[[b, j]]) * n_gate[[b, j]] + z_gate[[b, j]] * h[[b, j]];
338                }
339            }
340
341            outputs.push(h.clone());
342        }
343
344        let output_shape = if self.batch_first {
345            IxDyn(&[batch_size, seq_len, self.hidden_size])
346        } else {
347            IxDyn(&[seq_len, batch_size, self.hidden_size])
348        };
349        let mut output = ArrayD::zeros(output_shape);
350
351        for (t, h_t) in outputs.iter().enumerate() {
352            for b in 0..batch_size {
353                for j in 0..self.hidden_size {
354                    if self.batch_first {
355                        output[[b, t, j]] = h_t[[b, j]];
356                    } else {
357                        output[[t, b, j]] = h_t[[b, j]];
358                    }
359                }
360            }
361        }
362
363        Ok((
364            SciRS2Array::new(output, input.requires_grad),
365            SciRS2Array::new(h, input.requires_grad),
366        ))
367    }
368}
369
370impl QuantumModule for QuantumGRU {
371    fn forward(&mut self, input: &SciRS2Array) -> Result<SciRS2Array> {
372        let (output, _) = self.forward_with_hidden(input, None)?;
373        Ok(output)
374    }
375
376    fn parameters(&self) -> Vec<Parameter> {
377        self.weights.clone()
378    }
379
380    fn train(&mut self, mode: bool) {
381        self.training = mode;
382    }
383
384    fn training(&self) -> bool {
385        self.training
386    }
387
388    fn zero_grad(&mut self) {
389        for w in &mut self.weights {
390            w.data.zero_grad();
391        }
392    }
393
394    fn name(&self) -> &str {
395        "GRU"
396    }
397}