1use yscv_kernels::{matmul_2d, softmax_last_dim};
2use yscv_tensor::Tensor;
3
4use crate::ModelError;
5
6pub fn scaled_dot_product_attention(
11 query: &Tensor,
12 key: &Tensor,
13 value: &Tensor,
14) -> Result<Tensor, ModelError> {
15 let q_shape = query.shape();
16 let k_shape = key.shape();
17 if q_shape.len() != 2 || k_shape.len() != 2 {
18 return Err(ModelError::InvalidParameterShape {
19 parameter: "attention QKV",
20 expected: vec![0, 0],
21 got: q_shape.to_vec(),
22 });
23 }
24 let d_k = q_shape[1] as f32;
25
26 let kt = key.transpose_2d()?;
27 let scores = matmul_2d(query, &kt)?;
28 let scale = 1.0 / d_k.sqrt();
29 let scaled = scores.scale(scale);
30 let attn_weights = softmax_last_dim(&scaled)?;
31 let output = matmul_2d(&attn_weights, value)?;
32 Ok(output)
33}
34
35pub struct MultiHeadAttentionConfig {
37 pub d_model: usize,
38 pub num_heads: usize,
39}
40
41pub struct MultiHeadAttention {
43 pub w_q: Tensor, pub w_k: Tensor,
45 pub w_v: Tensor,
46 pub w_o: Tensor, pub num_heads: usize,
48 pub d_k: usize,
49}
50
51impl MultiHeadAttention {
52 pub fn new(config: &MultiHeadAttentionConfig) -> Result<Self, ModelError> {
54 let d = config.d_model;
55 let h = config.num_heads;
56 if !d.is_multiple_of(h) {
57 return Err(ModelError::InvalidParameterShape {
58 parameter: "d_model must be divisible by num_heads",
59 expected: vec![d, h],
60 got: vec![d % h],
61 });
62 }
63 let d_k = d / h;
64 let z = vec![0.0f32; d * d];
65 Ok(Self {
66 w_q: Tensor::from_vec(vec![d, d], z.clone())?,
67 w_k: Tensor::from_vec(vec![d, d], z.clone())?,
68 w_v: Tensor::from_vec(vec![d, d], z.clone())?,
69 w_o: Tensor::from_vec(vec![d, d], z)?,
70 num_heads: h,
71 d_k,
72 })
73 }
74
75 pub fn forward(&self, input: &Tensor) -> Result<Tensor, ModelError> {
77 let shape = input.shape();
78 let _seq_len = shape[0];
79 let _d_model = shape[1];
80
81 let q = matmul_2d(input, &self.w_q)?;
82 let k = matmul_2d(input, &self.w_k)?;
83 let v = matmul_2d(input, &self.w_v)?;
84
85 let mut head_outputs = Vec::new();
86 for h in 0..self.num_heads {
87 let start = h * self.d_k;
88 let qh = q.narrow(1, start, self.d_k)?;
89 let kh = k.narrow(1, start, self.d_k)?;
90 let vh = v.narrow(1, start, self.d_k)?;
91 let attn = scaled_dot_product_attention(&qh, &kh, &vh)?;
92 head_outputs.push(attn);
93 }
94
95 let concat = Tensor::cat(&head_outputs.iter().collect::<Vec<_>>(), 1)?;
97 let output = matmul_2d(&concat, &self.w_o)?;
98 Ok(output)
99 }
100}
101
102pub struct FeedForward {
104 pub w1: Tensor, pub b1: Tensor, pub w2: Tensor, pub b2: Tensor, }
109
110impl FeedForward {
111 pub fn new(d_model: usize, d_ff: usize) -> Result<Self, ModelError> {
112 Ok(Self {
113 w1: Tensor::from_vec(vec![d_model, d_ff], vec![0.0; d_model * d_ff])?,
114 b1: Tensor::from_vec(vec![d_ff], vec![0.0; d_ff])?,
115 w2: Tensor::from_vec(vec![d_ff, d_model], vec![0.0; d_ff * d_model])?,
116 b2: Tensor::from_vec(vec![d_model], vec![0.0; d_model])?,
117 })
118 }
119
120 pub fn forward(&self, input: &Tensor) -> Result<Tensor, ModelError> {
122 let h = matmul_2d(input, &self.w1)?;
123 let h = h.add(&self.b1.unsqueeze(0)?)?;
124 let data: Vec<f32> = h.data().iter().map(|&v| v.max(0.0)).collect();
125 let h = Tensor::from_vec(h.shape().to_vec(), data)?;
126 let out = matmul_2d(&h, &self.w2)?;
127 let out = out.add(&self.b2.unsqueeze(0)?)?;
128 Ok(out)
129 }
130}
131
132pub fn generate_causal_mask(seq_len: usize) -> Result<Tensor, ModelError> {
137 let mut data = vec![0.0f32; seq_len * seq_len];
138 for i in 0..seq_len {
139 for j in (i + 1)..seq_len {
140 data[i * seq_len + j] = f32::NEG_INFINITY;
141 }
142 }
143 Ok(Tensor::from_vec(vec![seq_len, seq_len], data)?)
144}
145
146pub fn generate_padding_mask(lengths: &[usize], max_len: usize) -> Result<Tensor, ModelError> {
153 let batch = lengths.len();
154 let mut data = vec![0.0f32; batch * max_len];
155 for (b, &len) in lengths.iter().enumerate() {
156 for j in len..max_len {
157 data[b * max_len + j] = f32::NEG_INFINITY;
158 }
159 }
160 Ok(Tensor::from_vec(vec![batch, max_len], data)?)
161}
162
163pub struct TransformerEncoderBlock {
165 pub mha: MultiHeadAttention,
166 pub ffn: FeedForward,
167 pub ln1_gamma: Tensor,
168 pub ln1_beta: Tensor,
169 pub ln2_gamma: Tensor,
170 pub ln2_beta: Tensor,
171 pub d_model: usize,
172}
173
174impl TransformerEncoderBlock {
175 pub fn new(d_model: usize, num_heads: usize, d_ff: usize) -> Result<Self, ModelError> {
176 let config = MultiHeadAttentionConfig { d_model, num_heads };
177 Ok(Self {
178 mha: MultiHeadAttention::new(&config)?,
179 ffn: FeedForward::new(d_model, d_ff)?,
180 ln1_gamma: Tensor::from_vec(vec![d_model], vec![1.0; d_model])?,
181 ln1_beta: Tensor::from_vec(vec![d_model], vec![0.0; d_model])?,
182 ln2_gamma: Tensor::from_vec(vec![d_model], vec![1.0; d_model])?,
183 ln2_beta: Tensor::from_vec(vec![d_model], vec![0.0; d_model])?,
184 d_model,
185 })
186 }
187
188 pub fn forward(&self, input: &Tensor) -> Result<Tensor, ModelError> {
190 let attn_out = self.mha.forward(input)?;
191 let residual1 = input.add(&attn_out)?;
192 let norm1 = layer_norm_2d(&residual1, &self.ln1_gamma, &self.ln1_beta, self.d_model)?;
193
194 let ffn_out = self.ffn.forward(&norm1)?;
195 let residual2 = norm1.add(&ffn_out)?;
196 let norm2 = layer_norm_2d(&residual2, &self.ln2_gamma, &self.ln2_beta, self.d_model)?;
197 Ok(norm2)
198 }
199}
200
201fn layer_norm_2d(
202 input: &Tensor,
203 gamma: &Tensor,
204 beta: &Tensor,
205 _d: usize,
206) -> Result<Tensor, ModelError> {
207 let params = yscv_kernels::LayerNormLastDimParams {
208 gamma,
209 beta,
210 epsilon: 1e-5,
211 };
212 yscv_kernels::layer_norm_last_dim(input, params).map_err(Into::into)
213}