quantrs2_ml/keras_api/
attention.rs

1//! Attention layers for Keras-like API
2
3use super::KerasLayer;
4use crate::error::{MLError, Result};
5use scirs2_core::ndarray::{ArrayD, IxDyn};
6
7/// Multi-head attention layer (Keras-compatible)
8pub struct MultiHeadAttention {
9    /// Number of heads
10    num_heads: usize,
11    /// Key dimension
12    key_dim: usize,
13    /// Value dimension
14    value_dim: usize,
15    /// Dropout
16    dropout: f64,
17    /// Use bias
18    use_bias: bool,
19    /// Query projection weights
20    query_weights: Option<ArrayD<f64>>,
21    /// Key projection weights
22    key_weights: Option<ArrayD<f64>>,
23    /// Value projection weights
24    value_weights: Option<ArrayD<f64>>,
25    /// Output projection weights
26    output_weights: Option<ArrayD<f64>>,
27    /// Built flag
28    built: bool,
29    /// Layer name
30    layer_name: Option<String>,
31}
32
33impl MultiHeadAttention {
34    /// Create new MultiHeadAttention
35    pub fn new(num_heads: usize, key_dim: usize) -> Self {
36        Self {
37            num_heads,
38            key_dim,
39            value_dim: key_dim,
40            dropout: 0.0,
41            use_bias: true,
42            query_weights: None,
43            key_weights: None,
44            value_weights: None,
45            output_weights: None,
46            built: false,
47            layer_name: None,
48        }
49    }
50
51    /// Set value dimension
52    pub fn value_dim(mut self, value_dim: usize) -> Self {
53        self.value_dim = value_dim;
54        self
55    }
56
57    /// Set dropout
58    pub fn dropout(mut self, dropout: f64) -> Self {
59        self.dropout = dropout;
60        self
61    }
62
63    /// Set use bias
64    pub fn use_bias(mut self, use_bias: bool) -> Self {
65        self.use_bias = use_bias;
66        self
67    }
68
69    /// Set layer name
70    pub fn name(mut self, name: &str) -> Self {
71        self.layer_name = Some(name.to_string());
72        self
73    }
74
75    /// Forward with query, key, value
76    pub fn call_with_qkv(
77        &mut self,
78        query: &ArrayD<f64>,
79        key: &ArrayD<f64>,
80        value: &ArrayD<f64>,
81    ) -> Result<ArrayD<f64>> {
82        if !self.built {
83            return Err(MLError::ModelNotTrained(
84                "Layer not built. Call build() first.".to_string(),
85            ));
86        }
87
88        let q_weights = self
89            .query_weights
90            .as_ref()
91            .ok_or_else(|| MLError::ModelNotTrained("Query weights not initialized".to_string()))?;
92        let k_weights = self
93            .key_weights
94            .as_ref()
95            .ok_or_else(|| MLError::ModelNotTrained("Key weights not initialized".to_string()))?;
96        let v_weights = self
97            .value_weights
98            .as_ref()
99            .ok_or_else(|| MLError::ModelNotTrained("Value weights not initialized".to_string()))?;
100        let out_weights = self.output_weights.as_ref().ok_or_else(|| {
101            MLError::ModelNotTrained("Output weights not initialized".to_string())
102        })?;
103
104        let shape = query.shape();
105        let (batch_size, seq_len, embed_dim) = (shape[0], shape[1], shape[2]);
106        let head_dim = self.key_dim;
107        let scale = (head_dim as f64).sqrt();
108
109        let total_dim = self.num_heads * head_dim;
110        let mut q: ArrayD<f64> = ArrayD::zeros(IxDyn(&[batch_size, seq_len, total_dim]));
111        let mut k: ArrayD<f64> = ArrayD::zeros(IxDyn(&[batch_size, seq_len, total_dim]));
112        let mut v: ArrayD<f64> = ArrayD::zeros(IxDyn(&[batch_size, seq_len, total_dim]));
113
114        for b in 0..batch_size {
115            for s in 0..seq_len {
116                for o in 0..total_dim.min(q_weights.shape()[1]) {
117                    let mut q_sum: f64 = 0.0;
118                    let mut k_sum: f64 = 0.0;
119                    let mut v_sum: f64 = 0.0;
120                    for i in 0..embed_dim.min(q_weights.shape()[0]) {
121                        q_sum += query[[b, s, i]] * q_weights[[i, o]];
122                        k_sum += key[[b, s, i]] * k_weights[[i, o]];
123                        v_sum += value[[b, s, i]] * v_weights[[i, o]];
124                    }
125                    q[[b, s, o]] = q_sum;
126                    k[[b, s, o]] = k_sum;
127                    v[[b, s, o]] = v_sum;
128                }
129            }
130        }
131
132        let mut attn: ArrayD<f64> =
133            ArrayD::zeros(IxDyn(&[batch_size, self.num_heads, seq_len, seq_len]));
134
135        for b in 0..batch_size {
136            for h in 0..self.num_heads {
137                for i in 0..seq_len {
138                    for j in 0..seq_len {
139                        let mut score: f64 = 0.0;
140                        for d in 0..head_dim {
141                            score += q[[b, i, h * head_dim + d]] * k[[b, j, h * head_dim + d]];
142                        }
143                        attn[[b, h, i, j]] = score / scale;
144                    }
145                }
146
147                for i in 0..seq_len {
148                    let max_score = (0..seq_len)
149                        .map(|j| attn[[b, h, i, j]])
150                        .fold(f64::NEG_INFINITY, f64::max);
151                    let mut sum_exp: f64 = 0.0;
152                    for j in 0..seq_len {
153                        attn[[b, h, i, j]] = (attn[[b, h, i, j]] - max_score).exp();
154                        sum_exp += attn[[b, h, i, j]];
155                    }
156                    for j in 0..seq_len {
157                        attn[[b, h, i, j]] /= sum_exp;
158                    }
159                }
160            }
161        }
162
163        let mut context: ArrayD<f64> = ArrayD::zeros(IxDyn(&[batch_size, seq_len, total_dim]));
164        for b in 0..batch_size {
165            for h in 0..self.num_heads {
166                for i in 0..seq_len {
167                    for d in 0..head_dim {
168                        let mut sum: f64 = 0.0;
169                        for j in 0..seq_len {
170                            sum += attn[[b, h, i, j]] * v[[b, j, h * head_dim + d]];
171                        }
172                        context[[b, i, h * head_dim + d]] = sum;
173                    }
174                }
175            }
176        }
177
178        let mut output: ArrayD<f64> = ArrayD::zeros(IxDyn(&[batch_size, seq_len, embed_dim]));
179        for b in 0..batch_size {
180            for s in 0..seq_len {
181                for o in 0..embed_dim.min(out_weights.shape()[1]) {
182                    let mut out_sum: f64 = 0.0;
183                    for i in 0..total_dim.min(out_weights.shape()[0]) {
184                        out_sum += context[[b, s, i]] * out_weights[[i, o]];
185                    }
186                    output[[b, s, o]] = out_sum;
187                }
188            }
189        }
190
191        Ok(output)
192    }
193}
194
195impl KerasLayer for MultiHeadAttention {
196    fn call(&self, input: &ArrayD<f64>) -> Result<ArrayD<f64>> {
197        if !self.built {
198            return Err(MLError::ModelNotTrained(
199                "Layer not built. Call build() first.".to_string(),
200            ));
201        }
202
203        let q_weights = self
204            .query_weights
205            .as_ref()
206            .ok_or_else(|| MLError::ModelNotTrained("Query weights not initialized".to_string()))?;
207        let k_weights = self
208            .key_weights
209            .as_ref()
210            .ok_or_else(|| MLError::ModelNotTrained("Key weights not initialized".to_string()))?;
211        let v_weights = self
212            .value_weights
213            .as_ref()
214            .ok_or_else(|| MLError::ModelNotTrained("Value weights not initialized".to_string()))?;
215        let out_weights = self.output_weights.as_ref().ok_or_else(|| {
216            MLError::ModelNotTrained("Output weights not initialized".to_string())
217        })?;
218
219        let shape = input.shape();
220        let (batch_size, seq_len, embed_dim) = (shape[0], shape[1], shape[2]);
221        let head_dim = self.key_dim;
222        let scale = (head_dim as f64).sqrt();
223
224        let total_dim = self.num_heads * head_dim;
225        let mut q: ArrayD<f64> = ArrayD::zeros(IxDyn(&[batch_size, seq_len, total_dim]));
226        let mut k: ArrayD<f64> = ArrayD::zeros(IxDyn(&[batch_size, seq_len, total_dim]));
227        let mut v: ArrayD<f64> = ArrayD::zeros(IxDyn(&[batch_size, seq_len, total_dim]));
228
229        for b in 0..batch_size {
230            for s in 0..seq_len {
231                for o in 0..total_dim.min(q_weights.shape()[1]) {
232                    let mut q_sum: f64 = 0.0;
233                    let mut k_sum: f64 = 0.0;
234                    let mut v_sum: f64 = 0.0;
235                    for i in 0..embed_dim.min(q_weights.shape()[0]) {
236                        q_sum += input[[b, s, i]] * q_weights[[i, o]];
237                        k_sum += input[[b, s, i]] * k_weights[[i, o]];
238                        v_sum += input[[b, s, i]] * v_weights[[i, o]];
239                    }
240                    q[[b, s, o]] = q_sum;
241                    k[[b, s, o]] = k_sum;
242                    v[[b, s, o]] = v_sum;
243                }
244            }
245        }
246
247        let mut attn: ArrayD<f64> =
248            ArrayD::zeros(IxDyn(&[batch_size, self.num_heads, seq_len, seq_len]));
249
250        for b in 0..batch_size {
251            for h in 0..self.num_heads {
252                for i in 0..seq_len {
253                    for j in 0..seq_len {
254                        let mut score: f64 = 0.0;
255                        for d in 0..head_dim {
256                            score += q[[b, i, h * head_dim + d]] * k[[b, j, h * head_dim + d]];
257                        }
258                        attn[[b, h, i, j]] = score / scale;
259                    }
260                }
261
262                for i in 0..seq_len {
263                    let max_score = (0..seq_len)
264                        .map(|j| attn[[b, h, i, j]])
265                        .fold(f64::NEG_INFINITY, f64::max);
266                    let mut sum_exp: f64 = 0.0;
267                    for j in 0..seq_len {
268                        attn[[b, h, i, j]] = (attn[[b, h, i, j]] - max_score).exp();
269                        sum_exp += attn[[b, h, i, j]];
270                    }
271                    for j in 0..seq_len {
272                        attn[[b, h, i, j]] /= sum_exp;
273                    }
274                }
275            }
276        }
277
278        let mut context: ArrayD<f64> = ArrayD::zeros(IxDyn(&[batch_size, seq_len, total_dim]));
279        for b in 0..batch_size {
280            for h in 0..self.num_heads {
281                for i in 0..seq_len {
282                    for d in 0..head_dim {
283                        let mut sum: f64 = 0.0;
284                        for j in 0..seq_len {
285                            sum += attn[[b, h, i, j]] * v[[b, j, h * head_dim + d]];
286                        }
287                        context[[b, i, h * head_dim + d]] = sum;
288                    }
289                }
290            }
291        }
292
293        let mut output: ArrayD<f64> = ArrayD::zeros(IxDyn(&[batch_size, seq_len, embed_dim]));
294        for b in 0..batch_size {
295            for s in 0..seq_len {
296                for o in 0..embed_dim.min(out_weights.shape()[1]) {
297                    let mut out_sum: f64 = 0.0;
298                    for i in 0..total_dim.min(out_weights.shape()[0]) {
299                        out_sum += context[[b, s, i]] * out_weights[[i, o]];
300                    }
301                    output[[b, s, o]] = out_sum;
302                }
303            }
304        }
305
306        Ok(output)
307    }
308
309    fn build(&mut self, input_shape: &[usize]) -> Result<()> {
310        let embed_dim = *input_shape
311            .last()
312            .ok_or_else(|| MLError::InvalidConfiguration("Invalid input shape".to_string()))?;
313
314        let total_dim = self.num_heads * self.key_dim;
315        let scale = (2.0 / (embed_dim + total_dim) as f64).sqrt();
316
317        let query_weights = ArrayD::from_shape_fn(IxDyn(&[embed_dim, total_dim]), |_| {
318            (fastrand::f64() * 2.0 - 1.0) * scale
319        });
320        let key_weights = ArrayD::from_shape_fn(IxDyn(&[embed_dim, total_dim]), |_| {
321            (fastrand::f64() * 2.0 - 1.0) * scale
322        });
323        let value_weights = ArrayD::from_shape_fn(IxDyn(&[embed_dim, total_dim]), |_| {
324            (fastrand::f64() * 2.0 - 1.0) * scale
325        });
326        let output_weights = ArrayD::from_shape_fn(IxDyn(&[total_dim, embed_dim]), |_| {
327            (fastrand::f64() * 2.0 - 1.0) * scale
328        });
329
330        self.query_weights = Some(query_weights);
331        self.key_weights = Some(key_weights);
332        self.value_weights = Some(value_weights);
333        self.output_weights = Some(output_weights);
334        self.built = true;
335
336        Ok(())
337    }
338
339    fn compute_output_shape(&self, input_shape: &[usize]) -> Vec<usize> {
340        input_shape.to_vec()
341    }
342
343    fn count_params(&self) -> usize {
344        let q = self.query_weights.as_ref().map_or(0, |w| w.len());
345        let k = self.key_weights.as_ref().map_or(0, |w| w.len());
346        let v = self.value_weights.as_ref().map_or(0, |w| w.len());
347        let o = self.output_weights.as_ref().map_or(0, |w| w.len());
348        q + k + v + o
349    }
350
351    fn get_weights(&self) -> Vec<ArrayD<f64>> {
352        let mut weights = vec![];
353        if let Some(ref w) = self.query_weights {
354            weights.push(w.clone());
355        }
356        if let Some(ref w) = self.key_weights {
357            weights.push(w.clone());
358        }
359        if let Some(ref w) = self.value_weights {
360            weights.push(w.clone());
361        }
362        if let Some(ref w) = self.output_weights {
363            weights.push(w.clone());
364        }
365        weights
366    }
367
368    fn set_weights(&mut self, weights: Vec<ArrayD<f64>>) -> Result<()> {
369        if weights.len() >= 4 {
370            self.query_weights = Some(weights[0].clone());
371            self.key_weights = Some(weights[1].clone());
372            self.value_weights = Some(weights[2].clone());
373            self.output_weights = Some(weights[3].clone());
374        }
375        Ok(())
376    }
377
378    fn built(&self) -> bool {
379        self.built
380    }
381
382    fn name(&self) -> &str {
383        self.layer_name.as_deref().unwrap_or("multi_head_attention")
384    }
385}
386
387/// Embedding layer
388pub struct Embedding {
389    /// Input dimension (vocabulary size)
390    input_dim: usize,
391    /// Output dimension (embedding size)
392    output_dim: usize,
393    /// Embedding weights
394    embeddings: Option<ArrayD<f64>>,
395    /// Mask zero
396    mask_zero: bool,
397    /// Built flag
398    built: bool,
399    /// Layer name
400    layer_name: Option<String>,
401}
402
403impl Embedding {
404    /// Create new Embedding layer
405    pub fn new(input_dim: usize, output_dim: usize) -> Self {
406        Self {
407            input_dim,
408            output_dim,
409            embeddings: None,
410            mask_zero: false,
411            built: false,
412            layer_name: None,
413        }
414    }
415
416    /// Set mask zero
417    pub fn mask_zero(mut self, mask_zero: bool) -> Self {
418        self.mask_zero = mask_zero;
419        self
420    }
421
422    /// Set layer name
423    pub fn name(mut self, name: &str) -> Self {
424        self.layer_name = Some(name.to_string());
425        self
426    }
427}
428
429impl KerasLayer for Embedding {
430    fn call(&self, input: &ArrayD<f64>) -> Result<ArrayD<f64>> {
431        if !self.built {
432            return Err(MLError::ModelNotTrained(
433                "Layer not built. Call build() first.".to_string(),
434            ));
435        }
436
437        let embeddings = self
438            .embeddings
439            .as_ref()
440            .ok_or_else(|| MLError::ModelNotTrained("Embeddings not initialized".to_string()))?;
441
442        let shape = input.shape();
443        let batch_size = shape[0];
444        let seq_len = *shape.get(1).unwrap_or(&1);
445
446        let mut output = ArrayD::zeros(IxDyn(&[batch_size, seq_len, self.output_dim]));
447
448        for b in 0..batch_size {
449            for s in 0..seq_len {
450                let idx = input[[b, s]] as usize;
451                if idx < self.input_dim {
452                    for d in 0..self.output_dim {
453                        output[[b, s, d]] = embeddings[[idx, d]];
454                    }
455                }
456            }
457        }
458
459        Ok(output)
460    }
461
462    fn build(&mut self, _input_shape: &[usize]) -> Result<()> {
463        let scale = (1.0 / self.input_dim as f64).sqrt();
464        let embeddings = ArrayD::from_shape_fn(IxDyn(&[self.input_dim, self.output_dim]), |_| {
465            (fastrand::f64() * 2.0 - 1.0) * scale
466        });
467
468        self.embeddings = Some(embeddings);
469        self.built = true;
470
471        Ok(())
472    }
473
474    fn compute_output_shape(&self, input_shape: &[usize]) -> Vec<usize> {
475        let mut out_shape = input_shape.to_vec();
476        out_shape.push(self.output_dim);
477        out_shape
478    }
479
480    fn count_params(&self) -> usize {
481        self.input_dim * self.output_dim
482    }
483
484    fn get_weights(&self) -> Vec<ArrayD<f64>> {
485        self.embeddings.iter().cloned().collect()
486    }
487
488    fn set_weights(&mut self, weights: Vec<ArrayD<f64>>) -> Result<()> {
489        if !weights.is_empty() {
490            self.embeddings = Some(weights[0].clone());
491        }
492        Ok(())
493    }
494
495    fn built(&self) -> bool {
496        self.built
497    }
498
499    fn name(&self) -> &str {
500        self.layer_name.as_deref().unwrap_or("embedding")
501    }
502}