candle_transformers/models/
quantized_blip.rs

1//! BLIP model implementation with quantization support.
2//!
3//! BLIP is a vision-language model for image understanding and generation tasks.
4//! This implementation provides quantization for reduced memory and compute.
5//!
6//! Key characteristics:
7//! - Vision encoder using ViT architecture
8//! - Text decoder using BERT-style transformer
9//! - Cross-attention between vision and text features
10//! - Support for 8-bit quantization
11//!
12//! References:
13//! - [BLIP Paper](https://arxiv.org/abs/2201.12086)
14//! - [Hugging Face Implementation](https://huggingface.co/docs/transformers/model_doc/blip)
15//!
16
17use super::quantized_blip_text as blip_text;
18use crate::quantized_nn::{layer_norm, linear, Linear};
19pub use crate::quantized_var_builder::VarBuilder;
20use candle::{Module, Result, Tensor, D};
21use candle_nn::{Conv2d, Conv2dConfig, LayerNorm};
22
23pub type VisionConfig = super::blip::VisionConfig;
24pub type Config = super::blip::Config;
25
26#[derive(Debug, Clone)]
27struct VisionEmbeddings {
28    class_embedding: Tensor,
29    patch_embedding: Conv2d,
30    position_embedding: Tensor,
31}
32
33impl VisionEmbeddings {
34    fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
35        let class_embedding = vb
36            .get((1, 1, cfg.hidden_size), "class_embedding")?
37            .dequantize(vb.device())?;
38        let conv_cfg = Conv2dConfig {
39            stride: cfg.patch_size,
40            ..Default::default()
41        };
42        let pe_vb = vb.pp("patch_embedding");
43        let pe_weight = pe_vb
44            .get(
45                (cfg.hidden_size, 3, cfg.patch_size, cfg.patch_size),
46                "weight",
47            )?
48            .dequantize(vb.device())?;
49        let pe_bias = pe_vb
50            .get(cfg.hidden_size, "bias")?
51            .dequantize(vb.device())?;
52
53        let patch_embedding = Conv2d::new(pe_weight, Some(pe_bias), conv_cfg);
54        let num_patches1 = cfg.image_size / cfg.patch_size;
55        let num_patches = num_patches1 * num_patches1;
56        let num_positions = num_patches + 1;
57        let position_embedding = vb
58            .get((1, num_positions, cfg.hidden_size), "position_embedding")?
59            .dequantize(vb.device())?;
60        Ok(Self {
61            class_embedding,
62            patch_embedding,
63            position_embedding,
64        })
65    }
66}
67
68impl Module for VisionEmbeddings {
69    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
70        let target_dtype = xs.dtype();
71        let b_size = xs.dim(0)?;
72        let patch_embeds = xs.apply(&self.patch_embedding)?.flatten_from(2)?.t()?;
73        let d = self.class_embedding.dim(D::Minus1)?;
74        let class_embeds = self
75            .class_embedding
76            .broadcast_as((b_size, 1, d))?
77            .to_dtype(target_dtype)?;
78        let embeddings = Tensor::cat(&[&class_embeds, &patch_embeds], 1)?;
79        let position_embedding = self.position_embedding.narrow(1, 0, embeddings.dim(1)?)?;
80        embeddings.broadcast_add(&position_embedding)
81    }
82}
83
84#[derive(Debug, Clone)]
85struct Attention {
86    qkv: Linear,
87    projection: Linear,
88    scale: f64,
89    num_heads: usize,
90}
91
92impl Attention {
93    fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
94        let embed_dim = cfg.hidden_size;
95        let num_heads = cfg.num_attention_heads;
96        let head_dim = embed_dim / num_heads;
97        let scale = 1f64 / (head_dim as f64).sqrt();
98        let qkv = linear(embed_dim, 3 * embed_dim, vb.pp("qkv"))?;
99        let projection = linear(embed_dim, embed_dim, vb.pp("projection"))?;
100        Ok(Self {
101            qkv,
102            projection,
103            scale,
104            num_heads,
105        })
106    }
107
108    fn forward(&self, xs: &Tensor, attn_mask: Option<&Tensor>) -> Result<Tensor> {
109        let (b_sz, tgt_len, embed_dim) = xs.dims3()?;
110        let mixed_qkv = xs
111            .apply(&self.qkv)?
112            .reshape((b_sz, tgt_len, 3, self.num_heads, embed_dim / self.num_heads))?
113            .permute((2, 0, 3, 1, 4))?;
114        let query = mixed_qkv.get(0)?;
115        let key = mixed_qkv.get(1)?;
116        let value = mixed_qkv.get(2)?;
117        let attention_scores = query.matmul(&key.t()?)?;
118        let attention_scores = (attention_scores * self.scale)?;
119        let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?;
120        let attention_probs = match attn_mask {
121            None => attention_probs,
122            Some(attn_mask) => (attention_probs * attn_mask)?,
123        };
124        attention_probs
125            .matmul(&value)?
126            .permute((0, 2, 1, 3))?
127            .flatten_from(D::Minus2)?
128            .apply(&self.projection)
129    }
130}
131
132#[derive(Debug, Clone)]
133#[allow(clippy::upper_case_acronyms)]
134struct MLP {
135    activation_fn: candle_nn::Activation,
136    fc1: Linear,
137    fc2: Linear,
138}
139
140impl MLP {
141    fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
142        let fc1 = linear(cfg.hidden_size, cfg.intermediate_size, vb.pp("fc1"))?;
143        let fc2 = linear(cfg.intermediate_size, cfg.hidden_size, vb.pp("fc2"))?;
144        Ok(Self {
145            activation_fn: cfg.hidden_act,
146            fc1,
147            fc2,
148        })
149    }
150}
151
152impl Module for MLP {
153    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
154        xs.apply(&self.fc1)?
155            .apply(&self.activation_fn)?
156            .apply(&self.fc2)
157    }
158}
159
160#[derive(Debug, Clone)]
161struct EncoderLayer {
162    self_attn: Attention,
163    layer_norm1: LayerNorm,
164    mlp: MLP,
165    layer_norm2: LayerNorm,
166}
167
168impl EncoderLayer {
169    fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
170        let embed_dim = cfg.hidden_size;
171        let self_attn = Attention::new(cfg, vb.pp("self_attn"))?;
172        let layer_norm1 = layer_norm(embed_dim, cfg.layer_norm_eps, vb.pp("layer_norm1"))?;
173        let layer_norm2 = layer_norm(embed_dim, cfg.layer_norm_eps, vb.pp("layer_norm2"))?;
174        let mlp = MLP::new(cfg, vb.pp("mlp"))?;
175        Ok(Self {
176            self_attn,
177            layer_norm1,
178            mlp,
179            layer_norm2,
180        })
181    }
182
183    fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
184        let residual = xs;
185        let xs = xs.apply(&self.layer_norm1)?;
186        let xs = self.self_attn.forward(&xs, attention_mask)?;
187        let xs = (xs + residual)?;
188
189        let residual = &xs;
190        let xs = xs.apply(&self.layer_norm2)?.apply(&self.mlp)?;
191        xs + residual
192    }
193}
194
195#[derive(Debug, Clone)]
196struct Encoder {
197    layers: Vec<EncoderLayer>,
198}
199
200impl Encoder {
201    fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
202        let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
203        let vb = vb.pp("layers");
204        for i in 0..cfg.num_hidden_layers {
205            let layer = EncoderLayer::new(cfg, vb.pp(i))?;
206            layers.push(layer)
207        }
208        Ok(Self { layers })
209    }
210
211    fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
212        let mut xs = xs.clone();
213        for layer in self.layers.iter() {
214            xs = layer.forward(&xs, attention_mask)?
215        }
216        Ok(xs)
217    }
218}
219
220#[derive(Debug, Clone)]
221pub struct VisionModel {
222    embeddings: VisionEmbeddings,
223    encoder: Encoder,
224    post_layernorm: LayerNorm,
225}
226
227impl VisionModel {
228    fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
229        let embeddings = VisionEmbeddings::new(cfg, vb.pp("embeddings"))?;
230        let encoder = Encoder::new(cfg, vb.pp("encoder"))?;
231        let post_layernorm =
232            layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("post_layernorm"))?;
233        Ok(Self {
234            embeddings,
235            encoder,
236            post_layernorm,
237        })
238    }
239}
240
241impl Module for VisionModel {
242    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
243        let xs = xs.apply(&self.embeddings)?;
244        let encoder_outputs = self.encoder.forward(&xs, None)?;
245        // Return the last hidden state rather than pooled outputs.
246        encoder_outputs.apply(&self.post_layernorm)
247    }
248}
249
250#[derive(Debug, Clone)]
251pub struct BlipForConditionalGeneration {
252    vision_model: VisionModel,
253    text_decoder: blip_text::TextLMHeadModel,
254}
255
256impl BlipForConditionalGeneration {
257    pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
258        let vision_model = VisionModel::new(&cfg.vision_config, vb.pp("vision_model"))?;
259        let text_decoder =
260            blip_text::TextLMHeadModel::new(&cfg.text_config, vb.pp("text_decoder"))?;
261        Ok(Self {
262            vision_model,
263            text_decoder,
264        })
265    }
266
267    pub fn vision_model(&self) -> &VisionModel {
268        &self.vision_model
269    }
270
271    pub fn text_decoder(&mut self) -> &mut blip_text::TextLMHeadModel {
272        &mut self.text_decoder
273    }
274    pub fn reset_kv_cache(&mut self) {
275        self.text_decoder.reset_kv_cache();
276    }
277}