Skip to main content

quantrs2_ml/quantum_transformer/
functions.rs

1//! Auto-generated module
2//!
3//! 🤖 Generated with [SplitRS](https://github.com/cool-japan/splitrs)
4
5use quantrs2_core::gate::{multi::*, single::*, GateOp};
6use scirs2_core::ndarray::{s, Array1, Array2, Array3, Array4, Axis};
7
8use super::types::{
9    ActivationType, PositionEncodingType, QuantumAttentionType, QuantumFeedForward,
10    QuantumMultiHeadAttention, QuantumPositionEncoding, QuantumTransformerConfig,
11};
12
13/// Helper function to create causal attention mask
14pub fn create_causal_mask(batch_size: usize, seq_len: usize) -> Array3<bool> {
15    let mut mask = Array3::from_elem((batch_size, seq_len, seq_len), false);
16    for batch_idx in 0..batch_size {
17        for i in 0..seq_len {
18            for j in (i + 1)..seq_len {
19                mask[[batch_idx, i, j]] = true;
20            }
21        }
22    }
23    mask
24}
25/// Helper function to create padding mask
26pub fn create_padding_mask(
27    batch_size: usize,
28    seq_len: usize,
29    actual_lengths: &[usize],
30) -> Array3<bool> {
31    let mut mask = Array3::from_elem((batch_size, seq_len, seq_len), false);
32    for (batch_idx, &actual_len) in actual_lengths.iter().enumerate() {
33        if batch_idx < batch_size {
34            for i in 0..seq_len {
35                for j in actual_len..seq_len {
36                    mask[[batch_idx, i, j]] = true;
37                }
38            }
39        }
40    }
41    mask
42}
43#[cfg(test)]
44mod tests {
45    use super::*;
46    #[test]
47    fn test_quantum_transformer_config() {
48        let config = QuantumTransformerConfig::default();
49        assert_eq!(config.model_dim, 512);
50        assert_eq!(config.num_heads, 8);
51        assert_eq!(config.num_layers, 6);
52        let large_config = QuantumTransformerConfig::large();
53        assert_eq!(large_config.model_dim, 1024);
54        assert_eq!(large_config.num_heads, 16);
55    }
56    #[test]
57    fn test_quantum_multi_head_attention_creation() {
58        let attention = QuantumMultiHeadAttention::new(
59            8,
60            512,
61            QuantumAttentionType::HybridQuantumClassical,
62            10,
63        );
64        assert!(attention.is_ok());
65        let attn = attention.expect("Attention creation should succeed");
66        assert_eq!(attn.num_heads, 8);
67        assert_eq!(attn.model_dim, 512);
68        assert_eq!(attn.head_dim, 64);
69    }
70    #[test]
71    fn test_quantum_position_encoding() {
72        let pos_enc = QuantumPositionEncoding::new(PositionEncodingType::Sinusoidal, 256, 512, 8);
73        assert!(pos_enc.is_ok());
74        let pe = pos_enc.expect("Position encoding creation should succeed");
75        assert_eq!(pe.model_dim, 256);
76        assert_eq!(pe.max_seq_len, 512);
77    }
78    #[test]
79    fn test_quantum_feedforward() {
80        let ff = QuantumFeedForward::new(256, 1024, 256, 8, ActivationType::QuantumGELU, 0.1);
81        assert!(ff.is_ok());
82        let feedforward = ff.expect("Feedforward creation should succeed");
83        assert_eq!(feedforward.input_dim, 256);
84        assert_eq!(feedforward.hidden_dim, 1024);
85        assert_eq!(feedforward.output_dim, 256);
86    }
87    #[test]
88    fn test_causal_mask_creation() {
89        let mask = create_causal_mask(2, 4);
90        assert_eq!(mask.dim(), (2, 4, 4));
91        assert!(!mask[[0, 0, 0]]);
92        assert!(!mask[[0, 1, 0]]);
93        assert!(!mask[[0, 1, 1]]);
94        assert!(mask[[0, 0, 1]]);
95        assert!(mask[[0, 0, 2]]);
96        assert!(mask[[0, 1, 2]]);
97    }
98    #[test]
99    fn test_padding_mask_creation() {
100        let actual_lengths = vec![3, 2];
101        let mask = create_padding_mask(2, 4, &actual_lengths);
102        assert!(!mask[[0, 0, 2]]);
103        assert!(mask[[0, 0, 3]]);
104        assert!(!mask[[1, 0, 1]]);
105        assert!(mask[[1, 0, 2]]);
106        assert!(mask[[1, 0, 3]]);
107    }
108}