1use super::layers::{QuantumLayerNorm, QuantumLinear};
4use super::{Parameter, QuantumModule};
5use crate::error::{MLError, Result};
6use crate::scirs2_integration::SciRS2Array;
7use scirs2_core::ndarray::{ArrayD, IxDyn};
8
9pub struct QuantumMultiheadAttention {
11 embed_dim: usize,
12 num_heads: usize,
13 head_dim: usize,
14 q_proj: Parameter,
15 k_proj: Parameter,
16 v_proj: Parameter,
17 out_proj: Parameter,
18 dropout: f64,
19 training: bool,
20}
21
22impl QuantumMultiheadAttention {
23 pub fn new(embed_dim: usize, num_heads: usize) -> Result<Self> {
25 if embed_dim % num_heads != 0 {
26 return Err(MLError::InvalidConfiguration(
27 "embed_dim must be divisible by num_heads".to_string(),
28 ));
29 }
30
31 let head_dim = embed_dim / num_heads;
32 let scale = (1.0 / (embed_dim as f64)).sqrt();
33
34 let q_proj = ArrayD::from_shape_fn(IxDyn(&[embed_dim, embed_dim]), |_| {
35 (fastrand::f64() * 2.0 - 1.0) * scale
36 });
37 let k_proj = ArrayD::from_shape_fn(IxDyn(&[embed_dim, embed_dim]), |_| {
38 (fastrand::f64() * 2.0 - 1.0) * scale
39 });
40 let v_proj = ArrayD::from_shape_fn(IxDyn(&[embed_dim, embed_dim]), |_| {
41 (fastrand::f64() * 2.0 - 1.0) * scale
42 });
43 let out_proj = ArrayD::from_shape_fn(IxDyn(&[embed_dim, embed_dim]), |_| {
44 (fastrand::f64() * 2.0 - 1.0) * scale
45 });
46
47 Ok(Self {
48 embed_dim,
49 num_heads,
50 head_dim,
51 q_proj: Parameter::new(SciRS2Array::with_grad(q_proj), "q_proj"),
52 k_proj: Parameter::new(SciRS2Array::with_grad(k_proj), "k_proj"),
53 v_proj: Parameter::new(SciRS2Array::with_grad(v_proj), "v_proj"),
54 out_proj: Parameter::new(SciRS2Array::with_grad(out_proj), "out_proj"),
55 dropout: 0.0,
56 training: true,
57 })
58 }
59
60 pub fn dropout(mut self, dropout: f64) -> Self {
62 self.dropout = dropout;
63 self
64 }
65
66 pub fn forward_qkv(
68 &self,
69 query: &SciRS2Array,
70 key: &SciRS2Array,
71 value: &SciRS2Array,
72 attn_mask: Option<&ArrayD<f64>>,
73 ) -> Result<(SciRS2Array, SciRS2Array)> {
74 let shape = query.data.shape();
75 let (batch_size, seq_len, _) = (shape[0], shape[1], shape[2]);
76 let scale = (self.head_dim as f64).sqrt();
77
78 let mut q = ArrayD::zeros(IxDyn(&[batch_size, seq_len, self.embed_dim]));
80 let mut k = ArrayD::zeros(IxDyn(&[batch_size, seq_len, self.embed_dim]));
81 let mut v = ArrayD::zeros(IxDyn(&[batch_size, seq_len, self.embed_dim]));
82
83 for b in 0..batch_size {
85 for s in 0..seq_len {
86 for e_out in 0..self.embed_dim {
87 let mut q_sum = 0.0;
88 let mut k_sum = 0.0;
89 let mut v_sum = 0.0;
90 for e_in in 0..self.embed_dim {
91 q_sum += query.data[[b, s, e_in]] * self.q_proj.data.data[[e_out, e_in]];
92 k_sum += key.data[[b, s, e_in]] * self.k_proj.data.data[[e_out, e_in]];
93 v_sum += value.data[[b, s, e_in]] * self.v_proj.data.data[[e_out, e_in]];
94 }
95 q[[b, s, e_out]] = q_sum;
96 k[[b, s, e_out]] = k_sum;
97 v[[b, s, e_out]] = v_sum;
98 }
99 }
100 }
101
102 let mut attn_scores = ArrayD::zeros(IxDyn(&[batch_size, self.num_heads, seq_len, seq_len]));
104
105 for b in 0..batch_size {
106 for h in 0..self.num_heads {
107 for i in 0..seq_len {
108 for j in 0..seq_len {
109 let mut score = 0.0;
110 for d in 0..self.head_dim {
111 let q_idx = h * self.head_dim + d;
112 let k_idx = h * self.head_dim + d;
113 score += q[[b, i, q_idx]] * k[[b, j, k_idx]];
114 }
115 attn_scores[[b, h, i, j]] = score / scale;
116 }
117 }
118 }
119 }
120
121 if let Some(mask) = attn_mask {
123 for b in 0..batch_size {
124 for h in 0..self.num_heads {
125 for i in 0..seq_len {
126 for j in 0..seq_len {
127 if mask[[i, j]] == 0.0 {
128 attn_scores[[b, h, i, j]] = f64::NEG_INFINITY;
129 }
130 }
131 }
132 }
133 }
134 }
135
136 for b in 0..batch_size {
138 for h in 0..self.num_heads {
139 for i in 0..seq_len {
140 let max_score = (0..seq_len)
141 .map(|j| attn_scores[[b, h, i, j]])
142 .fold(f64::NEG_INFINITY, f64::max);
143 let mut sum_exp = 0.0;
144 for j in 0..seq_len {
145 attn_scores[[b, h, i, j]] = (attn_scores[[b, h, i, j]] - max_score).exp();
146 sum_exp += attn_scores[[b, h, i, j]];
147 }
148 for j in 0..seq_len {
149 attn_scores[[b, h, i, j]] /= sum_exp;
150 }
151 }
152 }
153 }
154
155 let mut attn_output = ArrayD::zeros(IxDyn(&[batch_size, seq_len, self.embed_dim]));
157
158 for b in 0..batch_size {
159 for h in 0..self.num_heads {
160 for i in 0..seq_len {
161 for d in 0..self.head_dim {
162 let mut sum = 0.0;
163 for j in 0..seq_len {
164 sum += attn_scores[[b, h, i, j]] * v[[b, j, h * self.head_dim + d]];
165 }
166 attn_output[[b, i, h * self.head_dim + d]] = sum;
167 }
168 }
169 }
170 }
171
172 let mut output = ArrayD::zeros(IxDyn(&[batch_size, seq_len, self.embed_dim]));
174 for b in 0..batch_size {
175 for s in 0..seq_len {
176 for e_out in 0..self.embed_dim {
177 let mut sum = 0.0;
178 for e_in in 0..self.embed_dim {
179 sum += attn_output[[b, s, e_in]] * self.out_proj.data.data[[e_out, e_in]];
180 }
181 output[[b, s, e_out]] = sum;
182 }
183 }
184 }
185
186 let mut avg_attn = ArrayD::zeros(IxDyn(&[batch_size, seq_len, seq_len]));
188 for b in 0..batch_size {
189 for i in 0..seq_len {
190 for j in 0..seq_len {
191 let mut sum = 0.0;
192 for h in 0..self.num_heads {
193 sum += attn_scores[[b, h, i, j]];
194 }
195 avg_attn[[b, i, j]] = sum / self.num_heads as f64;
196 }
197 }
198 }
199
200 Ok((
201 SciRS2Array::new(output, query.requires_grad),
202 SciRS2Array::new(avg_attn, false),
203 ))
204 }
205}
206
207impl QuantumModule for QuantumMultiheadAttention {
208 fn forward(&mut self, input: &SciRS2Array) -> Result<SciRS2Array> {
209 let (output, _) = self.forward_qkv(input, input, input, None)?;
211 Ok(output)
212 }
213
214 fn parameters(&self) -> Vec<Parameter> {
215 vec![
216 self.q_proj.clone(),
217 self.k_proj.clone(),
218 self.v_proj.clone(),
219 self.out_proj.clone(),
220 ]
221 }
222
223 fn train(&mut self, mode: bool) {
224 self.training = mode;
225 }
226
227 fn training(&self) -> bool {
228 self.training
229 }
230
231 fn zero_grad(&mut self) {
232 self.q_proj.data.zero_grad();
233 self.k_proj.data.zero_grad();
234 self.v_proj.data.zero_grad();
235 self.out_proj.data.zero_grad();
236 }
237
238 fn name(&self) -> &str {
239 "MultiheadAttention"
240 }
241}
242
243pub struct QuantumTransformerEncoderLayer {
245 self_attn: QuantumMultiheadAttention,
246 linear1: QuantumLinear,
247 linear2: QuantumLinear,
248 norm1: QuantumLayerNorm,
249 norm2: QuantumLayerNorm,
250 dropout: f64,
251 training: bool,
252}
253
254impl QuantumTransformerEncoderLayer {
255 pub fn new(d_model: usize, nhead: usize, dim_feedforward: usize) -> Result<Self> {
257 Ok(Self {
258 self_attn: QuantumMultiheadAttention::new(d_model, nhead)?,
259 linear1: QuantumLinear::new(d_model, dim_feedforward)?,
260 linear2: QuantumLinear::new(dim_feedforward, d_model)?,
261 norm1: QuantumLayerNorm::new(vec![d_model]),
262 norm2: QuantumLayerNorm::new(vec![d_model]),
263 dropout: 0.1,
264 training: true,
265 })
266 }
267
268 pub fn dropout(mut self, dropout: f64) -> Self {
270 self.dropout = dropout;
271 self
272 }
273}
274
275impl QuantumModule for QuantumTransformerEncoderLayer {
276 fn forward(&mut self, input: &SciRS2Array) -> Result<SciRS2Array> {
277 let attn_output = self.self_attn.forward(input)?;
279
280 let residual1 = SciRS2Array::new(&input.data + &attn_output.data, input.requires_grad);
282 let normed1 = self.norm1.forward(&residual1)?;
283
284 let ff_output = self.linear1.forward(&normed1)?;
286 let ff_activated =
287 SciRS2Array::new(ff_output.data.mapv(|x| x.max(0.0)), ff_output.requires_grad);
288 let ff_output2 = self.linear2.forward(&ff_activated)?;
289
290 let residual2 = SciRS2Array::new(&normed1.data + &ff_output2.data, input.requires_grad);
292 self.norm2.forward(&residual2)
293 }
294
295 fn parameters(&self) -> Vec<Parameter> {
296 let mut params = self.self_attn.parameters();
297 params.extend(self.linear1.parameters());
298 params.extend(self.linear2.parameters());
299 params.extend(self.norm1.parameters());
300 params.extend(self.norm2.parameters());
301 params
302 }
303
304 fn train(&mut self, mode: bool) {
305 self.training = mode;
306 self.self_attn.train(mode);
307 self.linear1.train(mode);
308 self.linear2.train(mode);
309 self.norm1.train(mode);
310 self.norm2.train(mode);
311 }
312
313 fn training(&self) -> bool {
314 self.training
315 }
316
317 fn zero_grad(&mut self) {
318 self.self_attn.zero_grad();
319 self.linear1.zero_grad();
320 self.linear2.zero_grad();
321 self.norm1.zero_grad();
322 self.norm2.zero_grad();
323 }
324
325 fn name(&self) -> &str {
326 "TransformerEncoderLayer"
327 }
328}
329
330pub struct PositionalEncoding {
332 d_model: usize,
333 max_len: usize,
334 dropout: f64,
335 encoding: ArrayD<f64>,
336 training: bool,
337}
338
339impl PositionalEncoding {
340 pub fn new(d_model: usize, max_len: usize) -> Self {
342 let mut encoding = ArrayD::zeros(IxDyn(&[max_len, d_model]));
343
344 for pos in 0..max_len {
345 for i in 0..d_model {
346 let angle = pos as f64 / 10000.0_f64.powf(2.0 * (i / 2) as f64 / d_model as f64);
347 encoding[[pos, i]] = if i % 2 == 0 { angle.sin() } else { angle.cos() };
348 }
349 }
350
351 Self {
352 d_model,
353 max_len,
354 dropout: 0.1,
355 encoding,
356 training: true,
357 }
358 }
359
360 pub fn dropout(mut self, dropout: f64) -> Self {
362 self.dropout = dropout;
363 self
364 }
365}
366
367impl QuantumModule for PositionalEncoding {
368 fn forward(&mut self, input: &SciRS2Array) -> Result<SciRS2Array> {
369 let shape = input.data.shape();
370 let seq_len = shape[1];
371
372 let mut output = input.data.clone();
373
374 for b in 0..shape[0] {
375 for s in 0..seq_len.min(self.max_len) {
376 for d in 0..self.d_model.min(shape[2]) {
377 output[[b, s, d]] += self.encoding[[s, d]];
378 }
379 }
380 }
381
382 Ok(SciRS2Array::new(output, input.requires_grad))
383 }
384
385 fn parameters(&self) -> Vec<Parameter> {
386 Vec::new() }
388
389 fn train(&mut self, mode: bool) {
390 self.training = mode;
391 }
392
393 fn training(&self) -> bool {
394 self.training
395 }
396
397 fn zero_grad(&mut self) {}
398
399 fn name(&self) -> &str {
400 "PositionalEncoding"
401 }
402}