quantrs2_ml/keras_api/
rnn.rs

1//! RNN layers for Keras-like API (LSTM, GRU, Bidirectional)
2
3use super::{ActivationFunction, Dense, KerasLayer};
4use crate::error::{MLError, Result};
5use scirs2_core::ndarray::{ArrayD, IxDyn};
6
7/// LSTM layer (Keras-compatible)
8pub struct LSTM {
9    /// Number of units (hidden size)
10    units: usize,
11    /// Return sequences
12    return_sequences: bool,
13    /// Return state
14    return_state: bool,
15    /// Go backwards
16    go_backwards: bool,
17    /// Dropout rate
18    dropout: f64,
19    /// Recurrent dropout
20    recurrent_dropout: f64,
21    /// Activation function
22    activation: ActivationFunction,
23    /// Recurrent activation
24    recurrent_activation: ActivationFunction,
25    /// Weights
26    weights: Option<(ArrayD<f64>, ArrayD<f64>, ArrayD<f64>)>,
27    /// Built flag
28    built: bool,
29    /// Layer name
30    layer_name: Option<String>,
31}
32
33impl LSTM {
34    /// Create new LSTM layer
35    pub fn new(units: usize) -> Self {
36        Self {
37            units,
38            return_sequences: false,
39            return_state: false,
40            go_backwards: false,
41            dropout: 0.0,
42            recurrent_dropout: 0.0,
43            activation: ActivationFunction::Tanh,
44            recurrent_activation: ActivationFunction::Sigmoid,
45            weights: None,
46            built: false,
47            layer_name: None,
48        }
49    }
50
51    /// Set return sequences
52    pub fn return_sequences(mut self, return_sequences: bool) -> Self {
53        self.return_sequences = return_sequences;
54        self
55    }
56
57    /// Set return state
58    pub fn return_state(mut self, return_state: bool) -> Self {
59        self.return_state = return_state;
60        self
61    }
62
63    /// Set go backwards
64    pub fn go_backwards(mut self, go_backwards: bool) -> Self {
65        self.go_backwards = go_backwards;
66        self
67    }
68
69    /// Set dropout
70    pub fn dropout(mut self, dropout: f64) -> Self {
71        self.dropout = dropout;
72        self
73    }
74
75    /// Set recurrent dropout
76    pub fn recurrent_dropout(mut self, recurrent_dropout: f64) -> Self {
77        self.recurrent_dropout = recurrent_dropout;
78        self
79    }
80
81    /// Set layer name
82    pub fn name(mut self, name: &str) -> Self {
83        self.layer_name = Some(name.to_string());
84        self
85    }
86}
87
88impl KerasLayer for LSTM {
89    fn call(&self, input: &ArrayD<f64>) -> Result<ArrayD<f64>> {
90        if !self.built {
91            return Err(MLError::ModelNotTrained(
92                "Layer not built. Call build() first.".to_string(),
93            ));
94        }
95
96        let (kernel, recurrent_kernel, bias) = self
97            .weights
98            .as_ref()
99            .ok_or_else(|| MLError::ModelNotTrained("LSTM weights not initialized".to_string()))?;
100
101        let shape = input.shape();
102        let (batch_size, seq_len, features) = (shape[0], shape[1], shape[2]);
103
104        let mut h: ArrayD<f64> = ArrayD::zeros(IxDyn(&[batch_size, self.units]));
105        let mut c: ArrayD<f64> = ArrayD::zeros(IxDyn(&[batch_size, self.units]));
106
107        let mut outputs = Vec::with_capacity(seq_len);
108
109        let sequence: Vec<usize> = if self.go_backwards {
110            (0..seq_len).rev().collect()
111        } else {
112            (0..seq_len).collect()
113        };
114
115        for t in sequence {
116            let mut gates = ArrayD::zeros(IxDyn(&[batch_size, 4 * self.units]));
117
118            for b in 0..batch_size {
119                for g in 0..4 * self.units {
120                    let mut sum = bias[[g]];
121                    for f in 0..features.min(kernel.shape()[0]) {
122                        sum += input[[b, t, f]] * kernel[[f, g]];
123                    }
124                    for j in 0..self.units {
125                        sum += h[[b, j]] * recurrent_kernel[[j, g]];
126                    }
127                    gates[[b, g]] = sum;
128                }
129            }
130
131            for b in 0..batch_size {
132                for j in 0..self.units {
133                    let i = 1.0 / (1.0 + (-gates[[b, j]]).exp());
134                    let f = 1.0 / (1.0 + (-gates[[b, self.units + j]]).exp());
135                    let g = gates[[b, 2 * self.units + j]].tanh();
136                    let o = 1.0 / (1.0 + (-gates[[b, 3 * self.units + j]]).exp());
137
138                    c[[b, j]] = f * c[[b, j]] + i * g;
139                    h[[b, j]] = o * c[[b, j]].tanh();
140                }
141            }
142
143            outputs.push(h.clone());
144        }
145
146        if self.go_backwards {
147            outputs.reverse();
148        }
149
150        if self.return_sequences {
151            let mut result = ArrayD::zeros(IxDyn(&[batch_size, seq_len, self.units]));
152            for (t, h_t) in outputs.iter().enumerate() {
153                for b in 0..batch_size {
154                    for j in 0..self.units {
155                        result[[b, t, j]] = h_t[[b, j]];
156                    }
157                }
158            }
159            Ok(result)
160        } else {
161            Ok(outputs.last().cloned().unwrap_or(h))
162        }
163    }
164
165    fn build(&mut self, input_shape: &[usize]) -> Result<()> {
166        let input_dim = *input_shape
167            .last()
168            .ok_or_else(|| MLError::InvalidConfiguration("Invalid input shape".to_string()))?;
169
170        let scale = (6.0 / (input_dim + self.units) as f64).sqrt();
171        let kernel = ArrayD::from_shape_fn(IxDyn(&[input_dim, 4 * self.units]), |_| {
172            (fastrand::f64() * 2.0 - 1.0) * scale
173        });
174        let recurrent_kernel = ArrayD::from_shape_fn(IxDyn(&[self.units, 4 * self.units]), |_| {
175            (fastrand::f64() * 2.0 - 1.0) * scale
176        });
177        let bias = ArrayD::zeros(IxDyn(&[4 * self.units]));
178
179        self.weights = Some((kernel, recurrent_kernel, bias));
180        self.built = true;
181
182        Ok(())
183    }
184
185    fn compute_output_shape(&self, input_shape: &[usize]) -> Vec<usize> {
186        if self.return_sequences {
187            vec![input_shape[0], input_shape[1], self.units]
188        } else {
189            vec![input_shape[0], self.units]
190        }
191    }
192
193    fn count_params(&self) -> usize {
194        if let Some((kernel, recurrent_kernel, bias)) = &self.weights {
195            kernel.len() + recurrent_kernel.len() + bias.len()
196        } else {
197            0
198        }
199    }
200
201    fn get_weights(&self) -> Vec<ArrayD<f64>> {
202        if let Some((k, rk, b)) = &self.weights {
203            vec![k.clone(), rk.clone(), b.clone()]
204        } else {
205            vec![]
206        }
207    }
208
209    fn set_weights(&mut self, weights: Vec<ArrayD<f64>>) -> Result<()> {
210        if weights.len() == 3 {
211            self.weights = Some((weights[0].clone(), weights[1].clone(), weights[2].clone()));
212            Ok(())
213        } else {
214            Err(MLError::InvalidConfiguration(
215                "LSTM requires 3 weight arrays".to_string(),
216            ))
217        }
218    }
219
220    fn built(&self) -> bool {
221        self.built
222    }
223
224    fn name(&self) -> &str {
225        self.layer_name.as_deref().unwrap_or("lstm")
226    }
227}
228
229/// GRU layer (Keras-compatible)
230pub struct GRU {
231    /// Number of units
232    units: usize,
233    /// Return sequences
234    return_sequences: bool,
235    /// Return state
236    return_state: bool,
237    /// Go backwards
238    go_backwards: bool,
239    /// Dropout
240    dropout: f64,
241    /// Recurrent dropout
242    recurrent_dropout: f64,
243    /// Weights
244    weights: Option<(ArrayD<f64>, ArrayD<f64>, ArrayD<f64>)>,
245    /// Built flag
246    built: bool,
247    /// Layer name
248    layer_name: Option<String>,
249}
250
251impl GRU {
252    /// Create new GRU layer
253    pub fn new(units: usize) -> Self {
254        Self {
255            units,
256            return_sequences: false,
257            return_state: false,
258            go_backwards: false,
259            dropout: 0.0,
260            recurrent_dropout: 0.0,
261            weights: None,
262            built: false,
263            layer_name: None,
264        }
265    }
266
267    /// Set return sequences
268    pub fn return_sequences(mut self, return_sequences: bool) -> Self {
269        self.return_sequences = return_sequences;
270        self
271    }
272
273    /// Set return state
274    pub fn return_state(mut self, return_state: bool) -> Self {
275        self.return_state = return_state;
276        self
277    }
278
279    /// Set go backwards
280    pub fn go_backwards(mut self, go_backwards: bool) -> Self {
281        self.go_backwards = go_backwards;
282        self
283    }
284
285    /// Set dropout
286    pub fn dropout(mut self, dropout: f64) -> Self {
287        self.dropout = dropout;
288        self
289    }
290
291    /// Set layer name
292    pub fn name(mut self, name: &str) -> Self {
293        self.layer_name = Some(name.to_string());
294        self
295    }
296}
297
298impl KerasLayer for GRU {
299    fn call(&self, input: &ArrayD<f64>) -> Result<ArrayD<f64>> {
300        if !self.built {
301            return Err(MLError::ModelNotTrained(
302                "Layer not built. Call build() first.".to_string(),
303            ));
304        }
305
306        let (kernel, recurrent_kernel, bias) = self
307            .weights
308            .as_ref()
309            .ok_or_else(|| MLError::ModelNotTrained("GRU weights not initialized".to_string()))?;
310
311        let shape = input.shape();
312        let (batch_size, seq_len, features) = (shape[0], shape[1], shape[2]);
313
314        let mut h: ArrayD<f64> = ArrayD::zeros(IxDyn(&[batch_size, self.units]));
315        let mut outputs = Vec::with_capacity(seq_len);
316
317        let sequence: Vec<usize> = if self.go_backwards {
318            (0..seq_len).rev().collect()
319        } else {
320            (0..seq_len).collect()
321        };
322
323        for t in sequence {
324            let mut gates: ArrayD<f64> = ArrayD::zeros(IxDyn(&[batch_size, 3 * self.units]));
325
326            for b in 0..batch_size {
327                for g in 0..3 * self.units {
328                    let mut sum = bias[[g]];
329                    for f in 0..features.min(kernel.shape()[0]) {
330                        sum += input[[b, t, f]] * kernel[[f, g]];
331                    }
332                    for j in 0..self.units {
333                        sum += h[[b, j]] * recurrent_kernel[[j, g]];
334                    }
335                    gates[[b, g]] = sum;
336                }
337            }
338
339            for b in 0..batch_size {
340                for j in 0..self.units {
341                    let r = 1.0 / (1.0 + (-gates[[b, j]]).exp());
342                    let z = 1.0 / (1.0 + (-gates[[b, self.units + j]]).exp());
343                    let n_val: f64 = gates[[b, 2 * self.units + j]] + r * h[[b, j]];
344                    let n = n_val.tanh();
345
346                    h[[b, j]] = (1.0 - z) * n + z * h[[b, j]];
347                }
348            }
349
350            outputs.push(h.clone());
351        }
352
353        if self.go_backwards {
354            outputs.reverse();
355        }
356
357        if self.return_sequences {
358            let mut result = ArrayD::zeros(IxDyn(&[batch_size, seq_len, self.units]));
359            for (t, h_t) in outputs.iter().enumerate() {
360                for b in 0..batch_size {
361                    for j in 0..self.units {
362                        result[[b, t, j]] = h_t[[b, j]];
363                    }
364                }
365            }
366            Ok(result)
367        } else {
368            Ok(outputs.last().cloned().unwrap_or(h))
369        }
370    }
371
372    fn build(&mut self, input_shape: &[usize]) -> Result<()> {
373        let input_dim = *input_shape
374            .last()
375            .ok_or_else(|| MLError::InvalidConfiguration("Invalid input shape".to_string()))?;
376
377        let scale = (6.0 / (input_dim + self.units) as f64).sqrt();
378        let kernel = ArrayD::from_shape_fn(IxDyn(&[input_dim, 3 * self.units]), |_| {
379            (fastrand::f64() * 2.0 - 1.0) * scale
380        });
381        let recurrent_kernel = ArrayD::from_shape_fn(IxDyn(&[self.units, 3 * self.units]), |_| {
382            (fastrand::f64() * 2.0 - 1.0) * scale
383        });
384        let bias = ArrayD::zeros(IxDyn(&[3 * self.units]));
385
386        self.weights = Some((kernel, recurrent_kernel, bias));
387        self.built = true;
388
389        Ok(())
390    }
391
392    fn compute_output_shape(&self, input_shape: &[usize]) -> Vec<usize> {
393        if self.return_sequences {
394            vec![input_shape[0], input_shape[1], self.units]
395        } else {
396            vec![input_shape[0], self.units]
397        }
398    }
399
400    fn count_params(&self) -> usize {
401        if let Some((kernel, recurrent_kernel, bias)) = &self.weights {
402            kernel.len() + recurrent_kernel.len() + bias.len()
403        } else {
404            0
405        }
406    }
407
408    fn get_weights(&self) -> Vec<ArrayD<f64>> {
409        if let Some((k, rk, b)) = &self.weights {
410            vec![k.clone(), rk.clone(), b.clone()]
411        } else {
412            vec![]
413        }
414    }
415
416    fn set_weights(&mut self, weights: Vec<ArrayD<f64>>) -> Result<()> {
417        if weights.len() == 3 {
418            self.weights = Some((weights[0].clone(), weights[1].clone(), weights[2].clone()));
419            Ok(())
420        } else {
421            Err(MLError::InvalidConfiguration(
422                "GRU requires 3 weight arrays".to_string(),
423            ))
424        }
425    }
426
427    fn built(&self) -> bool {
428        self.built
429    }
430
431    fn name(&self) -> &str {
432        self.layer_name.as_deref().unwrap_or("gru")
433    }
434}
435
436/// Bidirectional wrapper
437pub struct Bidirectional {
438    /// Forward layer
439    forward_layer: Box<dyn KerasLayer>,
440    /// Backward layer
441    backward_layer: Box<dyn KerasLayer>,
442    /// Merge mode
443    merge_mode: String,
444    /// Built flag
445    built: bool,
446    /// Layer name
447    layer_name: Option<String>,
448}
449
450impl Bidirectional {
451    /// Create new Bidirectional wrapper
452    pub fn new(layer: Box<dyn KerasLayer>) -> Self {
453        Self {
454            forward_layer: layer,
455            backward_layer: Box::new(Dense::new(1)),
456            merge_mode: "concat".to_string(),
457            built: false,
458            layer_name: None,
459        }
460    }
461
462    /// Set merge mode
463    pub fn merge_mode(mut self, merge_mode: &str) -> Self {
464        self.merge_mode = merge_mode.to_string();
465        self
466    }
467
468    /// Set layer name
469    pub fn name(mut self, name: &str) -> Self {
470        self.layer_name = Some(name.to_string());
471        self
472    }
473}
474
475impl KerasLayer for Bidirectional {
476    fn call(&self, input: &ArrayD<f64>) -> Result<ArrayD<f64>> {
477        let forward_output = self.forward_layer.call(input)?;
478
479        let shape = input.shape();
480        let mut reversed = input.clone();
481        let seq_len = shape[1];
482        for b in 0..shape[0] {
483            for t in 0..seq_len {
484                for f in 0..shape[2] {
485                    reversed[[b, t, f]] = input[[b, seq_len - 1 - t, f]];
486                }
487            }
488        }
489
490        let backward_output = self.backward_layer.call(&reversed)?;
491
492        match self.merge_mode.as_str() {
493            "sum" => Ok(&forward_output + &backward_output),
494            "mul" => Ok(&forward_output * &backward_output),
495            "ave" => Ok((&forward_output + &backward_output) / 2.0),
496            _ => {
497                let fwd_shape = forward_output.shape();
498                let bwd_shape = backward_output.shape();
499                let mut output = ArrayD::zeros(IxDyn(&[
500                    fwd_shape[0],
501                    fwd_shape.get(1).copied().unwrap_or(1),
502                    fwd_shape.last().copied().unwrap_or(0) + bwd_shape.last().copied().unwrap_or(0),
503                ]));
504
505                let fwd_last = *fwd_shape.last().unwrap_or(&0);
506                for b in 0..fwd_shape[0] {
507                    for s in 0..fwd_shape.get(1).copied().unwrap_or(1) {
508                        for f in 0..fwd_last {
509                            output[[b, s, f]] = forward_output[[b, s, f]];
510                        }
511                        for f in 0..*bwd_shape.last().unwrap_or(&0) {
512                            output[[b, s, fwd_last + f]] = backward_output[[b, s, f]];
513                        }
514                    }
515                }
516                Ok(output)
517            }
518        }
519    }
520
521    fn build(&mut self, input_shape: &[usize]) -> Result<()> {
522        self.forward_layer.build(input_shape)?;
523        self.backward_layer.build(input_shape)?;
524        self.built = true;
525        Ok(())
526    }
527
528    fn compute_output_shape(&self, input_shape: &[usize]) -> Vec<usize> {
529        let fwd_shape = self.forward_layer.compute_output_shape(input_shape);
530        match self.merge_mode.as_str() {
531            "sum" | "mul" | "ave" => fwd_shape,
532            _ => {
533                let mut out = fwd_shape.clone();
534                if let Some(last) = out.last_mut() {
535                    *last *= 2;
536                }
537                out
538            }
539        }
540    }
541
542    fn count_params(&self) -> usize {
543        self.forward_layer.count_params() + self.backward_layer.count_params()
544    }
545
546    fn get_weights(&self) -> Vec<ArrayD<f64>> {
547        let mut weights = self.forward_layer.get_weights();
548        weights.extend(self.backward_layer.get_weights());
549        weights
550    }
551
552    fn set_weights(&mut self, _weights: Vec<ArrayD<f64>>) -> Result<()> {
553        Ok(())
554    }
555
556    fn built(&self) -> bool {
557        self.built
558    }
559
560    fn name(&self) -> &str {
561        self.layer_name.as_deref().unwrap_or("bidirectional")
562    }
563}