1use yscv_autograd::{Graph, NodeId};
2use yscv_tensor::Tensor;
3
4use super::super::recurrent::{
5 GruCell, LstmCell, RnnCell, gru_forward_sequence, lstm_forward_sequence, rnn_forward_sequence,
6};
7use crate::ModelError;
8
9#[derive(Debug, Clone)]
13pub struct RnnLayer {
14 pub input_size: usize,
15 pub hidden_size: usize,
16 pub w_ih: Tensor,
17 pub w_hh: Tensor,
18 pub b_h: Tensor,
19 w_ih_node: Option<NodeId>,
20 w_hh_node: Option<NodeId>,
21 b_h_node: Option<NodeId>,
22}
23
24impl RnnLayer {
25 pub fn new(input_size: usize, hidden_size: usize, seed: u64) -> Self {
27 let scale = 1.0 / (hidden_size as f32).sqrt();
28 let w_ih = Self::pseudo_random_tensor(vec![input_size, hidden_size], scale, seed);
29 let w_hh =
30 Self::pseudo_random_tensor(vec![hidden_size, hidden_size], scale, seed.wrapping_add(1));
31 let b_h = Tensor::from_vec(vec![hidden_size], vec![0.0; hidden_size]).expect("valid bias");
32 Self {
33 input_size,
34 hidden_size,
35 w_ih,
36 w_hh,
37 b_h,
38 w_ih_node: None,
39 w_hh_node: None,
40 b_h_node: None,
41 }
42 }
43
44 pub(crate) fn pseudo_random_tensor(shape: Vec<usize>, scale: f32, seed: u64) -> Tensor {
45 let len: usize = shape.iter().product();
46 let mut data = Vec::with_capacity(len);
47 let mut s = seed;
48 for _ in 0..len {
49 s = s
51 .wrapping_mul(6364136223846793005)
52 .wrapping_add(1442695040888963407);
53 let v = ((s >> 33) as f32 / u32::MAX as f32 - 0.5) * 2.0 * scale;
54 data.push(v);
55 }
56 Tensor::from_vec(shape, data).expect("valid tensor")
57 }
58
59 pub fn w_ih_node(&self) -> Option<NodeId> {
60 self.w_ih_node
61 }
62
63 pub fn register_params(&mut self, graph: &mut Graph) {
64 self.w_ih_node = Some(graph.variable(self.w_ih.clone()));
65 self.w_hh_node = Some(graph.variable(self.w_hh.clone()));
66 self.b_h_node = Some(graph.variable(self.b_h.clone()));
67 }
68
69 pub fn forward(&self, graph: &mut Graph, input: NodeId) -> Result<NodeId, ModelError> {
70 let w_ih = self
71 .w_ih_node
72 .ok_or(ModelError::ParamsNotRegistered { layer: "Rnn" })?;
73 let w_hh = self
74 .w_hh_node
75 .ok_or(ModelError::ParamsNotRegistered { layer: "Rnn" })?;
76 let bias = self
77 .b_h_node
78 .ok_or(ModelError::ParamsNotRegistered { layer: "Rnn" })?;
79 graph
80 .rnn_forward(input, w_ih, w_hh, bias)
81 .map_err(Into::into)
82 }
83
84 pub fn forward_inference(&self, input: &Tensor) -> Result<Tensor, ModelError> {
86 let shape = input.shape();
87 if shape.len() != 2 {
88 return Err(ModelError::InvalidInputShape {
89 expected_features: self.input_size,
90 got: shape.to_vec(),
91 });
92 }
93 let seq_len = shape[0];
94 let batched = input.reshape(vec![1, seq_len, self.input_size])?;
96 let cell = RnnCell {
97 w_ih: self.w_ih.clone(),
98 w_hh: self.w_hh.clone(),
99 bias: self.b_h.clone(),
100 hidden_size: self.hidden_size,
101 };
102 let (output, _) = rnn_forward_sequence(&cell, &batched, None)?;
103 output
105 .reshape(vec![seq_len, self.hidden_size])
106 .map_err(Into::into)
107 }
108}
109
110#[derive(Debug, Clone)]
114pub struct LstmLayer {
115 pub input_size: usize,
116 pub hidden_size: usize,
117 pub cell: LstmCell,
118 w_ih_node: Option<NodeId>,
119 w_hh_node: Option<NodeId>,
120 bias_node: Option<NodeId>,
121}
122
123impl LstmLayer {
124 pub fn new(input_size: usize, hidden_size: usize, seed: u64) -> Self {
125 let h4 = 4 * hidden_size;
126 let scale = 1.0 / (hidden_size as f32).sqrt();
127 let w_ih = RnnLayer::pseudo_random_tensor(vec![input_size, h4], scale, seed);
128 let w_hh =
129 RnnLayer::pseudo_random_tensor(vec![hidden_size, h4], scale, seed.wrapping_add(1));
130 let bias = Tensor::from_vec(vec![h4], vec![0.0; h4]).expect("valid bias");
131 let cell = LstmCell {
132 w_ih,
133 w_hh,
134 bias,
135 hidden_size,
136 };
137 Self {
138 input_size,
139 hidden_size,
140 cell,
141 w_ih_node: None,
142 w_hh_node: None,
143 bias_node: None,
144 }
145 }
146
147 pub fn w_ih_node(&self) -> Option<NodeId> {
148 self.w_ih_node
149 }
150
151 pub fn register_params(&mut self, graph: &mut Graph) {
152 self.w_ih_node = Some(graph.variable(self.cell.w_ih.clone()));
153 self.w_hh_node = Some(graph.variable(self.cell.w_hh.clone()));
154 self.bias_node = Some(graph.variable(self.cell.bias.clone()));
155 }
156
157 pub fn forward(&self, graph: &mut Graph, input: NodeId) -> Result<NodeId, ModelError> {
158 let w_ih = self
159 .w_ih_node
160 .ok_or(ModelError::ParamsNotRegistered { layer: "Lstm" })?;
161 let w_hh = self
162 .w_hh_node
163 .ok_or(ModelError::ParamsNotRegistered { layer: "Lstm" })?;
164 let bias = self
165 .bias_node
166 .ok_or(ModelError::ParamsNotRegistered { layer: "Lstm" })?;
167 graph
168 .lstm_forward(input, w_ih, w_hh, bias)
169 .map_err(Into::into)
170 }
171
172 pub fn forward_inference(&self, input: &Tensor) -> Result<Tensor, ModelError> {
173 let shape = input.shape();
174 if shape.len() != 2 {
175 return Err(ModelError::InvalidInputShape {
176 expected_features: self.input_size,
177 got: shape.to_vec(),
178 });
179 }
180 let seq_len = shape[0];
181 let batched = input.reshape(vec![1, seq_len, self.input_size])?;
182 let (output, _, _) = lstm_forward_sequence(&self.cell, &batched, None, None)?;
183 output
184 .reshape(vec![seq_len, self.hidden_size])
185 .map_err(Into::into)
186 }
187}
188
189#[derive(Debug, Clone)]
193pub struct GruLayer {
194 pub input_size: usize,
195 pub hidden_size: usize,
196 pub cell: GruCell,
197 w_ih_node: Option<NodeId>,
198 w_hh_node: Option<NodeId>,
199 bias_ih_node: Option<NodeId>,
200 bias_hh_node: Option<NodeId>,
201}
202
203impl GruLayer {
204 pub fn new(input_size: usize, hidden_size: usize, seed: u64) -> Self {
205 let h3 = 3 * hidden_size;
206 let scale = 1.0 / (hidden_size as f32).sqrt();
207 let w_ih = RnnLayer::pseudo_random_tensor(vec![input_size, h3], scale, seed);
208 let w_hh =
209 RnnLayer::pseudo_random_tensor(vec![hidden_size, h3], scale, seed.wrapping_add(1));
210 let bias_ih = Tensor::from_vec(vec![h3], vec![0.0; h3]).expect("valid bias");
211 let bias_hh = Tensor::from_vec(vec![h3], vec![0.0; h3]).expect("valid bias");
212 let cell = GruCell {
213 w_ih,
214 w_hh,
215 bias_ih,
216 bias_hh,
217 hidden_size,
218 };
219 Self {
220 input_size,
221 hidden_size,
222 cell,
223 w_ih_node: None,
224 w_hh_node: None,
225 bias_ih_node: None,
226 bias_hh_node: None,
227 }
228 }
229
230 pub fn w_ih_node(&self) -> Option<NodeId> {
231 self.w_ih_node
232 }
233
234 pub fn register_params(&mut self, graph: &mut Graph) {
235 self.w_ih_node = Some(graph.variable(self.cell.w_ih.clone()));
236 self.w_hh_node = Some(graph.variable(self.cell.w_hh.clone()));
237 self.bias_ih_node = Some(graph.variable(self.cell.bias_ih.clone()));
238 self.bias_hh_node = Some(graph.variable(self.cell.bias_hh.clone()));
239 }
240
241 pub fn forward(&self, graph: &mut Graph, input: NodeId) -> Result<NodeId, ModelError> {
242 let w_ih = self
243 .w_ih_node
244 .ok_or(ModelError::ParamsNotRegistered { layer: "Gru" })?;
245 let w_hh = self
246 .w_hh_node
247 .ok_or(ModelError::ParamsNotRegistered { layer: "Gru" })?;
248 let bias_ih = self
249 .bias_ih_node
250 .ok_or(ModelError::ParamsNotRegistered { layer: "Gru" })?;
251 let bias_hh = self
252 .bias_hh_node
253 .ok_or(ModelError::ParamsNotRegistered { layer: "Gru" })?;
254 graph
255 .gru_forward(input, w_ih, w_hh, bias_ih, bias_hh)
256 .map_err(Into::into)
257 }
258
259 pub fn forward_inference(&self, input: &Tensor) -> Result<Tensor, ModelError> {
260 let shape = input.shape();
261 if shape.len() != 2 {
262 return Err(ModelError::InvalidInputShape {
263 expected_features: self.input_size,
264 got: shape.to_vec(),
265 });
266 }
267 let seq_len = shape[0];
268 let batched = input.reshape(vec![1, seq_len, self.input_size])?;
269 let (output, _) = gru_forward_sequence(&self.cell, &batched, None)?;
270 output
271 .reshape(vec![seq_len, self.hidden_size])
272 .map_err(Into::into)
273 }
274}