1use super::KerasLayer;
4use crate::error::{MLError, Result};
5use scirs2_core::ndarray::{ArrayD, IxDyn};
6
7pub struct MultiHeadAttention {
9 num_heads: usize,
11 key_dim: usize,
13 value_dim: usize,
15 dropout: f64,
17 use_bias: bool,
19 query_weights: Option<ArrayD<f64>>,
21 key_weights: Option<ArrayD<f64>>,
23 value_weights: Option<ArrayD<f64>>,
25 output_weights: Option<ArrayD<f64>>,
27 built: bool,
29 layer_name: Option<String>,
31}
32
33impl MultiHeadAttention {
34 pub fn new(num_heads: usize, key_dim: usize) -> Self {
36 Self {
37 num_heads,
38 key_dim,
39 value_dim: key_dim,
40 dropout: 0.0,
41 use_bias: true,
42 query_weights: None,
43 key_weights: None,
44 value_weights: None,
45 output_weights: None,
46 built: false,
47 layer_name: None,
48 }
49 }
50
51 pub fn value_dim(mut self, value_dim: usize) -> Self {
53 self.value_dim = value_dim;
54 self
55 }
56
57 pub fn dropout(mut self, dropout: f64) -> Self {
59 self.dropout = dropout;
60 self
61 }
62
63 pub fn use_bias(mut self, use_bias: bool) -> Self {
65 self.use_bias = use_bias;
66 self
67 }
68
69 pub fn name(mut self, name: &str) -> Self {
71 self.layer_name = Some(name.to_string());
72 self
73 }
74
75 pub fn call_with_qkv(
77 &mut self,
78 query: &ArrayD<f64>,
79 key: &ArrayD<f64>,
80 value: &ArrayD<f64>,
81 ) -> Result<ArrayD<f64>> {
82 if !self.built {
83 return Err(MLError::ModelNotTrained(
84 "Layer not built. Call build() first.".to_string(),
85 ));
86 }
87
88 let q_weights = self
89 .query_weights
90 .as_ref()
91 .ok_or_else(|| MLError::ModelNotTrained("Query weights not initialized".to_string()))?;
92 let k_weights = self
93 .key_weights
94 .as_ref()
95 .ok_or_else(|| MLError::ModelNotTrained("Key weights not initialized".to_string()))?;
96 let v_weights = self
97 .value_weights
98 .as_ref()
99 .ok_or_else(|| MLError::ModelNotTrained("Value weights not initialized".to_string()))?;
100 let out_weights = self.output_weights.as_ref().ok_or_else(|| {
101 MLError::ModelNotTrained("Output weights not initialized".to_string())
102 })?;
103
104 let shape = query.shape();
105 let (batch_size, seq_len, embed_dim) = (shape[0], shape[1], shape[2]);
106 let head_dim = self.key_dim;
107 let scale = (head_dim as f64).sqrt();
108
109 let total_dim = self.num_heads * head_dim;
110 let mut q: ArrayD<f64> = ArrayD::zeros(IxDyn(&[batch_size, seq_len, total_dim]));
111 let mut k: ArrayD<f64> = ArrayD::zeros(IxDyn(&[batch_size, seq_len, total_dim]));
112 let mut v: ArrayD<f64> = ArrayD::zeros(IxDyn(&[batch_size, seq_len, total_dim]));
113
114 for b in 0..batch_size {
115 for s in 0..seq_len {
116 for o in 0..total_dim.min(q_weights.shape()[1]) {
117 let mut q_sum: f64 = 0.0;
118 let mut k_sum: f64 = 0.0;
119 let mut v_sum: f64 = 0.0;
120 for i in 0..embed_dim.min(q_weights.shape()[0]) {
121 q_sum += query[[b, s, i]] * q_weights[[i, o]];
122 k_sum += key[[b, s, i]] * k_weights[[i, o]];
123 v_sum += value[[b, s, i]] * v_weights[[i, o]];
124 }
125 q[[b, s, o]] = q_sum;
126 k[[b, s, o]] = k_sum;
127 v[[b, s, o]] = v_sum;
128 }
129 }
130 }
131
132 let mut attn: ArrayD<f64> =
133 ArrayD::zeros(IxDyn(&[batch_size, self.num_heads, seq_len, seq_len]));
134
135 for b in 0..batch_size {
136 for h in 0..self.num_heads {
137 for i in 0..seq_len {
138 for j in 0..seq_len {
139 let mut score: f64 = 0.0;
140 for d in 0..head_dim {
141 score += q[[b, i, h * head_dim + d]] * k[[b, j, h * head_dim + d]];
142 }
143 attn[[b, h, i, j]] = score / scale;
144 }
145 }
146
147 for i in 0..seq_len {
148 let max_score = (0..seq_len)
149 .map(|j| attn[[b, h, i, j]])
150 .fold(f64::NEG_INFINITY, f64::max);
151 let mut sum_exp: f64 = 0.0;
152 for j in 0..seq_len {
153 attn[[b, h, i, j]] = (attn[[b, h, i, j]] - max_score).exp();
154 sum_exp += attn[[b, h, i, j]];
155 }
156 for j in 0..seq_len {
157 attn[[b, h, i, j]] /= sum_exp;
158 }
159 }
160 }
161 }
162
163 let mut context: ArrayD<f64> = ArrayD::zeros(IxDyn(&[batch_size, seq_len, total_dim]));
164 for b in 0..batch_size {
165 for h in 0..self.num_heads {
166 for i in 0..seq_len {
167 for d in 0..head_dim {
168 let mut sum: f64 = 0.0;
169 for j in 0..seq_len {
170 sum += attn[[b, h, i, j]] * v[[b, j, h * head_dim + d]];
171 }
172 context[[b, i, h * head_dim + d]] = sum;
173 }
174 }
175 }
176 }
177
178 let mut output: ArrayD<f64> = ArrayD::zeros(IxDyn(&[batch_size, seq_len, embed_dim]));
179 for b in 0..batch_size {
180 for s in 0..seq_len {
181 for o in 0..embed_dim.min(out_weights.shape()[1]) {
182 let mut out_sum: f64 = 0.0;
183 for i in 0..total_dim.min(out_weights.shape()[0]) {
184 out_sum += context[[b, s, i]] * out_weights[[i, o]];
185 }
186 output[[b, s, o]] = out_sum;
187 }
188 }
189 }
190
191 Ok(output)
192 }
193}
194
195impl KerasLayer for MultiHeadAttention {
196 fn call(&self, input: &ArrayD<f64>) -> Result<ArrayD<f64>> {
197 if !self.built {
198 return Err(MLError::ModelNotTrained(
199 "Layer not built. Call build() first.".to_string(),
200 ));
201 }
202
203 let q_weights = self
204 .query_weights
205 .as_ref()
206 .ok_or_else(|| MLError::ModelNotTrained("Query weights not initialized".to_string()))?;
207 let k_weights = self
208 .key_weights
209 .as_ref()
210 .ok_or_else(|| MLError::ModelNotTrained("Key weights not initialized".to_string()))?;
211 let v_weights = self
212 .value_weights
213 .as_ref()
214 .ok_or_else(|| MLError::ModelNotTrained("Value weights not initialized".to_string()))?;
215 let out_weights = self.output_weights.as_ref().ok_or_else(|| {
216 MLError::ModelNotTrained("Output weights not initialized".to_string())
217 })?;
218
219 let shape = input.shape();
220 let (batch_size, seq_len, embed_dim) = (shape[0], shape[1], shape[2]);
221 let head_dim = self.key_dim;
222 let scale = (head_dim as f64).sqrt();
223
224 let total_dim = self.num_heads * head_dim;
225 let mut q: ArrayD<f64> = ArrayD::zeros(IxDyn(&[batch_size, seq_len, total_dim]));
226 let mut k: ArrayD<f64> = ArrayD::zeros(IxDyn(&[batch_size, seq_len, total_dim]));
227 let mut v: ArrayD<f64> = ArrayD::zeros(IxDyn(&[batch_size, seq_len, total_dim]));
228
229 for b in 0..batch_size {
230 for s in 0..seq_len {
231 for o in 0..total_dim.min(q_weights.shape()[1]) {
232 let mut q_sum: f64 = 0.0;
233 let mut k_sum: f64 = 0.0;
234 let mut v_sum: f64 = 0.0;
235 for i in 0..embed_dim.min(q_weights.shape()[0]) {
236 q_sum += input[[b, s, i]] * q_weights[[i, o]];
237 k_sum += input[[b, s, i]] * k_weights[[i, o]];
238 v_sum += input[[b, s, i]] * v_weights[[i, o]];
239 }
240 q[[b, s, o]] = q_sum;
241 k[[b, s, o]] = k_sum;
242 v[[b, s, o]] = v_sum;
243 }
244 }
245 }
246
247 let mut attn: ArrayD<f64> =
248 ArrayD::zeros(IxDyn(&[batch_size, self.num_heads, seq_len, seq_len]));
249
250 for b in 0..batch_size {
251 for h in 0..self.num_heads {
252 for i in 0..seq_len {
253 for j in 0..seq_len {
254 let mut score: f64 = 0.0;
255 for d in 0..head_dim {
256 score += q[[b, i, h * head_dim + d]] * k[[b, j, h * head_dim + d]];
257 }
258 attn[[b, h, i, j]] = score / scale;
259 }
260 }
261
262 for i in 0..seq_len {
263 let max_score = (0..seq_len)
264 .map(|j| attn[[b, h, i, j]])
265 .fold(f64::NEG_INFINITY, f64::max);
266 let mut sum_exp: f64 = 0.0;
267 for j in 0..seq_len {
268 attn[[b, h, i, j]] = (attn[[b, h, i, j]] - max_score).exp();
269 sum_exp += attn[[b, h, i, j]];
270 }
271 for j in 0..seq_len {
272 attn[[b, h, i, j]] /= sum_exp;
273 }
274 }
275 }
276 }
277
278 let mut context: ArrayD<f64> = ArrayD::zeros(IxDyn(&[batch_size, seq_len, total_dim]));
279 for b in 0..batch_size {
280 for h in 0..self.num_heads {
281 for i in 0..seq_len {
282 for d in 0..head_dim {
283 let mut sum: f64 = 0.0;
284 for j in 0..seq_len {
285 sum += attn[[b, h, i, j]] * v[[b, j, h * head_dim + d]];
286 }
287 context[[b, i, h * head_dim + d]] = sum;
288 }
289 }
290 }
291 }
292
293 let mut output: ArrayD<f64> = ArrayD::zeros(IxDyn(&[batch_size, seq_len, embed_dim]));
294 for b in 0..batch_size {
295 for s in 0..seq_len {
296 for o in 0..embed_dim.min(out_weights.shape()[1]) {
297 let mut out_sum: f64 = 0.0;
298 for i in 0..total_dim.min(out_weights.shape()[0]) {
299 out_sum += context[[b, s, i]] * out_weights[[i, o]];
300 }
301 output[[b, s, o]] = out_sum;
302 }
303 }
304 }
305
306 Ok(output)
307 }
308
309 fn build(&mut self, input_shape: &[usize]) -> Result<()> {
310 let embed_dim = *input_shape
311 .last()
312 .ok_or_else(|| MLError::InvalidConfiguration("Invalid input shape".to_string()))?;
313
314 let total_dim = self.num_heads * self.key_dim;
315 let scale = (2.0 / (embed_dim + total_dim) as f64).sqrt();
316
317 let query_weights = ArrayD::from_shape_fn(IxDyn(&[embed_dim, total_dim]), |_| {
318 (fastrand::f64() * 2.0 - 1.0) * scale
319 });
320 let key_weights = ArrayD::from_shape_fn(IxDyn(&[embed_dim, total_dim]), |_| {
321 (fastrand::f64() * 2.0 - 1.0) * scale
322 });
323 let value_weights = ArrayD::from_shape_fn(IxDyn(&[embed_dim, total_dim]), |_| {
324 (fastrand::f64() * 2.0 - 1.0) * scale
325 });
326 let output_weights = ArrayD::from_shape_fn(IxDyn(&[total_dim, embed_dim]), |_| {
327 (fastrand::f64() * 2.0 - 1.0) * scale
328 });
329
330 self.query_weights = Some(query_weights);
331 self.key_weights = Some(key_weights);
332 self.value_weights = Some(value_weights);
333 self.output_weights = Some(output_weights);
334 self.built = true;
335
336 Ok(())
337 }
338
339 fn compute_output_shape(&self, input_shape: &[usize]) -> Vec<usize> {
340 input_shape.to_vec()
341 }
342
343 fn count_params(&self) -> usize {
344 let q = self.query_weights.as_ref().map_or(0, |w| w.len());
345 let k = self.key_weights.as_ref().map_or(0, |w| w.len());
346 let v = self.value_weights.as_ref().map_or(0, |w| w.len());
347 let o = self.output_weights.as_ref().map_or(0, |w| w.len());
348 q + k + v + o
349 }
350
351 fn get_weights(&self) -> Vec<ArrayD<f64>> {
352 let mut weights = vec![];
353 if let Some(ref w) = self.query_weights {
354 weights.push(w.clone());
355 }
356 if let Some(ref w) = self.key_weights {
357 weights.push(w.clone());
358 }
359 if let Some(ref w) = self.value_weights {
360 weights.push(w.clone());
361 }
362 if let Some(ref w) = self.output_weights {
363 weights.push(w.clone());
364 }
365 weights
366 }
367
368 fn set_weights(&mut self, weights: Vec<ArrayD<f64>>) -> Result<()> {
369 if weights.len() >= 4 {
370 self.query_weights = Some(weights[0].clone());
371 self.key_weights = Some(weights[1].clone());
372 self.value_weights = Some(weights[2].clone());
373 self.output_weights = Some(weights[3].clone());
374 }
375 Ok(())
376 }
377
378 fn built(&self) -> bool {
379 self.built
380 }
381
382 fn name(&self) -> &str {
383 self.layer_name.as_deref().unwrap_or("multi_head_attention")
384 }
385}
386
387pub struct Embedding {
389 input_dim: usize,
391 output_dim: usize,
393 embeddings: Option<ArrayD<f64>>,
395 mask_zero: bool,
397 built: bool,
399 layer_name: Option<String>,
401}
402
403impl Embedding {
404 pub fn new(input_dim: usize, output_dim: usize) -> Self {
406 Self {
407 input_dim,
408 output_dim,
409 embeddings: None,
410 mask_zero: false,
411 built: false,
412 layer_name: None,
413 }
414 }
415
416 pub fn mask_zero(mut self, mask_zero: bool) -> Self {
418 self.mask_zero = mask_zero;
419 self
420 }
421
422 pub fn name(mut self, name: &str) -> Self {
424 self.layer_name = Some(name.to_string());
425 self
426 }
427}
428
429impl KerasLayer for Embedding {
430 fn call(&self, input: &ArrayD<f64>) -> Result<ArrayD<f64>> {
431 if !self.built {
432 return Err(MLError::ModelNotTrained(
433 "Layer not built. Call build() first.".to_string(),
434 ));
435 }
436
437 let embeddings = self
438 .embeddings
439 .as_ref()
440 .ok_or_else(|| MLError::ModelNotTrained("Embeddings not initialized".to_string()))?;
441
442 let shape = input.shape();
443 let batch_size = shape[0];
444 let seq_len = *shape.get(1).unwrap_or(&1);
445
446 let mut output = ArrayD::zeros(IxDyn(&[batch_size, seq_len, self.output_dim]));
447
448 for b in 0..batch_size {
449 for s in 0..seq_len {
450 let idx = input[[b, s]] as usize;
451 if idx < self.input_dim {
452 for d in 0..self.output_dim {
453 output[[b, s, d]] = embeddings[[idx, d]];
454 }
455 }
456 }
457 }
458
459 Ok(output)
460 }
461
462 fn build(&mut self, _input_shape: &[usize]) -> Result<()> {
463 let scale = (1.0 / self.input_dim as f64).sqrt();
464 let embeddings = ArrayD::from_shape_fn(IxDyn(&[self.input_dim, self.output_dim]), |_| {
465 (fastrand::f64() * 2.0 - 1.0) * scale
466 });
467
468 self.embeddings = Some(embeddings);
469 self.built = true;
470
471 Ok(())
472 }
473
474 fn compute_output_shape(&self, input_shape: &[usize]) -> Vec<usize> {
475 let mut out_shape = input_shape.to_vec();
476 out_shape.push(self.output_dim);
477 out_shape
478 }
479
480 fn count_params(&self) -> usize {
481 self.input_dim * self.output_dim
482 }
483
484 fn get_weights(&self) -> Vec<ArrayD<f64>> {
485 self.embeddings.iter().cloned().collect()
486 }
487
488 fn set_weights(&mut self, weights: Vec<ArrayD<f64>>) -> Result<()> {
489 if !weights.is_empty() {
490 self.embeddings = Some(weights[0].clone());
491 }
492 Ok(())
493 }
494
495 fn built(&self) -> bool {
496 self.built
497 }
498
499 fn name(&self) -> &str {
500 self.layer_name.as_deref().unwrap_or("embedding")
501 }
502}