1use super::{Parameter, QuantumModule};
4use crate::error::Result;
5use crate::scirs2_integration::SciRS2Array;
6use scirs2_core::ndarray::{s, ArrayD, IxDyn};
7
8#[derive(Debug, Clone)]
10pub struct LSTMState {
11 pub h: SciRS2Array,
13 pub c: SciRS2Array,
15}
16
17pub struct QuantumLSTM {
19 input_size: usize,
20 hidden_size: usize,
21 num_layers: usize,
22 bidirectional: bool,
23 dropout: f64,
24 batch_first: bool,
25 weights: Vec<Parameter>,
26 training: bool,
27}
28
29impl QuantumLSTM {
30 pub fn new(input_size: usize, hidden_size: usize) -> Self {
32 let weight_ih = ArrayD::from_shape_fn(IxDyn(&[4 * hidden_size, input_size]), |_| {
33 fastrand::f64() * 0.1 - 0.05
34 });
35 let weight_hh = ArrayD::from_shape_fn(IxDyn(&[4 * hidden_size, hidden_size]), |_| {
36 fastrand::f64() * 0.1 - 0.05
37 });
38 let bias_ih = ArrayD::zeros(IxDyn(&[4 * hidden_size]));
39 let bias_hh = ArrayD::zeros(IxDyn(&[4 * hidden_size]));
40
41 Self {
42 input_size,
43 hidden_size,
44 num_layers: 1,
45 bidirectional: false,
46 dropout: 0.0,
47 batch_first: true,
48 weights: vec![
49 Parameter::new(SciRS2Array::with_grad(weight_ih), "weight_ih_l0"),
50 Parameter::new(SciRS2Array::with_grad(weight_hh), "weight_hh_l0"),
51 Parameter::new(SciRS2Array::with_grad(bias_ih), "bias_ih_l0"),
52 Parameter::new(SciRS2Array::with_grad(bias_hh), "bias_hh_l0"),
53 ],
54 training: true,
55 }
56 }
57
58 pub fn num_layers(mut self, num_layers: usize) -> Self {
60 self.num_layers = num_layers;
61 self
62 }
63
64 pub fn bidirectional(mut self, bidirectional: bool) -> Self {
66 self.bidirectional = bidirectional;
67 self
68 }
69
70 pub fn dropout(mut self, dropout: f64) -> Self {
72 self.dropout = dropout;
73 self
74 }
75
76 pub fn batch_first(mut self, batch_first: bool) -> Self {
78 self.batch_first = batch_first;
79 self
80 }
81
82 pub fn forward_with_state(
84 &mut self,
85 input: &SciRS2Array,
86 initial_state: Option<LSTMState>,
87 ) -> Result<(SciRS2Array, LSTMState)> {
88 let shape = input.data.shape();
89 let (batch_size, seq_len, _input_size) = if self.batch_first {
90 (shape[0], shape[1], shape[2])
91 } else {
92 (shape[1], shape[0], shape[2])
93 };
94
95 let (mut h, mut c) = match initial_state {
96 Some(state) => (state.h.data, state.c.data),
97 None => (
98 ArrayD::zeros(IxDyn(&[batch_size, self.hidden_size])),
99 ArrayD::zeros(IxDyn(&[batch_size, self.hidden_size])),
100 ),
101 };
102
103 let mut outputs = Vec::with_capacity(seq_len);
104
105 for t in 0..seq_len {
106 let x_t = if self.batch_first {
107 input.data.slice(s![.., t, ..]).to_owned()
108 } else {
109 input.data.slice(s![t, .., ..]).to_owned()
110 };
111
112 let weight_ih = &self.weights[0].data.data;
113 let weight_hh = &self.weights[1].data.data;
114
115 let mut gates = ArrayD::zeros(IxDyn(&[batch_size, 4 * self.hidden_size]));
116
117 for b in 0..batch_size {
118 for g in 0..4 * self.hidden_size {
119 let mut sum = 0.0;
120 for i in 0..self
121 .input_size
122 .min(x_t.shape().last().copied().unwrap_or(self.input_size))
123 {
124 sum += x_t[[b, i]] * weight_ih[[g, i]];
125 }
126 for j in 0..self.hidden_size {
127 sum += h[[b, j]] * weight_hh[[g, j]];
128 }
129 gates[[b, g]] = sum;
130 }
131 }
132
133 let mut i_gate = ArrayD::zeros(IxDyn(&[batch_size, self.hidden_size]));
134 let mut f_gate = ArrayD::zeros(IxDyn(&[batch_size, self.hidden_size]));
135 let mut g_gate = ArrayD::zeros(IxDyn(&[batch_size, self.hidden_size]));
136 let mut o_gate = ArrayD::zeros(IxDyn(&[batch_size, self.hidden_size]));
137
138 for b in 0..batch_size {
139 for j in 0..self.hidden_size {
140 i_gate[[b, j]] = 1.0 / (1.0 + (-gates[[b, j]]).exp());
141 f_gate[[b, j]] = 1.0 / (1.0 + (-gates[[b, self.hidden_size + j]]).exp());
142 g_gate[[b, j]] = gates[[b, 2 * self.hidden_size + j]].tanh();
143 o_gate[[b, j]] = 1.0 / (1.0 + (-gates[[b, 3 * self.hidden_size + j]]).exp());
144 }
145 }
146
147 for b in 0..batch_size {
148 for j in 0..self.hidden_size {
149 c[[b, j]] = f_gate[[b, j]] * c[[b, j]] + i_gate[[b, j]] * g_gate[[b, j]];
150 h[[b, j]] = o_gate[[b, j]] * c[[b, j]].tanh();
151 }
152 }
153
154 outputs.push(h.clone());
155 }
156
157 let output_shape = if self.batch_first {
158 IxDyn(&[batch_size, seq_len, self.hidden_size])
159 } else {
160 IxDyn(&[seq_len, batch_size, self.hidden_size])
161 };
162 let mut output = ArrayD::zeros(output_shape);
163
164 for (t, h_t) in outputs.iter().enumerate() {
165 for b in 0..batch_size {
166 for j in 0..self.hidden_size {
167 if self.batch_first {
168 output[[b, t, j]] = h_t[[b, j]];
169 } else {
170 output[[t, b, j]] = h_t[[b, j]];
171 }
172 }
173 }
174 }
175
176 let final_state = LSTMState {
177 h: SciRS2Array::new(h, input.requires_grad),
178 c: SciRS2Array::new(c, input.requires_grad),
179 };
180
181 Ok((SciRS2Array::new(output, input.requires_grad), final_state))
182 }
183}
184
185impl QuantumModule for QuantumLSTM {
186 fn forward(&mut self, input: &SciRS2Array) -> Result<SciRS2Array> {
187 let (output, _) = self.forward_with_state(input, None)?;
188 Ok(output)
189 }
190
191 fn parameters(&self) -> Vec<Parameter> {
192 self.weights.clone()
193 }
194
195 fn train(&mut self, mode: bool) {
196 self.training = mode;
197 }
198
199 fn training(&self) -> bool {
200 self.training
201 }
202
203 fn zero_grad(&mut self) {
204 for w in &mut self.weights {
205 w.data.zero_grad();
206 }
207 }
208
209 fn name(&self) -> &str {
210 "LSTM"
211 }
212}
213
214pub struct QuantumGRU {
216 input_size: usize,
217 hidden_size: usize,
218 num_layers: usize,
219 bidirectional: bool,
220 dropout: f64,
221 batch_first: bool,
222 weights: Vec<Parameter>,
223 training: bool,
224}
225
226impl QuantumGRU {
227 pub fn new(input_size: usize, hidden_size: usize) -> Self {
229 let weight_ih = ArrayD::from_shape_fn(IxDyn(&[3 * hidden_size, input_size]), |_| {
230 fastrand::f64() * 0.1 - 0.05
231 });
232 let weight_hh = ArrayD::from_shape_fn(IxDyn(&[3 * hidden_size, hidden_size]), |_| {
233 fastrand::f64() * 0.1 - 0.05
234 });
235 let bias_ih = ArrayD::zeros(IxDyn(&[3 * hidden_size]));
236 let bias_hh = ArrayD::zeros(IxDyn(&[3 * hidden_size]));
237
238 Self {
239 input_size,
240 hidden_size,
241 num_layers: 1,
242 bidirectional: false,
243 dropout: 0.0,
244 batch_first: true,
245 weights: vec![
246 Parameter::new(SciRS2Array::with_grad(weight_ih), "weight_ih_l0"),
247 Parameter::new(SciRS2Array::with_grad(weight_hh), "weight_hh_l0"),
248 Parameter::new(SciRS2Array::with_grad(bias_ih), "bias_ih_l0"),
249 Parameter::new(SciRS2Array::with_grad(bias_hh), "bias_hh_l0"),
250 ],
251 training: true,
252 }
253 }
254
255 pub fn num_layers(mut self, num_layers: usize) -> Self {
257 self.num_layers = num_layers;
258 self
259 }
260
261 pub fn bidirectional(mut self, bidirectional: bool) -> Self {
263 self.bidirectional = bidirectional;
264 self
265 }
266
267 pub fn batch_first(mut self, batch_first: bool) -> Self {
269 self.batch_first = batch_first;
270 self
271 }
272
273 pub fn forward_with_hidden(
275 &mut self,
276 input: &SciRS2Array,
277 initial_hidden: Option<SciRS2Array>,
278 ) -> Result<(SciRS2Array, SciRS2Array)> {
279 let shape = input.data.shape();
280 let (batch_size, seq_len, _) = if self.batch_first {
281 (shape[0], shape[1], shape[2])
282 } else {
283 (shape[1], shape[0], shape[2])
284 };
285
286 let mut h = match initial_hidden {
287 Some(state) => state.data,
288 None => ArrayD::zeros(IxDyn(&[batch_size, self.hidden_size])),
289 };
290
291 let mut outputs = Vec::with_capacity(seq_len);
292
293 for t in 0..seq_len {
294 let x_t = if self.batch_first {
295 input.data.slice(s![.., t, ..]).to_owned()
296 } else {
297 input.data.slice(s![t, .., ..]).to_owned()
298 };
299
300 let weight_ih = &self.weights[0].data.data;
301 let weight_hh = &self.weights[1].data.data;
302
303 let mut gates = ArrayD::zeros(IxDyn(&[batch_size, 3 * self.hidden_size]));
304
305 for b in 0..batch_size {
306 for g in 0..3 * self.hidden_size {
307 let mut sum = 0.0;
308 for i in 0..self
309 .input_size
310 .min(x_t.shape().last().copied().unwrap_or(self.input_size))
311 {
312 sum += x_t[[b, i]] * weight_ih[[g, i]];
313 }
314 for j in 0..self.hidden_size {
315 sum += h[[b, j]] * weight_hh[[g, j]];
316 }
317 gates[[b, g]] = sum;
318 }
319 }
320
321 let mut r_gate = ArrayD::zeros(IxDyn(&[batch_size, self.hidden_size]));
322 let mut z_gate = ArrayD::zeros(IxDyn(&[batch_size, self.hidden_size]));
323 let mut n_gate = ArrayD::zeros(IxDyn(&[batch_size, self.hidden_size]));
324
325 for b in 0..batch_size {
326 for j in 0..self.hidden_size {
327 r_gate[[b, j]] = 1.0 / (1.0 + (-gates[[b, j]]).exp());
328 z_gate[[b, j]] = 1.0 / (1.0 + (-gates[[b, self.hidden_size + j]]).exp());
329 n_gate[[b, j]] =
330 (gates[[b, 2 * self.hidden_size + j]] + r_gate[[b, j]] * h[[b, j]]).tanh();
331 }
332 }
333
334 for b in 0..batch_size {
335 for j in 0..self.hidden_size {
336 h[[b, j]] =
337 (1.0 - z_gate[[b, j]]) * n_gate[[b, j]] + z_gate[[b, j]] * h[[b, j]];
338 }
339 }
340
341 outputs.push(h.clone());
342 }
343
344 let output_shape = if self.batch_first {
345 IxDyn(&[batch_size, seq_len, self.hidden_size])
346 } else {
347 IxDyn(&[seq_len, batch_size, self.hidden_size])
348 };
349 let mut output = ArrayD::zeros(output_shape);
350
351 for (t, h_t) in outputs.iter().enumerate() {
352 for b in 0..batch_size {
353 for j in 0..self.hidden_size {
354 if self.batch_first {
355 output[[b, t, j]] = h_t[[b, j]];
356 } else {
357 output[[t, b, j]] = h_t[[b, j]];
358 }
359 }
360 }
361 }
362
363 Ok((
364 SciRS2Array::new(output, input.requires_grad),
365 SciRS2Array::new(h, input.requires_grad),
366 ))
367 }
368}
369
370impl QuantumModule for QuantumGRU {
371 fn forward(&mut self, input: &SciRS2Array) -> Result<SciRS2Array> {
372 let (output, _) = self.forward_with_hidden(input, None)?;
373 Ok(output)
374 }
375
376 fn parameters(&self) -> Vec<Parameter> {
377 self.weights.clone()
378 }
379
380 fn train(&mut self, mode: bool) {
381 self.training = mode;
382 }
383
384 fn training(&self) -> bool {
385 self.training
386 }
387
388 fn zero_grad(&mut self) {
389 for w in &mut self.weights {
390 w.data.zero_grad();
391 }
392 }
393
394 fn name(&self) -> &str {
395 "GRU"
396 }
397}