Skip to main content

yscv_model/layers/
recurrent.rs

1use yscv_autograd::{Graph, NodeId};
2use yscv_tensor::Tensor;
3
4use super::super::recurrent::{
5    GruCell, LstmCell, RnnCell, gru_forward_sequence, lstm_forward_sequence, rnn_forward_sequence,
6};
7use crate::ModelError;
8
9/// RNN layer wrapping `rnn_forward_sequence`.
10///
11/// Input: `[seq_len, input_size]`, output: `[seq_len, hidden_size]`.
12#[derive(Debug, Clone)]
13pub struct RnnLayer {
14    pub input_size: usize,
15    pub hidden_size: usize,
16    pub w_ih: Tensor,
17    pub w_hh: Tensor,
18    pub b_h: Tensor,
19    w_ih_node: Option<NodeId>,
20    w_hh_node: Option<NodeId>,
21    b_h_node: Option<NodeId>,
22}
23
24impl RnnLayer {
25    /// Creates an RNN layer with small random weights seeded by `seed`.
26    pub fn new(input_size: usize, hidden_size: usize, seed: u64) -> Self {
27        let scale = 1.0 / (hidden_size as f32).sqrt();
28        let w_ih = Self::pseudo_random_tensor(vec![input_size, hidden_size], scale, seed);
29        let w_hh =
30            Self::pseudo_random_tensor(vec![hidden_size, hidden_size], scale, seed.wrapping_add(1));
31        let b_h = Tensor::from_vec(vec![hidden_size], vec![0.0; hidden_size]).expect("valid bias");
32        Self {
33            input_size,
34            hidden_size,
35            w_ih,
36            w_hh,
37            b_h,
38            w_ih_node: None,
39            w_hh_node: None,
40            b_h_node: None,
41        }
42    }
43
44    pub(crate) fn pseudo_random_tensor(shape: Vec<usize>, scale: f32, seed: u64) -> Tensor {
45        let len: usize = shape.iter().product();
46        let mut data = Vec::with_capacity(len);
47        let mut s = seed;
48        for _ in 0..len {
49            // Simple LCG PRNG
50            s = s
51                .wrapping_mul(6364136223846793005)
52                .wrapping_add(1442695040888963407);
53            let v = ((s >> 33) as f32 / u32::MAX as f32 - 0.5) * 2.0 * scale;
54            data.push(v);
55        }
56        Tensor::from_vec(shape, data).expect("valid tensor")
57    }
58
59    pub fn w_ih_node(&self) -> Option<NodeId> {
60        self.w_ih_node
61    }
62
63    pub fn register_params(&mut self, graph: &mut Graph) {
64        self.w_ih_node = Some(graph.variable(self.w_ih.clone()));
65        self.w_hh_node = Some(graph.variable(self.w_hh.clone()));
66        self.b_h_node = Some(graph.variable(self.b_h.clone()));
67    }
68
69    pub fn forward(&self, graph: &mut Graph, input: NodeId) -> Result<NodeId, ModelError> {
70        let w_ih = self
71            .w_ih_node
72            .ok_or(ModelError::ParamsNotRegistered { layer: "Rnn" })?;
73        let w_hh = self
74            .w_hh_node
75            .ok_or(ModelError::ParamsNotRegistered { layer: "Rnn" })?;
76        let bias = self
77            .b_h_node
78            .ok_or(ModelError::ParamsNotRegistered { layer: "Rnn" })?;
79        graph
80            .rnn_forward(input, w_ih, w_hh, bias)
81            .map_err(Into::into)
82    }
83
84    /// Forward inference: input `[seq_len, input_size]` -> `[seq_len, hidden_size]`.
85    pub fn forward_inference(&self, input: &Tensor) -> Result<Tensor, ModelError> {
86        let shape = input.shape();
87        if shape.len() != 2 {
88            return Err(ModelError::InvalidInputShape {
89                expected_features: self.input_size,
90                got: shape.to_vec(),
91            });
92        }
93        let seq_len = shape[0];
94        // Wrap as [1, seq_len, input_size] for rnn_forward_sequence
95        let batched = input.reshape(vec![1, seq_len, self.input_size])?;
96        let cell = RnnCell {
97            w_ih: self.w_ih.clone(),
98            w_hh: self.w_hh.clone(),
99            bias: self.b_h.clone(),
100            hidden_size: self.hidden_size,
101        };
102        let (output, _) = rnn_forward_sequence(&cell, &batched, None)?;
103        // output is [1, seq_len, hidden_size], reshape to [seq_len, hidden_size]
104        output
105            .reshape(vec![seq_len, self.hidden_size])
106            .map_err(Into::into)
107    }
108}
109
110/// LSTM layer wrapping `lstm_forward_sequence`.
111///
112/// Input: `[seq_len, input_size]`, output: `[seq_len, hidden_size]`.
113#[derive(Debug, Clone)]
114pub struct LstmLayer {
115    pub input_size: usize,
116    pub hidden_size: usize,
117    pub cell: LstmCell,
118    w_ih_node: Option<NodeId>,
119    w_hh_node: Option<NodeId>,
120    bias_node: Option<NodeId>,
121}
122
123impl LstmLayer {
124    pub fn new(input_size: usize, hidden_size: usize, seed: u64) -> Self {
125        let h4 = 4 * hidden_size;
126        let scale = 1.0 / (hidden_size as f32).sqrt();
127        let w_ih = RnnLayer::pseudo_random_tensor(vec![input_size, h4], scale, seed);
128        let w_hh =
129            RnnLayer::pseudo_random_tensor(vec![hidden_size, h4], scale, seed.wrapping_add(1));
130        let bias = Tensor::from_vec(vec![h4], vec![0.0; h4]).expect("valid bias");
131        let cell = LstmCell {
132            w_ih,
133            w_hh,
134            bias,
135            hidden_size,
136        };
137        Self {
138            input_size,
139            hidden_size,
140            cell,
141            w_ih_node: None,
142            w_hh_node: None,
143            bias_node: None,
144        }
145    }
146
147    pub fn w_ih_node(&self) -> Option<NodeId> {
148        self.w_ih_node
149    }
150
151    pub fn register_params(&mut self, graph: &mut Graph) {
152        self.w_ih_node = Some(graph.variable(self.cell.w_ih.clone()));
153        self.w_hh_node = Some(graph.variable(self.cell.w_hh.clone()));
154        self.bias_node = Some(graph.variable(self.cell.bias.clone()));
155    }
156
157    pub fn forward(&self, graph: &mut Graph, input: NodeId) -> Result<NodeId, ModelError> {
158        let w_ih = self
159            .w_ih_node
160            .ok_or(ModelError::ParamsNotRegistered { layer: "Lstm" })?;
161        let w_hh = self
162            .w_hh_node
163            .ok_or(ModelError::ParamsNotRegistered { layer: "Lstm" })?;
164        let bias = self
165            .bias_node
166            .ok_or(ModelError::ParamsNotRegistered { layer: "Lstm" })?;
167        graph
168            .lstm_forward(input, w_ih, w_hh, bias)
169            .map_err(Into::into)
170    }
171
172    pub fn forward_inference(&self, input: &Tensor) -> Result<Tensor, ModelError> {
173        let shape = input.shape();
174        if shape.len() != 2 {
175            return Err(ModelError::InvalidInputShape {
176                expected_features: self.input_size,
177                got: shape.to_vec(),
178            });
179        }
180        let seq_len = shape[0];
181        let batched = input.reshape(vec![1, seq_len, self.input_size])?;
182        let (output, _, _) = lstm_forward_sequence(&self.cell, &batched, None, None)?;
183        output
184            .reshape(vec![seq_len, self.hidden_size])
185            .map_err(Into::into)
186    }
187}
188
189/// GRU layer wrapping `gru_forward_sequence`.
190///
191/// Input: `[seq_len, input_size]`, output: `[seq_len, hidden_size]`.
192#[derive(Debug, Clone)]
193pub struct GruLayer {
194    pub input_size: usize,
195    pub hidden_size: usize,
196    pub cell: GruCell,
197    w_ih_node: Option<NodeId>,
198    w_hh_node: Option<NodeId>,
199    bias_ih_node: Option<NodeId>,
200    bias_hh_node: Option<NodeId>,
201}
202
203impl GruLayer {
204    pub fn new(input_size: usize, hidden_size: usize, seed: u64) -> Self {
205        let h3 = 3 * hidden_size;
206        let scale = 1.0 / (hidden_size as f32).sqrt();
207        let w_ih = RnnLayer::pseudo_random_tensor(vec![input_size, h3], scale, seed);
208        let w_hh =
209            RnnLayer::pseudo_random_tensor(vec![hidden_size, h3], scale, seed.wrapping_add(1));
210        let bias_ih = Tensor::from_vec(vec![h3], vec![0.0; h3]).expect("valid bias");
211        let bias_hh = Tensor::from_vec(vec![h3], vec![0.0; h3]).expect("valid bias");
212        let cell = GruCell {
213            w_ih,
214            w_hh,
215            bias_ih,
216            bias_hh,
217            hidden_size,
218        };
219        Self {
220            input_size,
221            hidden_size,
222            cell,
223            w_ih_node: None,
224            w_hh_node: None,
225            bias_ih_node: None,
226            bias_hh_node: None,
227        }
228    }
229
230    pub fn w_ih_node(&self) -> Option<NodeId> {
231        self.w_ih_node
232    }
233
234    pub fn register_params(&mut self, graph: &mut Graph) {
235        self.w_ih_node = Some(graph.variable(self.cell.w_ih.clone()));
236        self.w_hh_node = Some(graph.variable(self.cell.w_hh.clone()));
237        self.bias_ih_node = Some(graph.variable(self.cell.bias_ih.clone()));
238        self.bias_hh_node = Some(graph.variable(self.cell.bias_hh.clone()));
239    }
240
241    pub fn forward(&self, graph: &mut Graph, input: NodeId) -> Result<NodeId, ModelError> {
242        let w_ih = self
243            .w_ih_node
244            .ok_or(ModelError::ParamsNotRegistered { layer: "Gru" })?;
245        let w_hh = self
246            .w_hh_node
247            .ok_or(ModelError::ParamsNotRegistered { layer: "Gru" })?;
248        let bias_ih = self
249            .bias_ih_node
250            .ok_or(ModelError::ParamsNotRegistered { layer: "Gru" })?;
251        let bias_hh = self
252            .bias_hh_node
253            .ok_or(ModelError::ParamsNotRegistered { layer: "Gru" })?;
254        graph
255            .gru_forward(input, w_ih, w_hh, bias_ih, bias_hh)
256            .map_err(Into::into)
257    }
258
259    pub fn forward_inference(&self, input: &Tensor) -> Result<Tensor, ModelError> {
260        let shape = input.shape();
261        if shape.len() != 2 {
262            return Err(ModelError::InvalidInputShape {
263                expected_features: self.input_size,
264                got: shape.to_vec(),
265            });
266        }
267        let seq_len = shape[0];
268        let batched = input.reshape(vec![1, seq_len, self.input_size])?;
269        let (output, _) = gru_forward_sequence(&self.cell, &batched, None)?;
270        output
271            .reshape(vec![seq_len, self.hidden_size])
272            .map_err(Into::into)
273    }
274}