quantrs2_ml/quantum_transformer/
functions.rs1use quantrs2_core::gate::{multi::*, single::*, GateOp};
6use scirs2_core::ndarray::{s, Array1, Array2, Array3, Array4, Axis};
7
8use super::types::{
9 ActivationType, PositionEncodingType, QuantumAttentionType, QuantumFeedForward,
10 QuantumMultiHeadAttention, QuantumPositionEncoding, QuantumTransformerConfig,
11};
12
13pub fn create_causal_mask(batch_size: usize, seq_len: usize) -> Array3<bool> {
15 let mut mask = Array3::from_elem((batch_size, seq_len, seq_len), false);
16 for batch_idx in 0..batch_size {
17 for i in 0..seq_len {
18 for j in (i + 1)..seq_len {
19 mask[[batch_idx, i, j]] = true;
20 }
21 }
22 }
23 mask
24}
25pub fn create_padding_mask(
27 batch_size: usize,
28 seq_len: usize,
29 actual_lengths: &[usize],
30) -> Array3<bool> {
31 let mut mask = Array3::from_elem((batch_size, seq_len, seq_len), false);
32 for (batch_idx, &actual_len) in actual_lengths.iter().enumerate() {
33 if batch_idx < batch_size {
34 for i in 0..seq_len {
35 for j in actual_len..seq_len {
36 mask[[batch_idx, i, j]] = true;
37 }
38 }
39 }
40 }
41 mask
42}
43#[cfg(test)]
44mod tests {
45 use super::*;
46 #[test]
47 fn test_quantum_transformer_config() {
48 let config = QuantumTransformerConfig::default();
49 assert_eq!(config.model_dim, 512);
50 assert_eq!(config.num_heads, 8);
51 assert_eq!(config.num_layers, 6);
52 let large_config = QuantumTransformerConfig::large();
53 assert_eq!(large_config.model_dim, 1024);
54 assert_eq!(large_config.num_heads, 16);
55 }
56 #[test]
57 fn test_quantum_multi_head_attention_creation() {
58 let attention = QuantumMultiHeadAttention::new(
59 8,
60 512,
61 QuantumAttentionType::HybridQuantumClassical,
62 10,
63 );
64 assert!(attention.is_ok());
65 let attn = attention.expect("Attention creation should succeed");
66 assert_eq!(attn.num_heads, 8);
67 assert_eq!(attn.model_dim, 512);
68 assert_eq!(attn.head_dim, 64);
69 }
70 #[test]
71 fn test_quantum_position_encoding() {
72 let pos_enc = QuantumPositionEncoding::new(PositionEncodingType::Sinusoidal, 256, 512, 8);
73 assert!(pos_enc.is_ok());
74 let pe = pos_enc.expect("Position encoding creation should succeed");
75 assert_eq!(pe.model_dim, 256);
76 assert_eq!(pe.max_seq_len, 512);
77 }
78 #[test]
79 fn test_quantum_feedforward() {
80 let ff = QuantumFeedForward::new(256, 1024, 256, 8, ActivationType::QuantumGELU, 0.1);
81 assert!(ff.is_ok());
82 let feedforward = ff.expect("Feedforward creation should succeed");
83 assert_eq!(feedforward.input_dim, 256);
84 assert_eq!(feedforward.hidden_dim, 1024);
85 assert_eq!(feedforward.output_dim, 256);
86 }
87 #[test]
88 fn test_causal_mask_creation() {
89 let mask = create_causal_mask(2, 4);
90 assert_eq!(mask.dim(), (2, 4, 4));
91 assert!(!mask[[0, 0, 0]]);
92 assert!(!mask[[0, 1, 0]]);
93 assert!(!mask[[0, 1, 1]]);
94 assert!(mask[[0, 0, 1]]);
95 assert!(mask[[0, 0, 2]]);
96 assert!(mask[[0, 1, 2]]);
97 }
98 #[test]
99 fn test_padding_mask_creation() {
100 let actual_lengths = vec![3, 2];
101 let mask = create_padding_mask(2, 4, &actual_lengths);
102 assert!(!mask[[0, 0, 2]]);
103 assert!(mask[[0, 0, 3]]);
104 assert!(!mask[[1, 0, 1]]);
105 assert!(mask[[1, 0, 2]]);
106 assert!(mask[[1, 0, 3]]);
107 }
108}