1use ghostflow_core::Tensor;
6use crate::module::Module;
7use crate::linear::Linear;
8
9pub struct LSTMCell {
19 input_size: usize,
20 hidden_size: usize,
21
22 w_ih: Linear,
24 w_hh: Linear,
26
27 training: bool,
28}
29
30impl LSTMCell {
31 pub fn new(input_size: usize, hidden_size: usize) -> Self {
37 LSTMCell {
38 input_size,
39 hidden_size,
40 w_ih: Linear::new(input_size, 4 * hidden_size),
42 w_hh: Linear::new(hidden_size, 4 * hidden_size),
43 training: true,
44 }
45 }
46
47 pub fn forward_cell(&self, input: &Tensor, hidden: &Tensor, cell: &Tensor) -> (Tensor, Tensor) {
57 let batch_size = input.dims()[0];
58
59 let gates = self.w_ih.forward(input)
61 .add(&self.w_hh.forward(hidden))
62 .unwrap();
63
64 let gates_data = gates.data_f32();
65 let hidden_data = cell.data_f32();
66
67 let mut new_cell_data = vec![0.0f32; batch_size * self.hidden_size];
68 let mut new_hidden_data = vec![0.0f32; batch_size * self.hidden_size];
69
70 for b in 0..batch_size {
71 for h in 0..self.hidden_size {
72 let base_idx = b * 4 * self.hidden_size;
73
74 let i = sigmoid(gates_data[base_idx + h]); let f = sigmoid(gates_data[base_idx + self.hidden_size + h]); let g = tanh(gates_data[base_idx + 2 * self.hidden_size + h]); let o = sigmoid(gates_data[base_idx + 3 * self.hidden_size + h]); let c_prev = hidden_data[b * self.hidden_size + h];
82 let c_new = f * c_prev + i * g;
83 new_cell_data[b * self.hidden_size + h] = c_new;
84
85 new_hidden_data[b * self.hidden_size + h] = o * tanh(c_new);
87 }
88 }
89
90 let new_hidden = Tensor::from_slice(&new_hidden_data, &[batch_size, self.hidden_size]).unwrap();
91 let new_cell = Tensor::from_slice(&new_cell_data, &[batch_size, self.hidden_size]).unwrap();
92
93 (new_hidden, new_cell)
94 }
95}
96
97impl Module for LSTMCell {
98 fn forward(&self, input: &Tensor) -> Tensor {
99 let batch_size = input.dims()[0];
100 let hidden = Tensor::zeros(&[batch_size, self.hidden_size]);
101 let cell = Tensor::zeros(&[batch_size, self.hidden_size]);
102 let (h, _) = self.forward_cell(input, &hidden, &cell);
103 h
104 }
105
106 fn parameters(&self) -> Vec<Tensor> {
107 let mut params = self.w_ih.parameters();
108 params.extend(self.w_hh.parameters());
109 params
110 }
111
112 fn train(&mut self) { self.training = true; }
113 fn eval(&mut self) { self.training = false; }
114 fn is_training(&self) -> bool { self.training }
115}
116
117pub struct LSTM {
119 cell: LSTMCell,
120 num_layers: usize,
121 bidirectional: bool,
122 dropout: f32,
123 training: bool,
124}
125
126impl LSTM {
127 pub fn new(
136 input_size: usize,
137 hidden_size: usize,
138 num_layers: usize,
139 bidirectional: bool,
140 dropout: f32,
141 ) -> Self {
142 LSTM {
143 cell: LSTMCell::new(input_size, hidden_size),
144 num_layers,
145 bidirectional,
146 dropout,
147 training: true,
148 }
149 }
150
151 pub fn forward_sequence(&self, input: &Tensor) -> Tensor {
159 let batch_size = input.dims()[0];
160 let seq_len = input.dims()[1];
161 let input_size = input.dims()[2];
162
163 let hidden_size = self.cell.hidden_size;
164 let num_directions = if self.bidirectional { 2 } else { 1 };
165
166 let mut h = Tensor::zeros(&[batch_size, hidden_size]);
168 let mut c = Tensor::zeros(&[batch_size, hidden_size]);
169
170 let input_data = input.data_f32();
171 let mut output_data = vec![0.0f32; batch_size * seq_len * hidden_size * num_directions];
172
173 for t in 0..seq_len {
175 let mut x_t_data = vec![0.0f32; batch_size * input_size];
177 for b in 0..batch_size {
178 for i in 0..input_size {
179 x_t_data[b * input_size + i] =
180 input_data[b * seq_len * input_size + t * input_size + i];
181 }
182 }
183 let x_t = Tensor::from_slice(&x_t_data, &[batch_size, input_size]).unwrap();
184
185 let (h_new, c_new) = self.cell.forward_cell(&x_t, &h, &c);
187 h = h_new;
188 c = c_new;
189
190 let h_data = h.data_f32();
192 for b in 0..batch_size {
193 for h_idx in 0..hidden_size {
194 output_data[b * seq_len * hidden_size * num_directions +
195 t * hidden_size * num_directions + h_idx] = h_data[b * hidden_size + h_idx];
196 }
197 }
198 }
199
200 if self.bidirectional {
202 let mut h_back = Tensor::zeros(&[batch_size, hidden_size]);
203 let mut c_back = Tensor::zeros(&[batch_size, hidden_size]);
204
205 for t in (0..seq_len).rev() {
206 let mut x_t_data = vec![0.0f32; batch_size * input_size];
207 for b in 0..batch_size {
208 for i in 0..input_size {
209 x_t_data[b * input_size + i] =
210 input_data[b * seq_len * input_size + t * input_size + i];
211 }
212 }
213 let x_t = Tensor::from_slice(&x_t_data, &[batch_size, input_size]).unwrap();
214
215 let (h_new, c_new) = self.cell.forward_cell(&x_t, &h_back, &c_back);
216 h_back = h_new;
217 c_back = c_new;
218
219 let h_data = h_back.data_f32();
220 for b in 0..batch_size {
221 for h_idx in 0..hidden_size {
222 output_data[b * seq_len * hidden_size * num_directions +
223 t * hidden_size * num_directions + hidden_size + h_idx] =
224 h_data[b * hidden_size + h_idx];
225 }
226 }
227 }
228 }
229
230 Tensor::from_slice(
231 &output_data,
232 &[batch_size, seq_len, hidden_size * num_directions]
233 ).unwrap()
234 }
235}
236
237impl Module for LSTM {
238 fn forward(&self, input: &Tensor) -> Tensor {
239 self.forward_sequence(input)
240 }
241
242 fn parameters(&self) -> Vec<Tensor> {
243 self.cell.parameters()
244 }
245
246 fn train(&mut self) {
247 self.training = true;
248 self.cell.train();
249 }
250
251 fn eval(&mut self) {
252 self.training = false;
253 self.cell.eval();
254 }
255
256 fn is_training(&self) -> bool { self.training }
257}
258
259pub struct GRUCell {
267 input_size: usize,
268 hidden_size: usize,
269
270 w_ih: Linear,
272 w_hh: Linear,
274
275 training: bool,
276}
277
278impl GRUCell {
279 pub fn new(input_size: usize, hidden_size: usize) -> Self {
281 GRUCell {
282 input_size,
283 hidden_size,
284 w_ih: Linear::new(input_size, 3 * hidden_size),
286 w_hh: Linear::new(hidden_size, 3 * hidden_size),
287 training: true,
288 }
289 }
290
291 pub fn forward_cell(&self, input: &Tensor, hidden: &Tensor) -> Tensor {
293 let batch_size = input.dims()[0];
294
295 let gi = self.w_ih.forward(input);
297 let gh = self.w_hh.forward(hidden);
298
299 let gi_data = gi.data_f32();
300 let gh_data = gh.data_f32();
301 let h_data = hidden.data_f32();
302
303 let mut new_hidden_data = vec![0.0f32; batch_size * self.hidden_size];
304
305 for b in 0..batch_size {
306 for h in 0..self.hidden_size {
307 let r = sigmoid(
309 gi_data[b * 3 * self.hidden_size + h] +
310 gh_data[b * 3 * self.hidden_size + h]
311 );
312
313 let z = sigmoid(
315 gi_data[b * 3 * self.hidden_size + self.hidden_size + h] +
316 gh_data[b * 3 * self.hidden_size + self.hidden_size + h]
317 );
318
319 let n = tanh(
321 gi_data[b * 3 * self.hidden_size + 2 * self.hidden_size + h] +
322 r * gh_data[b * 3 * self.hidden_size + 2 * self.hidden_size + h]
323 );
324
325 let h_prev = h_data[b * self.hidden_size + h];
327 new_hidden_data[b * self.hidden_size + h] = (1.0 - z) * n + z * h_prev;
328 }
329 }
330
331 Tensor::from_slice(&new_hidden_data, &[batch_size, self.hidden_size]).unwrap()
332 }
333}
334
335impl Module for GRUCell {
336 fn forward(&self, input: &Tensor) -> Tensor {
337 let batch_size = input.dims()[0];
338 let hidden = Tensor::zeros(&[batch_size, self.hidden_size]);
339 self.forward_cell(input, &hidden)
340 }
341
342 fn parameters(&self) -> Vec<Tensor> {
343 let mut params = self.w_ih.parameters();
344 params.extend(self.w_hh.parameters());
345 params
346 }
347
348 fn train(&mut self) { self.training = true; }
349 fn eval(&mut self) { self.training = false; }
350 fn is_training(&self) -> bool { self.training }
351}
352
353pub struct GRU {
355 cell: GRUCell,
356 num_layers: usize,
357 bidirectional: bool,
358 dropout: f32,
359 training: bool,
360}
361
362impl GRU {
363 pub fn new(
365 input_size: usize,
366 hidden_size: usize,
367 num_layers: usize,
368 bidirectional: bool,
369 dropout: f32,
370 ) -> Self {
371 GRU {
372 cell: GRUCell::new(input_size, hidden_size),
373 num_layers,
374 bidirectional,
375 dropout,
376 training: true,
377 }
378 }
379
380 pub fn forward_sequence(&self, input: &Tensor) -> Tensor {
382 let batch_size = input.dims()[0];
383 let seq_len = input.dims()[1];
384 let input_size = input.dims()[2];
385
386 let hidden_size = self.cell.hidden_size;
387 let num_directions = if self.bidirectional { 2 } else { 1 };
388
389 let mut h = Tensor::zeros(&[batch_size, hidden_size]);
390
391 let input_data = input.data_f32();
392 let mut output_data = vec![0.0f32; batch_size * seq_len * hidden_size * num_directions];
393
394 for t in 0..seq_len {
396 let mut x_t_data = vec![0.0f32; batch_size * input_size];
397 for b in 0..batch_size {
398 for i in 0..input_size {
399 x_t_data[b * input_size + i] =
400 input_data[b * seq_len * input_size + t * input_size + i];
401 }
402 }
403 let x_t = Tensor::from_slice(&x_t_data, &[batch_size, input_size]).unwrap();
404
405 h = self.cell.forward_cell(&x_t, &h);
406
407 let h_data = h.data_f32();
408 for b in 0..batch_size {
409 for h_idx in 0..hidden_size {
410 output_data[b * seq_len * hidden_size * num_directions +
411 t * hidden_size * num_directions + h_idx] = h_data[b * hidden_size + h_idx];
412 }
413 }
414 }
415
416 if self.bidirectional {
418 let mut h_back = Tensor::zeros(&[batch_size, hidden_size]);
419
420 for t in (0..seq_len).rev() {
421 let mut x_t_data = vec![0.0f32; batch_size * input_size];
422 for b in 0..batch_size {
423 for i in 0..input_size {
424 x_t_data[b * input_size + i] =
425 input_data[b * seq_len * input_size + t * input_size + i];
426 }
427 }
428 let x_t = Tensor::from_slice(&x_t_data, &[batch_size, input_size]).unwrap();
429
430 h_back = self.cell.forward_cell(&x_t, &h_back);
431
432 let h_data = h_back.data_f32();
433 for b in 0..batch_size {
434 for h_idx in 0..hidden_size {
435 output_data[b * seq_len * hidden_size * num_directions +
436 t * hidden_size * num_directions + hidden_size + h_idx] =
437 h_data[b * hidden_size + h_idx];
438 }
439 }
440 }
441 }
442
443 Tensor::from_slice(
444 &output_data,
445 &[batch_size, seq_len, hidden_size * num_directions]
446 ).unwrap()
447 }
448}
449
450impl Module for GRU {
451 fn forward(&self, input: &Tensor) -> Tensor {
452 self.forward_sequence(input)
453 }
454
455 fn parameters(&self) -> Vec<Tensor> {
456 self.cell.parameters()
457 }
458
459 fn train(&mut self) {
460 self.training = true;
461 self.cell.train();
462 }
463
464 fn eval(&mut self) {
465 self.training = false;
466 self.cell.eval();
467 }
468
469 fn is_training(&self) -> bool { self.training }
470}
471
472#[inline]
474fn sigmoid(x: f32) -> f32 {
475 1.0 / (1.0 + (-x).exp())
476}
477
478#[inline]
479fn tanh(x: f32) -> f32 {
480 x.tanh()
481}
482
483#[cfg(test)]
484mod tests {
485 use super::*;
486
487 #[test]
488 fn test_lstm_cell() {
489 let cell = LSTMCell::new(10, 20);
490 let input = Tensor::randn(&[2, 10]);
491 let hidden = Tensor::zeros(&[2, 20]);
492 let cell_state = Tensor::zeros(&[2, 20]);
493
494 let (h, c) = cell.forward_cell(&input, &hidden, &cell_state);
495
496 assert_eq!(h.dims(), &[2, 20]);
497 assert_eq!(c.dims(), &[2, 20]);
498 }
499
500 #[test]
501 fn test_lstm_sequence() {
502 let lstm = LSTM::new(10, 20, 1, false, 0.0);
503 let input = Tensor::randn(&[2, 5, 10]); let output = lstm.forward_sequence(&input);
506
507 assert_eq!(output.dims(), &[2, 5, 20]);
508 }
509
510 #[test]
511 fn test_lstm_bidirectional() {
512 let lstm = LSTM::new(10, 20, 1, true, 0.0);
513 let input = Tensor::randn(&[2, 5, 10]);
514
515 let output = lstm.forward_sequence(&input);
516
517 assert_eq!(output.dims(), &[2, 5, 40]); }
519
520 #[test]
521 fn test_gru_cell() {
522 let cell = GRUCell::new(10, 20);
523 let input = Tensor::randn(&[2, 10]);
524 let hidden = Tensor::zeros(&[2, 20]);
525
526 let h = cell.forward_cell(&input, &hidden);
527
528 assert_eq!(h.dims(), &[2, 20]);
529 }
530
531 #[test]
532 fn test_gru_sequence() {
533 let gru = GRU::new(10, 20, 1, false, 0.0);
534 let input = Tensor::randn(&[2, 5, 10]);
535
536 let output = gru.forward_sequence(&input);
537
538 assert_eq!(output.dims(), &[2, 5, 20]);
539 }
540
541 #[test]
542 fn test_gru_bidirectional() {
543 let gru = GRU::new(10, 20, 1, true, 0.0);
544 let input = Tensor::randn(&[2, 5, 10]);
545
546 let output = gru.forward_sequence(&input);
547
548 assert_eq!(output.dims(), &[2, 5, 40]); }
550}