1use super::{ActivationFunction, Dense, KerasLayer};
4use crate::error::{MLError, Result};
5use scirs2_core::ndarray::{ArrayD, IxDyn};
6
7pub struct LSTM {
9 units: usize,
11 return_sequences: bool,
13 return_state: bool,
15 go_backwards: bool,
17 dropout: f64,
19 recurrent_dropout: f64,
21 activation: ActivationFunction,
23 recurrent_activation: ActivationFunction,
25 weights: Option<(ArrayD<f64>, ArrayD<f64>, ArrayD<f64>)>,
27 built: bool,
29 layer_name: Option<String>,
31}
32
33impl LSTM {
34 pub fn new(units: usize) -> Self {
36 Self {
37 units,
38 return_sequences: false,
39 return_state: false,
40 go_backwards: false,
41 dropout: 0.0,
42 recurrent_dropout: 0.0,
43 activation: ActivationFunction::Tanh,
44 recurrent_activation: ActivationFunction::Sigmoid,
45 weights: None,
46 built: false,
47 layer_name: None,
48 }
49 }
50
51 pub fn return_sequences(mut self, return_sequences: bool) -> Self {
53 self.return_sequences = return_sequences;
54 self
55 }
56
57 pub fn return_state(mut self, return_state: bool) -> Self {
59 self.return_state = return_state;
60 self
61 }
62
63 pub fn go_backwards(mut self, go_backwards: bool) -> Self {
65 self.go_backwards = go_backwards;
66 self
67 }
68
69 pub fn dropout(mut self, dropout: f64) -> Self {
71 self.dropout = dropout;
72 self
73 }
74
75 pub fn recurrent_dropout(mut self, recurrent_dropout: f64) -> Self {
77 self.recurrent_dropout = recurrent_dropout;
78 self
79 }
80
81 pub fn name(mut self, name: &str) -> Self {
83 self.layer_name = Some(name.to_string());
84 self
85 }
86}
87
88impl KerasLayer for LSTM {
89 fn call(&self, input: &ArrayD<f64>) -> Result<ArrayD<f64>> {
90 if !self.built {
91 return Err(MLError::ModelNotTrained(
92 "Layer not built. Call build() first.".to_string(),
93 ));
94 }
95
96 let (kernel, recurrent_kernel, bias) = self
97 .weights
98 .as_ref()
99 .ok_or_else(|| MLError::ModelNotTrained("LSTM weights not initialized".to_string()))?;
100
101 let shape = input.shape();
102 let (batch_size, seq_len, features) = (shape[0], shape[1], shape[2]);
103
104 let mut h: ArrayD<f64> = ArrayD::zeros(IxDyn(&[batch_size, self.units]));
105 let mut c: ArrayD<f64> = ArrayD::zeros(IxDyn(&[batch_size, self.units]));
106
107 let mut outputs = Vec::with_capacity(seq_len);
108
109 let sequence: Vec<usize> = if self.go_backwards {
110 (0..seq_len).rev().collect()
111 } else {
112 (0..seq_len).collect()
113 };
114
115 for t in sequence {
116 let mut gates = ArrayD::zeros(IxDyn(&[batch_size, 4 * self.units]));
117
118 for b in 0..batch_size {
119 for g in 0..4 * self.units {
120 let mut sum = bias[[g]];
121 for f in 0..features.min(kernel.shape()[0]) {
122 sum += input[[b, t, f]] * kernel[[f, g]];
123 }
124 for j in 0..self.units {
125 sum += h[[b, j]] * recurrent_kernel[[j, g]];
126 }
127 gates[[b, g]] = sum;
128 }
129 }
130
131 for b in 0..batch_size {
132 for j in 0..self.units {
133 let i = 1.0 / (1.0 + (-gates[[b, j]]).exp());
134 let f = 1.0 / (1.0 + (-gates[[b, self.units + j]]).exp());
135 let g = gates[[b, 2 * self.units + j]].tanh();
136 let o = 1.0 / (1.0 + (-gates[[b, 3 * self.units + j]]).exp());
137
138 c[[b, j]] = f * c[[b, j]] + i * g;
139 h[[b, j]] = o * c[[b, j]].tanh();
140 }
141 }
142
143 outputs.push(h.clone());
144 }
145
146 if self.go_backwards {
147 outputs.reverse();
148 }
149
150 if self.return_sequences {
151 let mut result = ArrayD::zeros(IxDyn(&[batch_size, seq_len, self.units]));
152 for (t, h_t) in outputs.iter().enumerate() {
153 for b in 0..batch_size {
154 for j in 0..self.units {
155 result[[b, t, j]] = h_t[[b, j]];
156 }
157 }
158 }
159 Ok(result)
160 } else {
161 Ok(outputs.last().cloned().unwrap_or(h))
162 }
163 }
164
165 fn build(&mut self, input_shape: &[usize]) -> Result<()> {
166 let input_dim = *input_shape
167 .last()
168 .ok_or_else(|| MLError::InvalidConfiguration("Invalid input shape".to_string()))?;
169
170 let scale = (6.0 / (input_dim + self.units) as f64).sqrt();
171 let kernel = ArrayD::from_shape_fn(IxDyn(&[input_dim, 4 * self.units]), |_| {
172 (fastrand::f64() * 2.0 - 1.0) * scale
173 });
174 let recurrent_kernel = ArrayD::from_shape_fn(IxDyn(&[self.units, 4 * self.units]), |_| {
175 (fastrand::f64() * 2.0 - 1.0) * scale
176 });
177 let bias = ArrayD::zeros(IxDyn(&[4 * self.units]));
178
179 self.weights = Some((kernel, recurrent_kernel, bias));
180 self.built = true;
181
182 Ok(())
183 }
184
185 fn compute_output_shape(&self, input_shape: &[usize]) -> Vec<usize> {
186 if self.return_sequences {
187 vec![input_shape[0], input_shape[1], self.units]
188 } else {
189 vec![input_shape[0], self.units]
190 }
191 }
192
193 fn count_params(&self) -> usize {
194 if let Some((kernel, recurrent_kernel, bias)) = &self.weights {
195 kernel.len() + recurrent_kernel.len() + bias.len()
196 } else {
197 0
198 }
199 }
200
201 fn get_weights(&self) -> Vec<ArrayD<f64>> {
202 if let Some((k, rk, b)) = &self.weights {
203 vec![k.clone(), rk.clone(), b.clone()]
204 } else {
205 vec![]
206 }
207 }
208
209 fn set_weights(&mut self, weights: Vec<ArrayD<f64>>) -> Result<()> {
210 if weights.len() == 3 {
211 self.weights = Some((weights[0].clone(), weights[1].clone(), weights[2].clone()));
212 Ok(())
213 } else {
214 Err(MLError::InvalidConfiguration(
215 "LSTM requires 3 weight arrays".to_string(),
216 ))
217 }
218 }
219
220 fn built(&self) -> bool {
221 self.built
222 }
223
224 fn name(&self) -> &str {
225 self.layer_name.as_deref().unwrap_or("lstm")
226 }
227}
228
229pub struct GRU {
231 units: usize,
233 return_sequences: bool,
235 return_state: bool,
237 go_backwards: bool,
239 dropout: f64,
241 recurrent_dropout: f64,
243 weights: Option<(ArrayD<f64>, ArrayD<f64>, ArrayD<f64>)>,
245 built: bool,
247 layer_name: Option<String>,
249}
250
251impl GRU {
252 pub fn new(units: usize) -> Self {
254 Self {
255 units,
256 return_sequences: false,
257 return_state: false,
258 go_backwards: false,
259 dropout: 0.0,
260 recurrent_dropout: 0.0,
261 weights: None,
262 built: false,
263 layer_name: None,
264 }
265 }
266
267 pub fn return_sequences(mut self, return_sequences: bool) -> Self {
269 self.return_sequences = return_sequences;
270 self
271 }
272
273 pub fn return_state(mut self, return_state: bool) -> Self {
275 self.return_state = return_state;
276 self
277 }
278
279 pub fn go_backwards(mut self, go_backwards: bool) -> Self {
281 self.go_backwards = go_backwards;
282 self
283 }
284
285 pub fn dropout(mut self, dropout: f64) -> Self {
287 self.dropout = dropout;
288 self
289 }
290
291 pub fn name(mut self, name: &str) -> Self {
293 self.layer_name = Some(name.to_string());
294 self
295 }
296}
297
298impl KerasLayer for GRU {
299 fn call(&self, input: &ArrayD<f64>) -> Result<ArrayD<f64>> {
300 if !self.built {
301 return Err(MLError::ModelNotTrained(
302 "Layer not built. Call build() first.".to_string(),
303 ));
304 }
305
306 let (kernel, recurrent_kernel, bias) = self
307 .weights
308 .as_ref()
309 .ok_or_else(|| MLError::ModelNotTrained("GRU weights not initialized".to_string()))?;
310
311 let shape = input.shape();
312 let (batch_size, seq_len, features) = (shape[0], shape[1], shape[2]);
313
314 let mut h: ArrayD<f64> = ArrayD::zeros(IxDyn(&[batch_size, self.units]));
315 let mut outputs = Vec::with_capacity(seq_len);
316
317 let sequence: Vec<usize> = if self.go_backwards {
318 (0..seq_len).rev().collect()
319 } else {
320 (0..seq_len).collect()
321 };
322
323 for t in sequence {
324 let mut gates: ArrayD<f64> = ArrayD::zeros(IxDyn(&[batch_size, 3 * self.units]));
325
326 for b in 0..batch_size {
327 for g in 0..3 * self.units {
328 let mut sum = bias[[g]];
329 for f in 0..features.min(kernel.shape()[0]) {
330 sum += input[[b, t, f]] * kernel[[f, g]];
331 }
332 for j in 0..self.units {
333 sum += h[[b, j]] * recurrent_kernel[[j, g]];
334 }
335 gates[[b, g]] = sum;
336 }
337 }
338
339 for b in 0..batch_size {
340 for j in 0..self.units {
341 let r = 1.0 / (1.0 + (-gates[[b, j]]).exp());
342 let z = 1.0 / (1.0 + (-gates[[b, self.units + j]]).exp());
343 let n_val: f64 = gates[[b, 2 * self.units + j]] + r * h[[b, j]];
344 let n = n_val.tanh();
345
346 h[[b, j]] = (1.0 - z) * n + z * h[[b, j]];
347 }
348 }
349
350 outputs.push(h.clone());
351 }
352
353 if self.go_backwards {
354 outputs.reverse();
355 }
356
357 if self.return_sequences {
358 let mut result = ArrayD::zeros(IxDyn(&[batch_size, seq_len, self.units]));
359 for (t, h_t) in outputs.iter().enumerate() {
360 for b in 0..batch_size {
361 for j in 0..self.units {
362 result[[b, t, j]] = h_t[[b, j]];
363 }
364 }
365 }
366 Ok(result)
367 } else {
368 Ok(outputs.last().cloned().unwrap_or(h))
369 }
370 }
371
372 fn build(&mut self, input_shape: &[usize]) -> Result<()> {
373 let input_dim = *input_shape
374 .last()
375 .ok_or_else(|| MLError::InvalidConfiguration("Invalid input shape".to_string()))?;
376
377 let scale = (6.0 / (input_dim + self.units) as f64).sqrt();
378 let kernel = ArrayD::from_shape_fn(IxDyn(&[input_dim, 3 * self.units]), |_| {
379 (fastrand::f64() * 2.0 - 1.0) * scale
380 });
381 let recurrent_kernel = ArrayD::from_shape_fn(IxDyn(&[self.units, 3 * self.units]), |_| {
382 (fastrand::f64() * 2.0 - 1.0) * scale
383 });
384 let bias = ArrayD::zeros(IxDyn(&[3 * self.units]));
385
386 self.weights = Some((kernel, recurrent_kernel, bias));
387 self.built = true;
388
389 Ok(())
390 }
391
392 fn compute_output_shape(&self, input_shape: &[usize]) -> Vec<usize> {
393 if self.return_sequences {
394 vec![input_shape[0], input_shape[1], self.units]
395 } else {
396 vec![input_shape[0], self.units]
397 }
398 }
399
400 fn count_params(&self) -> usize {
401 if let Some((kernel, recurrent_kernel, bias)) = &self.weights {
402 kernel.len() + recurrent_kernel.len() + bias.len()
403 } else {
404 0
405 }
406 }
407
408 fn get_weights(&self) -> Vec<ArrayD<f64>> {
409 if let Some((k, rk, b)) = &self.weights {
410 vec![k.clone(), rk.clone(), b.clone()]
411 } else {
412 vec![]
413 }
414 }
415
416 fn set_weights(&mut self, weights: Vec<ArrayD<f64>>) -> Result<()> {
417 if weights.len() == 3 {
418 self.weights = Some((weights[0].clone(), weights[1].clone(), weights[2].clone()));
419 Ok(())
420 } else {
421 Err(MLError::InvalidConfiguration(
422 "GRU requires 3 weight arrays".to_string(),
423 ))
424 }
425 }
426
427 fn built(&self) -> bool {
428 self.built
429 }
430
431 fn name(&self) -> &str {
432 self.layer_name.as_deref().unwrap_or("gru")
433 }
434}
435
436pub struct Bidirectional {
438 forward_layer: Box<dyn KerasLayer>,
440 backward_layer: Box<dyn KerasLayer>,
442 merge_mode: String,
444 built: bool,
446 layer_name: Option<String>,
448}
449
450impl Bidirectional {
451 pub fn new(layer: Box<dyn KerasLayer>) -> Self {
453 Self {
454 forward_layer: layer,
455 backward_layer: Box::new(Dense::new(1)),
456 merge_mode: "concat".to_string(),
457 built: false,
458 layer_name: None,
459 }
460 }
461
462 pub fn merge_mode(mut self, merge_mode: &str) -> Self {
464 self.merge_mode = merge_mode.to_string();
465 self
466 }
467
468 pub fn name(mut self, name: &str) -> Self {
470 self.layer_name = Some(name.to_string());
471 self
472 }
473}
474
475impl KerasLayer for Bidirectional {
476 fn call(&self, input: &ArrayD<f64>) -> Result<ArrayD<f64>> {
477 let forward_output = self.forward_layer.call(input)?;
478
479 let shape = input.shape();
480 let mut reversed = input.clone();
481 let seq_len = shape[1];
482 for b in 0..shape[0] {
483 for t in 0..seq_len {
484 for f in 0..shape[2] {
485 reversed[[b, t, f]] = input[[b, seq_len - 1 - t, f]];
486 }
487 }
488 }
489
490 let backward_output = self.backward_layer.call(&reversed)?;
491
492 match self.merge_mode.as_str() {
493 "sum" => Ok(&forward_output + &backward_output),
494 "mul" => Ok(&forward_output * &backward_output),
495 "ave" => Ok((&forward_output + &backward_output) / 2.0),
496 _ => {
497 let fwd_shape = forward_output.shape();
498 let bwd_shape = backward_output.shape();
499 let mut output = ArrayD::zeros(IxDyn(&[
500 fwd_shape[0],
501 fwd_shape.get(1).copied().unwrap_or(1),
502 fwd_shape.last().copied().unwrap_or(0) + bwd_shape.last().copied().unwrap_or(0),
503 ]));
504
505 let fwd_last = *fwd_shape.last().unwrap_or(&0);
506 for b in 0..fwd_shape[0] {
507 for s in 0..fwd_shape.get(1).copied().unwrap_or(1) {
508 for f in 0..fwd_last {
509 output[[b, s, f]] = forward_output[[b, s, f]];
510 }
511 for f in 0..*bwd_shape.last().unwrap_or(&0) {
512 output[[b, s, fwd_last + f]] = backward_output[[b, s, f]];
513 }
514 }
515 }
516 Ok(output)
517 }
518 }
519 }
520
521 fn build(&mut self, input_shape: &[usize]) -> Result<()> {
522 self.forward_layer.build(input_shape)?;
523 self.backward_layer.build(input_shape)?;
524 self.built = true;
525 Ok(())
526 }
527
528 fn compute_output_shape(&self, input_shape: &[usize]) -> Vec<usize> {
529 let fwd_shape = self.forward_layer.compute_output_shape(input_shape);
530 match self.merge_mode.as_str() {
531 "sum" | "mul" | "ave" => fwd_shape,
532 _ => {
533 let mut out = fwd_shape.clone();
534 if let Some(last) = out.last_mut() {
535 *last *= 2;
536 }
537 out
538 }
539 }
540 }
541
542 fn count_params(&self) -> usize {
543 self.forward_layer.count_params() + self.backward_layer.count_params()
544 }
545
546 fn get_weights(&self) -> Vec<ArrayD<f64>> {
547 let mut weights = self.forward_layer.get_weights();
548 weights.extend(self.backward_layer.get_weights());
549 weights
550 }
551
552 fn set_weights(&mut self, _weights: Vec<ArrayD<f64>>) -> Result<()> {
553 Ok(())
554 }
555
556 fn built(&self) -> bool {
557 self.built
558 }
559
560 fn name(&self) -> &str {
561 self.layer_name.as_deref().unwrap_or("bidirectional")
562 }
563}