1use super::with_tracing::{linear, Embedding, Linear};
11use candle::{Module, Result, Tensor, D};
12use candle_nn::{layer_norm, LayerNorm, VarBuilder};
13use serde::Deserialize;
14
15#[derive(Debug, Clone, Deserialize)]
16pub struct Config {
17 pub vocab_size: usize,
18 pub hidden_size: usize,
19 pub encoder_hidden_size: usize,
20 pub intermediate_size: usize,
21 pub projection_dim: usize,
22 pub num_hidden_layers: usize,
23 pub num_attention_heads: usize,
24 pub max_position_embeddings: usize,
25 pub hidden_act: candle_nn::Activation,
26 pub layer_norm_eps: f64,
27 pub is_decoder: bool,
28}
29
30#[derive(Debug, Clone)]
31struct TextEmbeddings {
32 word_embedddings: Embedding,
33 position_embeddings: Embedding,
34 layer_norm: LayerNorm,
35 position_ids: Tensor,
36}
37
38impl TextEmbeddings {
39 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
40 let word_embedddings =
41 Embedding::new(cfg.vocab_size, cfg.hidden_size, vb.pp("word_embeddings"))?;
42 let position_embeddings = Embedding::new(
43 cfg.max_position_embeddings,
44 cfg.hidden_size,
45 vb.pp("position_embeddings"),
46 )?;
47 let layer_norm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("LayerNorm"))?;
48 let position_ids =
49 Tensor::arange(0, cfg.max_position_embeddings as u32, vb.device())?.unsqueeze(0)?;
50 Ok(Self {
51 word_embedddings,
52 position_embeddings,
53 layer_norm,
54 position_ids,
55 })
56 }
57
58 fn forward(&self, xs: &Tensor, past_kv_len: usize) -> Result<Tensor> {
59 let seq_len = xs.dim(1)?;
60 let position_ids = self.position_ids.narrow(1, past_kv_len, seq_len)?;
61 let embeddings = self.word_embedddings.forward(xs)?;
62 let position_embeddings = self.position_embeddings.forward(&position_ids)?;
63 (embeddings + position_embeddings)?.apply(&self.layer_norm)
64 }
65}
66
67#[derive(Debug, Clone)]
68struct TextSelfAttention {
69 query: Linear,
70 key: Linear,
71 value: Linear,
72 attention_head_size: usize,
73 num_attention_heads: usize,
74 attention_scale: f64,
75 kv_cache: Option<(Tensor, Tensor)>,
76}
77
78impl TextSelfAttention {
79 fn new(cfg: &Config, is_cross_attention: bool, vb: VarBuilder) -> Result<Self> {
80 let num_attention_heads = cfg.num_attention_heads;
81 let attention_head_size = cfg.hidden_size / num_attention_heads;
82 let all_head_size = cfg.num_attention_heads * attention_head_size;
83 let query = linear(cfg.hidden_size, all_head_size, vb.pp("query"))?;
84 let in_size = if is_cross_attention {
85 cfg.encoder_hidden_size
86 } else {
87 cfg.hidden_size
88 };
89 let key = linear(in_size, all_head_size, vb.pp("key"))?;
90 let value = linear(in_size, all_head_size, vb.pp("value"))?;
91 let attention_scale = 1f64 / (attention_head_size as f64).sqrt();
92 Ok(Self {
93 query,
94 key,
95 value,
96 attention_head_size,
97 num_attention_heads,
98 attention_scale,
99 kv_cache: None,
100 })
101 }
102
103 fn transpose_for_scores(&self, xs: &Tensor) -> Result<Tensor> {
104 let (b_size, seq_len, _) = xs.dims3()?;
105 xs.reshape((
106 b_size,
107 seq_len,
108 self.num_attention_heads,
109 self.attention_head_size,
110 ))?
111 .permute((0, 2, 1, 3))
112 }
113
114 fn reset_kv_cache(&mut self) {
115 self.kv_cache = None
116 }
117
118 fn forward(
119 &mut self,
120 xs: &Tensor,
121 encoder_hidden_states: Option<&Tensor>,
122 attention_mask: Option<&Tensor>,
123 ) -> Result<Tensor> {
124 let query = self
125 .transpose_for_scores(&self.query.forward(xs)?)?
126 .contiguous()?;
127 let (key, value) = match encoder_hidden_states {
128 None => {
129 let key = self.transpose_for_scores(&self.key.forward(xs)?)?;
130 let value = self.transpose_for_scores(&self.value.forward(xs)?)?;
131 let (key, value) = match &self.kv_cache {
132 None => (key, value),
133 Some((prev_key, prev_value)) => {
134 let key = Tensor::cat(&[prev_key, &key], 2)?;
135 let value = Tensor::cat(&[prev_value, &value], 2)?;
136 (key, value)
137 }
138 };
139 self.kv_cache = Some((key.clone(), value.clone()));
140 (key, value)
141 }
142 Some(xs) => {
143 let key = self.transpose_for_scores(&self.key.forward(xs)?)?;
144 let value = self.transpose_for_scores(&self.value.forward(xs)?)?;
145 (key, value)
147 }
148 };
149 let key = key.contiguous()?;
150 let value = value.contiguous()?;
151 let attention_scores = query.matmul(&key.t()?)?;
152 let attention_scores = (attention_scores * self.attention_scale)?;
153 let attention_scores = match attention_mask {
154 Some(mask) => attention_scores.broadcast_add(mask)?,
155 None => attention_scores,
156 };
157 let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?;
158 attention_probs
159 .matmul(&value)?
160 .permute((0, 2, 1, 3))?
161 .flatten_from(D::Minus2)
162 }
163}
164
165#[derive(Debug, Clone)]
166struct TextSelfOutput {
167 dense: Linear,
168 layer_norm: LayerNorm,
169}
170
171impl TextSelfOutput {
172 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
173 let dense = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("dense"))?;
174 let layer_norm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("LayerNorm"))?;
175 Ok(Self { dense, layer_norm })
176 }
177
178 fn forward(&self, xs: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
179 (xs.apply(&self.dense) + input_tensor)?.apply(&self.layer_norm)
180 }
181}
182
183#[derive(Debug, Clone)]
184struct TextAttention {
185 self_: TextSelfAttention,
186 output: TextSelfOutput,
187}
188
189impl TextAttention {
190 fn new(cfg: &Config, is_cross_attention: bool, vb: VarBuilder) -> Result<Self> {
191 let self_ = TextSelfAttention::new(cfg, is_cross_attention, vb.pp("self"))?;
192 let output = TextSelfOutput::new(cfg, vb.pp("output"))?;
193 Ok(Self { self_, output })
194 }
195
196 fn reset_kv_cache(&mut self) {
197 self.self_.reset_kv_cache()
198 }
199
200 fn forward(
201 &mut self,
202 xs: &Tensor,
203 encoder_hidden_states: Option<&Tensor>,
204 attention_mask: Option<&Tensor>,
205 ) -> Result<Tensor> {
206 let self_outputs = self
207 .self_
208 .forward(xs, encoder_hidden_states, attention_mask)?;
209 self.output.forward(&self_outputs, xs)
210 }
211}
212
213#[derive(Debug, Clone)]
214struct TextIntermediate {
215 dense: Linear,
216 intermediate_act_fn: candle_nn::Activation,
217}
218
219impl TextIntermediate {
220 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
221 let dense = linear(cfg.hidden_size, cfg.intermediate_size, vb.pp("dense"))?;
222 Ok(Self {
223 dense,
224 intermediate_act_fn: cfg.hidden_act,
225 })
226 }
227}
228
229impl Module for TextIntermediate {
230 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
231 xs.apply(&self.dense)?.apply(&self.intermediate_act_fn)
232 }
233}
234
235#[derive(Debug, Clone)]
236struct TextOutput {
237 dense: Linear,
238 layer_norm: LayerNorm,
239}
240
241impl TextOutput {
242 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
243 let dense = linear(cfg.intermediate_size, cfg.hidden_size, vb.pp("dense"))?;
244 let layer_norm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("LayerNorm"))?;
245 Ok(Self { dense, layer_norm })
246 }
247
248 fn forward(&self, xs: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
249 (xs.apply(&self.dense)? + input_tensor)?.apply(&self.layer_norm)
250 }
251}
252
253#[derive(Debug, Clone)]
254struct TextLayer {
255 attention: TextAttention,
256 cross_attention: Option<TextAttention>,
257 intermediate: TextIntermediate,
258 output: TextOutput,
259}
260
261impl TextLayer {
262 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
263 let attention = TextAttention::new(cfg, false, vb.pp("attention"))?;
264 let cross_attention = if cfg.is_decoder {
265 Some(TextAttention::new(cfg, true, vb.pp("crossattention"))?)
266 } else {
267 None
268 };
269 let intermediate = TextIntermediate::new(cfg, vb.pp("intermediate"))?;
270 let output = TextOutput::new(cfg, vb.pp("output"))?;
271 Ok(Self {
272 attention,
273 cross_attention,
274 intermediate,
275 output,
276 })
277 }
278
279 fn reset_kv_cache(&mut self) {
280 self.attention.reset_kv_cache();
281 if let Some(ca) = &mut self.cross_attention {
282 ca.reset_kv_cache()
283 }
284 }
285
286 fn forward(
287 &mut self,
288 xs: &Tensor,
289 encoder_hidden_states: &Tensor,
290 attention_mask: &Tensor,
291 ) -> Result<Tensor> {
292 let attention_output = self.attention.forward(xs, None, Some(attention_mask))?;
293 let attention_output = match &mut self.cross_attention {
294 Some(ca) => ca.forward(&attention_output, Some(encoder_hidden_states), None)?,
295 None => candle::bail!("expected some cross-attn"),
296 };
297 let intermediate_output = self.intermediate.forward(&attention_output)?;
298 self.output.forward(&intermediate_output, &attention_output)
299 }
300}
301
302#[derive(Debug, Clone)]
303struct TextEncoder {
304 layers: Vec<TextLayer>,
305}
306
307impl TextEncoder {
308 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
309 let vb = vb.pp("layer");
310 let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
311 for i in 0..cfg.num_hidden_layers {
312 let layer = TextLayer::new(cfg, vb.pp(i))?;
313 layers.push(layer)
314 }
315 Ok(Self { layers })
316 }
317
318 fn reset_kv_cache(&mut self) {
319 self.layers.iter_mut().for_each(|l| l.reset_kv_cache())
320 }
321
322 fn forward(
323 &mut self,
324 xs: &Tensor,
325 encoder_hidden_states: &Tensor,
326 attention_mask: &Tensor,
327 ) -> Result<Tensor> {
328 let mut xs = xs.clone();
329 for layer in self.layers.iter_mut() {
330 xs = layer.forward(&xs, encoder_hidden_states, attention_mask)?
331 }
332 Ok(xs)
333 }
334}
335
336#[derive(Debug, Clone)]
337pub struct TextPooler {
338 dense: Linear,
339}
340
341impl TextPooler {
342 pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
343 let dense = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("dense"))?;
344 Ok(Self { dense })
345 }
346}
347
348impl Module for TextPooler {
349 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
350 xs.narrow(D::Minus1, 0, 1)?
351 .squeeze(D::Minus1)?
352 .apply(&self.dense)?
353 .tanh()
354 }
355}
356
357#[derive(Debug, Clone)]
358struct TextPredictionHeadTransform {
359 dense: Linear,
360 transform_act_fn: candle_nn::Activation,
361 layer_norm: LayerNorm,
362}
363
364impl TextPredictionHeadTransform {
365 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
366 let dense = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("dense"))?;
367 let layer_norm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("LayerNorm"))?;
368 Ok(Self {
369 dense,
370 transform_act_fn: cfg.hidden_act,
371 layer_norm,
372 })
373 }
374}
375
376impl Module for TextPredictionHeadTransform {
377 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
378 xs.apply(&self.dense)?
379 .apply(&self.transform_act_fn)?
380 .apply(&self.layer_norm)
381 }
382}
383
384#[derive(Debug, Clone)]
385struct TextLMPredictionHead {
386 transform: TextPredictionHeadTransform,
387 decoder: Linear,
388}
389
390impl TextLMPredictionHead {
391 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
392 let transform = TextPredictionHeadTransform::new(cfg, vb.pp("transform"))?;
393 let weight = vb.get((cfg.vocab_size, cfg.hidden_size), "decoder.weight")?;
394 let bias = vb.get(cfg.vocab_size, "bias")?;
395 let decoder = Linear::from_weights(weight, Some(bias));
396 Ok(Self { transform, decoder })
397 }
398}
399
400impl Module for TextLMPredictionHead {
401 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
402 xs.apply(&self.transform)?.apply(&self.decoder)
403 }
404}
405
406#[derive(Debug, Clone)]
407struct TextOnlyMLMHead {
408 predictions: TextLMPredictionHead,
409}
410
411impl TextOnlyMLMHead {
412 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
413 let predictions = TextLMPredictionHead::new(cfg, vb.pp("predictions"))?;
414 Ok(Self { predictions })
415 }
416}
417
418impl Module for TextOnlyMLMHead {
419 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
420 self.predictions.forward(xs)
421 }
422}
423
424#[derive(Debug, Clone)]
425struct TextModel {
426 embeddings: TextEmbeddings,
427 encoder: TextEncoder,
428 past_kv_len: usize,
429 }
431
432impl TextModel {
433 pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
434 let embeddings = TextEmbeddings::new(cfg, vb.pp("embeddings"))?;
435 let encoder = TextEncoder::new(cfg, vb.pp("encoder"))?;
436 Ok(Self {
437 embeddings,
438 encoder,
439 past_kv_len: 0,
440 })
441 }
442
443 fn forward(
444 &mut self,
445 input_ids: &Tensor,
446 encoder_hidden_states: &Tensor,
447 attention_mask: &Tensor,
448 ) -> Result<Tensor> {
449 let (_b_sz, seq_len) = input_ids.dims2()?;
450 let embedding_output = self.embeddings.forward(input_ids, self.past_kv_len)?;
451 let sequence_output =
452 self.encoder
453 .forward(&embedding_output, encoder_hidden_states, attention_mask)?;
454 self.past_kv_len += seq_len;
455 Ok(sequence_output)
457 }
458
459 fn reset_kv_cache(&mut self) {
460 self.past_kv_len = 0;
461 self.encoder.reset_kv_cache();
462 }
463}
464
465#[derive(Debug, Clone)]
466pub struct TextLMHeadModel {
467 bert: TextModel,
468 cls: TextOnlyMLMHead,
469}
470
471impl TextLMHeadModel {
472 pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
473 let bert = TextModel::new(cfg, vb.pp("bert"))?;
474 let cls = TextOnlyMLMHead::new(cfg, vb.pp("cls"))?;
475 Ok(Self { bert, cls })
476 }
477
478 pub fn forward(
479 &mut self,
480 input_ids: &Tensor,
481 encoder_hidden_states: &Tensor,
482 ) -> Result<Tensor> {
483 let seq_len = input_ids.dim(1)?;
484 let mask: Vec<_> = (0..seq_len)
485 .flat_map(|i| (0..seq_len).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 }))
486 .collect();
487 let mask = Tensor::from_vec(mask, (seq_len, seq_len), input_ids.device())?;
488 let sequence_output = self.bert.forward(input_ids, encoder_hidden_states, &mask)?;
489 let prediction_scores = self.cls.forward(&sequence_output)?;
490 Ok(prediction_scores)
492 }
493
494 pub fn reset_kv_cache(&mut self) {
495 self.bert.reset_kv_cache()
496 }
497}