candle_transformers/models/
blip_text.rs

1//! Implementation of BLIP text encoder/decoder.
2//!
3//! - 📝 [Paper](https://arxiv.org/abs/2201.12086). BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation"
4//!
5//! - ⚡ [Interactive Wasm Example](https://huggingface.co/spaces/radames/Candle-BLIP-Image-Captioning)
6//! - 💻 [GH Link](https://github.com/salesforce/BLIP)
7//! - 🤗 [HF Link](https://huggingface.co/Salesforce/blip-image-captioning-base)
8//! - 📝 [Paper](https://arxiv.org/abs/2201.12086)
9//!
10use super::with_tracing::{linear, Embedding, Linear};
11use candle::{Module, Result, Tensor, D};
12use candle_nn::{layer_norm, LayerNorm, VarBuilder};
13use serde::Deserialize;
14
15#[derive(Debug, Clone, Deserialize)]
16pub struct Config {
17    pub vocab_size: usize,
18    pub hidden_size: usize,
19    pub encoder_hidden_size: usize,
20    pub intermediate_size: usize,
21    pub projection_dim: usize,
22    pub num_hidden_layers: usize,
23    pub num_attention_heads: usize,
24    pub max_position_embeddings: usize,
25    pub hidden_act: candle_nn::Activation,
26    pub layer_norm_eps: f64,
27    pub is_decoder: bool,
28}
29
30#[derive(Debug, Clone)]
31struct TextEmbeddings {
32    word_embedddings: Embedding,
33    position_embeddings: Embedding,
34    layer_norm: LayerNorm,
35    position_ids: Tensor,
36}
37
38impl TextEmbeddings {
39    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
40        let word_embedddings =
41            Embedding::new(cfg.vocab_size, cfg.hidden_size, vb.pp("word_embeddings"))?;
42        let position_embeddings = Embedding::new(
43            cfg.max_position_embeddings,
44            cfg.hidden_size,
45            vb.pp("position_embeddings"),
46        )?;
47        let layer_norm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("LayerNorm"))?;
48        let position_ids =
49            Tensor::arange(0, cfg.max_position_embeddings as u32, vb.device())?.unsqueeze(0)?;
50        Ok(Self {
51            word_embedddings,
52            position_embeddings,
53            layer_norm,
54            position_ids,
55        })
56    }
57
58    fn forward(&self, xs: &Tensor, past_kv_len: usize) -> Result<Tensor> {
59        let seq_len = xs.dim(1)?;
60        let position_ids = self.position_ids.narrow(1, past_kv_len, seq_len)?;
61        let embeddings = self.word_embedddings.forward(xs)?;
62        let position_embeddings = self.position_embeddings.forward(&position_ids)?;
63        (embeddings + position_embeddings)?.apply(&self.layer_norm)
64    }
65}
66
67#[derive(Debug, Clone)]
68struct TextSelfAttention {
69    query: Linear,
70    key: Linear,
71    value: Linear,
72    attention_head_size: usize,
73    num_attention_heads: usize,
74    attention_scale: f64,
75    kv_cache: Option<(Tensor, Tensor)>,
76}
77
78impl TextSelfAttention {
79    fn new(cfg: &Config, is_cross_attention: bool, vb: VarBuilder) -> Result<Self> {
80        let num_attention_heads = cfg.num_attention_heads;
81        let attention_head_size = cfg.hidden_size / num_attention_heads;
82        let all_head_size = cfg.num_attention_heads * attention_head_size;
83        let query = linear(cfg.hidden_size, all_head_size, vb.pp("query"))?;
84        let in_size = if is_cross_attention {
85            cfg.encoder_hidden_size
86        } else {
87            cfg.hidden_size
88        };
89        let key = linear(in_size, all_head_size, vb.pp("key"))?;
90        let value = linear(in_size, all_head_size, vb.pp("value"))?;
91        let attention_scale = 1f64 / (attention_head_size as f64).sqrt();
92        Ok(Self {
93            query,
94            key,
95            value,
96            attention_head_size,
97            num_attention_heads,
98            attention_scale,
99            kv_cache: None,
100        })
101    }
102
103    fn transpose_for_scores(&self, xs: &Tensor) -> Result<Tensor> {
104        let (b_size, seq_len, _) = xs.dims3()?;
105        xs.reshape((
106            b_size,
107            seq_len,
108            self.num_attention_heads,
109            self.attention_head_size,
110        ))?
111        .permute((0, 2, 1, 3))
112    }
113
114    fn reset_kv_cache(&mut self) {
115        self.kv_cache = None
116    }
117
118    fn forward(
119        &mut self,
120        xs: &Tensor,
121        encoder_hidden_states: Option<&Tensor>,
122        attention_mask: Option<&Tensor>,
123    ) -> Result<Tensor> {
124        let query = self
125            .transpose_for_scores(&self.query.forward(xs)?)?
126            .contiguous()?;
127        let (key, value) = match encoder_hidden_states {
128            None => {
129                let key = self.transpose_for_scores(&self.key.forward(xs)?)?;
130                let value = self.transpose_for_scores(&self.value.forward(xs)?)?;
131                let (key, value) = match &self.kv_cache {
132                    None => (key, value),
133                    Some((prev_key, prev_value)) => {
134                        let key = Tensor::cat(&[prev_key, &key], 2)?;
135                        let value = Tensor::cat(&[prev_value, &value], 2)?;
136                        (key, value)
137                    }
138                };
139                self.kv_cache = Some((key.clone(), value.clone()));
140                (key, value)
141            }
142            Some(xs) => {
143                let key = self.transpose_for_scores(&self.key.forward(xs)?)?;
144                let value = self.transpose_for_scores(&self.value.forward(xs)?)?;
145                // no kv-cache in this case, but the results could probably be memoized.
146                (key, value)
147            }
148        };
149        let key = key.contiguous()?;
150        let value = value.contiguous()?;
151        let attention_scores = query.matmul(&key.t()?)?;
152        let attention_scores = (attention_scores * self.attention_scale)?;
153        let attention_scores = match attention_mask {
154            Some(mask) => attention_scores.broadcast_add(mask)?,
155            None => attention_scores,
156        };
157        let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?;
158        attention_probs
159            .matmul(&value)?
160            .permute((0, 2, 1, 3))?
161            .flatten_from(D::Minus2)
162    }
163}
164
165#[derive(Debug, Clone)]
166struct TextSelfOutput {
167    dense: Linear,
168    layer_norm: LayerNorm,
169}
170
171impl TextSelfOutput {
172    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
173        let dense = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("dense"))?;
174        let layer_norm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("LayerNorm"))?;
175        Ok(Self { dense, layer_norm })
176    }
177
178    fn forward(&self, xs: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
179        (xs.apply(&self.dense) + input_tensor)?.apply(&self.layer_norm)
180    }
181}
182
183#[derive(Debug, Clone)]
184struct TextAttention {
185    self_: TextSelfAttention,
186    output: TextSelfOutput,
187}
188
189impl TextAttention {
190    fn new(cfg: &Config, is_cross_attention: bool, vb: VarBuilder) -> Result<Self> {
191        let self_ = TextSelfAttention::new(cfg, is_cross_attention, vb.pp("self"))?;
192        let output = TextSelfOutput::new(cfg, vb.pp("output"))?;
193        Ok(Self { self_, output })
194    }
195
196    fn reset_kv_cache(&mut self) {
197        self.self_.reset_kv_cache()
198    }
199
200    fn forward(
201        &mut self,
202        xs: &Tensor,
203        encoder_hidden_states: Option<&Tensor>,
204        attention_mask: Option<&Tensor>,
205    ) -> Result<Tensor> {
206        let self_outputs = self
207            .self_
208            .forward(xs, encoder_hidden_states, attention_mask)?;
209        self.output.forward(&self_outputs, xs)
210    }
211}
212
213#[derive(Debug, Clone)]
214struct TextIntermediate {
215    dense: Linear,
216    intermediate_act_fn: candle_nn::Activation,
217}
218
219impl TextIntermediate {
220    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
221        let dense = linear(cfg.hidden_size, cfg.intermediate_size, vb.pp("dense"))?;
222        Ok(Self {
223            dense,
224            intermediate_act_fn: cfg.hidden_act,
225        })
226    }
227}
228
229impl Module for TextIntermediate {
230    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
231        xs.apply(&self.dense)?.apply(&self.intermediate_act_fn)
232    }
233}
234
235#[derive(Debug, Clone)]
236struct TextOutput {
237    dense: Linear,
238    layer_norm: LayerNorm,
239}
240
241impl TextOutput {
242    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
243        let dense = linear(cfg.intermediate_size, cfg.hidden_size, vb.pp("dense"))?;
244        let layer_norm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("LayerNorm"))?;
245        Ok(Self { dense, layer_norm })
246    }
247
248    fn forward(&self, xs: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
249        (xs.apply(&self.dense)? + input_tensor)?.apply(&self.layer_norm)
250    }
251}
252
253#[derive(Debug, Clone)]
254struct TextLayer {
255    attention: TextAttention,
256    cross_attention: Option<TextAttention>,
257    intermediate: TextIntermediate,
258    output: TextOutput,
259}
260
261impl TextLayer {
262    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
263        let attention = TextAttention::new(cfg, false, vb.pp("attention"))?;
264        let cross_attention = if cfg.is_decoder {
265            Some(TextAttention::new(cfg, true, vb.pp("crossattention"))?)
266        } else {
267            None
268        };
269        let intermediate = TextIntermediate::new(cfg, vb.pp("intermediate"))?;
270        let output = TextOutput::new(cfg, vb.pp("output"))?;
271        Ok(Self {
272            attention,
273            cross_attention,
274            intermediate,
275            output,
276        })
277    }
278
279    fn reset_kv_cache(&mut self) {
280        self.attention.reset_kv_cache();
281        if let Some(ca) = &mut self.cross_attention {
282            ca.reset_kv_cache()
283        }
284    }
285
286    fn forward(
287        &mut self,
288        xs: &Tensor,
289        encoder_hidden_states: &Tensor,
290        attention_mask: &Tensor,
291    ) -> Result<Tensor> {
292        let attention_output = self.attention.forward(xs, None, Some(attention_mask))?;
293        let attention_output = match &mut self.cross_attention {
294            Some(ca) => ca.forward(&attention_output, Some(encoder_hidden_states), None)?,
295            None => candle::bail!("expected some cross-attn"),
296        };
297        let intermediate_output = self.intermediate.forward(&attention_output)?;
298        self.output.forward(&intermediate_output, &attention_output)
299    }
300}
301
302#[derive(Debug, Clone)]
303struct TextEncoder {
304    layers: Vec<TextLayer>,
305}
306
307impl TextEncoder {
308    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
309        let vb = vb.pp("layer");
310        let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
311        for i in 0..cfg.num_hidden_layers {
312            let layer = TextLayer::new(cfg, vb.pp(i))?;
313            layers.push(layer)
314        }
315        Ok(Self { layers })
316    }
317
318    fn reset_kv_cache(&mut self) {
319        self.layers.iter_mut().for_each(|l| l.reset_kv_cache())
320    }
321
322    fn forward(
323        &mut self,
324        xs: &Tensor,
325        encoder_hidden_states: &Tensor,
326        attention_mask: &Tensor,
327    ) -> Result<Tensor> {
328        let mut xs = xs.clone();
329        for layer in self.layers.iter_mut() {
330            xs = layer.forward(&xs, encoder_hidden_states, attention_mask)?
331        }
332        Ok(xs)
333    }
334}
335
336#[derive(Debug, Clone)]
337pub struct TextPooler {
338    dense: Linear,
339}
340
341impl TextPooler {
342    pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
343        let dense = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("dense"))?;
344        Ok(Self { dense })
345    }
346}
347
348impl Module for TextPooler {
349    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
350        xs.narrow(D::Minus1, 0, 1)?
351            .squeeze(D::Minus1)?
352            .apply(&self.dense)?
353            .tanh()
354    }
355}
356
357#[derive(Debug, Clone)]
358struct TextPredictionHeadTransform {
359    dense: Linear,
360    transform_act_fn: candle_nn::Activation,
361    layer_norm: LayerNorm,
362}
363
364impl TextPredictionHeadTransform {
365    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
366        let dense = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("dense"))?;
367        let layer_norm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("LayerNorm"))?;
368        Ok(Self {
369            dense,
370            transform_act_fn: cfg.hidden_act,
371            layer_norm,
372        })
373    }
374}
375
376impl Module for TextPredictionHeadTransform {
377    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
378        xs.apply(&self.dense)?
379            .apply(&self.transform_act_fn)?
380            .apply(&self.layer_norm)
381    }
382}
383
384#[derive(Debug, Clone)]
385struct TextLMPredictionHead {
386    transform: TextPredictionHeadTransform,
387    decoder: Linear,
388}
389
390impl TextLMPredictionHead {
391    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
392        let transform = TextPredictionHeadTransform::new(cfg, vb.pp("transform"))?;
393        let weight = vb.get((cfg.vocab_size, cfg.hidden_size), "decoder.weight")?;
394        let bias = vb.get(cfg.vocab_size, "bias")?;
395        let decoder = Linear::from_weights(weight, Some(bias));
396        Ok(Self { transform, decoder })
397    }
398}
399
400impl Module for TextLMPredictionHead {
401    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
402        xs.apply(&self.transform)?.apply(&self.decoder)
403    }
404}
405
406#[derive(Debug, Clone)]
407struct TextOnlyMLMHead {
408    predictions: TextLMPredictionHead,
409}
410
411impl TextOnlyMLMHead {
412    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
413        let predictions = TextLMPredictionHead::new(cfg, vb.pp("predictions"))?;
414        Ok(Self { predictions })
415    }
416}
417
418impl Module for TextOnlyMLMHead {
419    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
420        self.predictions.forward(xs)
421    }
422}
423
424#[derive(Debug, Clone)]
425struct TextModel {
426    embeddings: TextEmbeddings,
427    encoder: TextEncoder,
428    past_kv_len: usize,
429    // We do not need the pooler for caption generation
430}
431
432impl TextModel {
433    pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
434        let embeddings = TextEmbeddings::new(cfg, vb.pp("embeddings"))?;
435        let encoder = TextEncoder::new(cfg, vb.pp("encoder"))?;
436        Ok(Self {
437            embeddings,
438            encoder,
439            past_kv_len: 0,
440        })
441    }
442
443    fn forward(
444        &mut self,
445        input_ids: &Tensor,
446        encoder_hidden_states: &Tensor,
447        attention_mask: &Tensor,
448    ) -> Result<Tensor> {
449        let (_b_sz, seq_len) = input_ids.dims2()?;
450        let embedding_output = self.embeddings.forward(input_ids, self.past_kv_len)?;
451        let sequence_output =
452            self.encoder
453                .forward(&embedding_output, encoder_hidden_states, attention_mask)?;
454        self.past_kv_len += seq_len;
455        // We're interested in the sequence-output rather than the pooled-output.
456        Ok(sequence_output)
457    }
458
459    fn reset_kv_cache(&mut self) {
460        self.past_kv_len = 0;
461        self.encoder.reset_kv_cache();
462    }
463}
464
465#[derive(Debug, Clone)]
466pub struct TextLMHeadModel {
467    bert: TextModel,
468    cls: TextOnlyMLMHead,
469}
470
471impl TextLMHeadModel {
472    pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
473        let bert = TextModel::new(cfg, vb.pp("bert"))?;
474        let cls = TextOnlyMLMHead::new(cfg, vb.pp("cls"))?;
475        Ok(Self { bert, cls })
476    }
477
478    pub fn forward(
479        &mut self,
480        input_ids: &Tensor,
481        encoder_hidden_states: &Tensor,
482    ) -> Result<Tensor> {
483        let seq_len = input_ids.dim(1)?;
484        let mask: Vec<_> = (0..seq_len)
485            .flat_map(|i| (0..seq_len).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 }))
486            .collect();
487        let mask = Tensor::from_vec(mask, (seq_len, seq_len), input_ids.device())?;
488        let sequence_output = self.bert.forward(input_ids, encoder_hidden_states, &mask)?;
489        let prediction_scores = self.cls.forward(&sequence_output)?;
490        // return_logits is false so we don't discard the last sequence element.
491        Ok(prediction_scores)
492    }
493
494    pub fn reset_kv_cache(&mut self) {
495        self.bert.reset_kv_cache()
496    }
497}