candle_transformers/models/
quantized_blip.rs1use super::quantized_blip_text as blip_text;
18use crate::quantized_nn::{layer_norm, linear, Linear};
19pub use crate::quantized_var_builder::VarBuilder;
20use candle::{Module, Result, Tensor, D};
21use candle_nn::{Conv2d, Conv2dConfig, LayerNorm};
22
23pub type VisionConfig = super::blip::VisionConfig;
24pub type Config = super::blip::Config;
25
26#[derive(Debug, Clone)]
27struct VisionEmbeddings {
28 class_embedding: Tensor,
29 patch_embedding: Conv2d,
30 position_embedding: Tensor,
31}
32
33impl VisionEmbeddings {
34 fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
35 let class_embedding = vb
36 .get((1, 1, cfg.hidden_size), "class_embedding")?
37 .dequantize(vb.device())?;
38 let conv_cfg = Conv2dConfig {
39 stride: cfg.patch_size,
40 ..Default::default()
41 };
42 let pe_vb = vb.pp("patch_embedding");
43 let pe_weight = pe_vb
44 .get(
45 (cfg.hidden_size, 3, cfg.patch_size, cfg.patch_size),
46 "weight",
47 )?
48 .dequantize(vb.device())?;
49 let pe_bias = pe_vb
50 .get(cfg.hidden_size, "bias")?
51 .dequantize(vb.device())?;
52
53 let patch_embedding = Conv2d::new(pe_weight, Some(pe_bias), conv_cfg);
54 let num_patches1 = cfg.image_size / cfg.patch_size;
55 let num_patches = num_patches1 * num_patches1;
56 let num_positions = num_patches + 1;
57 let position_embedding = vb
58 .get((1, num_positions, cfg.hidden_size), "position_embedding")?
59 .dequantize(vb.device())?;
60 Ok(Self {
61 class_embedding,
62 patch_embedding,
63 position_embedding,
64 })
65 }
66}
67
68impl Module for VisionEmbeddings {
69 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
70 let target_dtype = xs.dtype();
71 let b_size = xs.dim(0)?;
72 let patch_embeds = xs.apply(&self.patch_embedding)?.flatten_from(2)?.t()?;
73 let d = self.class_embedding.dim(D::Minus1)?;
74 let class_embeds = self
75 .class_embedding
76 .broadcast_as((b_size, 1, d))?
77 .to_dtype(target_dtype)?;
78 let embeddings = Tensor::cat(&[&class_embeds, &patch_embeds], 1)?;
79 let position_embedding = self.position_embedding.narrow(1, 0, embeddings.dim(1)?)?;
80 embeddings.broadcast_add(&position_embedding)
81 }
82}
83
84#[derive(Debug, Clone)]
85struct Attention {
86 qkv: Linear,
87 projection: Linear,
88 scale: f64,
89 num_heads: usize,
90}
91
92impl Attention {
93 fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
94 let embed_dim = cfg.hidden_size;
95 let num_heads = cfg.num_attention_heads;
96 let head_dim = embed_dim / num_heads;
97 let scale = 1f64 / (head_dim as f64).sqrt();
98 let qkv = linear(embed_dim, 3 * embed_dim, vb.pp("qkv"))?;
99 let projection = linear(embed_dim, embed_dim, vb.pp("projection"))?;
100 Ok(Self {
101 qkv,
102 projection,
103 scale,
104 num_heads,
105 })
106 }
107
108 fn forward(&self, xs: &Tensor, attn_mask: Option<&Tensor>) -> Result<Tensor> {
109 let (b_sz, tgt_len, embed_dim) = xs.dims3()?;
110 let mixed_qkv = xs
111 .apply(&self.qkv)?
112 .reshape((b_sz, tgt_len, 3, self.num_heads, embed_dim / self.num_heads))?
113 .permute((2, 0, 3, 1, 4))?;
114 let query = mixed_qkv.get(0)?;
115 let key = mixed_qkv.get(1)?;
116 let value = mixed_qkv.get(2)?;
117 let attention_scores = query.matmul(&key.t()?)?;
118 let attention_scores = (attention_scores * self.scale)?;
119 let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?;
120 let attention_probs = match attn_mask {
121 None => attention_probs,
122 Some(attn_mask) => (attention_probs * attn_mask)?,
123 };
124 attention_probs
125 .matmul(&value)?
126 .permute((0, 2, 1, 3))?
127 .flatten_from(D::Minus2)?
128 .apply(&self.projection)
129 }
130}
131
132#[derive(Debug, Clone)]
133#[allow(clippy::upper_case_acronyms)]
134struct MLP {
135 activation_fn: candle_nn::Activation,
136 fc1: Linear,
137 fc2: Linear,
138}
139
140impl MLP {
141 fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
142 let fc1 = linear(cfg.hidden_size, cfg.intermediate_size, vb.pp("fc1"))?;
143 let fc2 = linear(cfg.intermediate_size, cfg.hidden_size, vb.pp("fc2"))?;
144 Ok(Self {
145 activation_fn: cfg.hidden_act,
146 fc1,
147 fc2,
148 })
149 }
150}
151
152impl Module for MLP {
153 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
154 xs.apply(&self.fc1)?
155 .apply(&self.activation_fn)?
156 .apply(&self.fc2)
157 }
158}
159
160#[derive(Debug, Clone)]
161struct EncoderLayer {
162 self_attn: Attention,
163 layer_norm1: LayerNorm,
164 mlp: MLP,
165 layer_norm2: LayerNorm,
166}
167
168impl EncoderLayer {
169 fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
170 let embed_dim = cfg.hidden_size;
171 let self_attn = Attention::new(cfg, vb.pp("self_attn"))?;
172 let layer_norm1 = layer_norm(embed_dim, cfg.layer_norm_eps, vb.pp("layer_norm1"))?;
173 let layer_norm2 = layer_norm(embed_dim, cfg.layer_norm_eps, vb.pp("layer_norm2"))?;
174 let mlp = MLP::new(cfg, vb.pp("mlp"))?;
175 Ok(Self {
176 self_attn,
177 layer_norm1,
178 mlp,
179 layer_norm2,
180 })
181 }
182
183 fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
184 let residual = xs;
185 let xs = xs.apply(&self.layer_norm1)?;
186 let xs = self.self_attn.forward(&xs, attention_mask)?;
187 let xs = (xs + residual)?;
188
189 let residual = &xs;
190 let xs = xs.apply(&self.layer_norm2)?.apply(&self.mlp)?;
191 xs + residual
192 }
193}
194
195#[derive(Debug, Clone)]
196struct Encoder {
197 layers: Vec<EncoderLayer>,
198}
199
200impl Encoder {
201 fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
202 let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
203 let vb = vb.pp("layers");
204 for i in 0..cfg.num_hidden_layers {
205 let layer = EncoderLayer::new(cfg, vb.pp(i))?;
206 layers.push(layer)
207 }
208 Ok(Self { layers })
209 }
210
211 fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
212 let mut xs = xs.clone();
213 for layer in self.layers.iter() {
214 xs = layer.forward(&xs, attention_mask)?
215 }
216 Ok(xs)
217 }
218}
219
220#[derive(Debug, Clone)]
221pub struct VisionModel {
222 embeddings: VisionEmbeddings,
223 encoder: Encoder,
224 post_layernorm: LayerNorm,
225}
226
227impl VisionModel {
228 fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
229 let embeddings = VisionEmbeddings::new(cfg, vb.pp("embeddings"))?;
230 let encoder = Encoder::new(cfg, vb.pp("encoder"))?;
231 let post_layernorm =
232 layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("post_layernorm"))?;
233 Ok(Self {
234 embeddings,
235 encoder,
236 post_layernorm,
237 })
238 }
239}
240
241impl Module for VisionModel {
242 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
243 let xs = xs.apply(&self.embeddings)?;
244 let encoder_outputs = self.encoder.forward(&xs, None)?;
245 encoder_outputs.apply(&self.post_layernorm)
247 }
248}
249
250#[derive(Debug, Clone)]
251pub struct BlipForConditionalGeneration {
252 vision_model: VisionModel,
253 text_decoder: blip_text::TextLMHeadModel,
254}
255
256impl BlipForConditionalGeneration {
257 pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
258 let vision_model = VisionModel::new(&cfg.vision_config, vb.pp("vision_model"))?;
259 let text_decoder =
260 blip_text::TextLMHeadModel::new(&cfg.text_config, vb.pp("text_decoder"))?;
261 Ok(Self {
262 vision_model,
263 text_decoder,
264 })
265 }
266
267 pub fn vision_model(&self) -> &VisionModel {
268 &self.vision_model
269 }
270
271 pub fn text_decoder(&mut self) -> &mut blip_text::TextLMHeadModel {
272 &mut self.text_decoder
273 }
274 pub fn reset_kv_cache(&mut self) {
275 self.text_decoder.reset_kv_cache();
276 }
277}