1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
use crate::{shapes::*, tensor::*};

mod cpu_kernel;
#[cfg(feature = "cuda")]
mod cuda_kernel;

pub type Query<const NUM_HEADS: usize, const HEAD_DIM: usize, E, D> =
    Tensor<(Const<NUM_HEADS>, usize, Const<HEAD_DIM>), E, D>;
pub type Key<const NUM_HEADS: usize, const HEAD_DIM: usize, E, D> =
    Tensor<(Const<NUM_HEADS>, Const<HEAD_DIM>, usize), E, D>;
pub type Value<const NUM_HEADS: usize, const HEAD_DIM: usize, E, D> =
    Tensor<(Const<NUM_HEADS>, usize, Const<HEAD_DIM>), E, D>;

type QkvTuple<const NUM_HEADS: usize, const HEAD_DIM: usize, E, D> = (
    Query<NUM_HEADS, HEAD_DIM, E, D>,
    Key<NUM_HEADS, HEAD_DIM, E, D>,
    Value<NUM_HEADS, HEAD_DIM, E, D>,
);

/// AttentionReshape qkv + past_key + past_value into (q, k, v) used
/// in attention layer
pub trait TryAttentionReshape<E: Dtype>: Storage<E> {
    /// This is an inference only kernel:
    /// Within `transformers` architecture, a core component is the `attention`
    /// layer, which can be written in many forms.
    ///
    /// This particular version expects a `qkv` tensor (gotten from one single
    /// Linear layer, corresponding of stacked `query`, `key`, `value`.
    /// And `past_key` and `past_value` which are the cached values within attention
    /// (This speeds up inference speed).
    /// For the first pass, just send zero-width tensors when the cache isn't present
    /// already.
    ///
    /// Having a single layer instead of many `cat`, `reshape`, `permute` makes this
    /// operation very efficient on GPU.
    fn attention_reshape<
        const THREE_HIDDEN_DIM: usize,
        const NUM_HEADS: usize,
        const HEAD_DIM: usize,
    >(
        &self,
        qkv: &Tensor<(usize, Const<THREE_HIDDEN_DIM>), E, Self>,
        past_key: &Tensor<(Const<NUM_HEADS>, Const<HEAD_DIM>, usize), E, Self>,
        past_value: &Tensor<(Const<NUM_HEADS>, usize, Const<HEAD_DIM>), E, Self>,
    ) -> QkvTuple<NUM_HEADS, HEAD_DIM, E, Self> {
        self.try_attention_reshape(qkv, past_key, past_value)
            .unwrap()
    }

    /// Fallible version of [TryAttentionReshape::attention_reshape]
    fn try_attention_reshape<
        const THREE_HIDDEN_DIM: usize,
        const NUM_HEADS: usize,
        const HEAD_DIM: usize,
    >(
        &self,
        qkv: &Tensor<(usize, Const<THREE_HIDDEN_DIM>), E, Self>,
        past_key: &Tensor<(Const<NUM_HEADS>, Const<HEAD_DIM>, usize), E, Self>,
        past_value: &Tensor<(Const<NUM_HEADS>, usize, Const<HEAD_DIM>), E, Self>,
    ) -> Result<QkvTuple<NUM_HEADS, HEAD_DIM, E, Self>, Self::Err>;
}

pub trait AttentionReshapeKernel<E: Dtype>: Storage<E> {
    fn forward<const THREE_HIDDEN_DIM: usize, const NUM_HEADS: usize, const HEAD_DIM: usize>(
        &self,
        qkv: &Tensor<(usize, Const<THREE_HIDDEN_DIM>), E, Self>,
        past_key: &Tensor<(Const<NUM_HEADS>, Const<HEAD_DIM>, usize), E, Self>,
        past_value: &Tensor<(Const<NUM_HEADS>, usize, Const<HEAD_DIM>), E, Self>,
    ) -> Result<QkvTuple<NUM_HEADS, HEAD_DIM, E, Self>, Self::Err>;
}

impl<E: Dtype, D: AttentionReshapeKernel<E>> TryAttentionReshape<E> for D {
    /// Fallible version of [TryAttentionReshape::attention_reshape]
    fn try_attention_reshape<
        const THREE_HIDDEN_DIM: usize,
        const NUM_HEADS: usize,
        const HEAD_DIM: usize,
    >(
        &self,
        qkv: &Tensor<(usize, Const<THREE_HIDDEN_DIM>), E, Self>,
        past_key: &Tensor<(Const<NUM_HEADS>, Const<HEAD_DIM>, usize), E, Self>,
        past_value: &Tensor<(Const<NUM_HEADS>, usize, Const<HEAD_DIM>), E, Self>,
    ) -> Result<QkvTuple<NUM_HEADS, HEAD_DIM, E, Self>, Self::Err> {
        let device = qkv.device.clone();
        device.forward(qkv, past_key, past_value)
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::{tensor_ops::*, tests::*};

    #[test]
    fn test_attention_reshape() {
        let dev: TestDevice = Default::default();

        const NUM_HEADS: usize = 2;
        const HEAD_DIM: usize = 3;
        let sequence_length = 1;
        let past_length = 3;

        let qkv: Tensor<(usize, Const<{ NUM_HEADS * HEAD_DIM * 3 }>), TestDtype, _> =
            dev.zeros_like(&(sequence_length, Const)) + 1.0;
        let past_key: Tensor<(Const<NUM_HEADS>, Const<HEAD_DIM>, usize), TestDtype, _> =
            dev.zeros_like(&(Const, Const, past_length)) + 2.0;
        let past_value: Tensor<(Const<NUM_HEADS>, usize, Const<HEAD_DIM>), TestDtype, _> =
            dev.zeros_like(&(Const, past_length, Const)) + 3.0;

        let (q, k, v) = dev.attention_reshape(&qkv, &past_key, &past_value);

        let q = q.realize::<(Const<NUM_HEADS>, Const<1>, Const<HEAD_DIM>)>();
        let k = k.realize::<(Const<NUM_HEADS>, Const<HEAD_DIM>, Const<4>)>();
        let v = v.realize::<(Const<NUM_HEADS>, Const<4>, Const<HEAD_DIM>)>();

        assert_close_to_literal!(q, [[[1.0; HEAD_DIM]; 1]; NUM_HEADS]);
        assert_close_to_literal!(
            k,
            [
                [
                    [2.0, 2.0, 2.0, 1.0],
                    [2.0, 2.0, 2.0, 1.0],
                    [2.0, 2.0, 2.0, 1.0]
                ],
                [
                    [2.0, 2.0, 2.0, 1.0],
                    [2.0, 2.0, 2.0, 1.0],
                    [2.0, 2.0, 2.0, 1.0]
                ]
            ]
        );
        assert_close_to_literal!(
            v,
            [
                [
                    [3.0, 3.0, 3.0],
                    [3.0, 3.0, 3.0],
                    [3.0, 3.0, 3.0],
                    [1.0, 1.0, 1.0]
                ],
                [
                    [3.0, 3.0, 3.0],
                    [3.0, 3.0, 3.0],
                    [3.0, 3.0, 3.0],
                    [1.0, 1.0, 1.0]
                ]
            ]
        );
    }
}