candle_transformers/models/clip/
mod.rs

1//! Contrastive Language-Image Pre-Training
2//!
3//! Contrastive Language-Image Pre-Training (CLIP) is an architecture trained on
4//! pairs of images with related texts.
5//!
6//! - 💻 [GH Link](https://github.com/openai/CLIP)
7//! - 💻 Transformers Python [reference implementation](https://github.com/huggingface/transformers/tree/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip)
8//! - 🤗 [HF Model](https://huggingface.co/openai/clip-vit-large-patch14-336)
9//!
10
11use self::{
12    text_model::{Activation, ClipTextTransformer},
13    vision_model::ClipVisionTransformer,
14};
15use candle::{Result, Tensor, D};
16
17pub mod text_model;
18pub mod vision_model;
19
20#[derive(Clone, Debug)]
21pub struct ClipModel {
22    text_model: ClipTextTransformer,
23    vision_model: ClipVisionTransformer,
24    visual_projection: candle_nn::Linear,
25    text_projection: candle_nn::Linear,
26    logit_scale: Tensor,
27}
28
29#[derive(Clone, Debug)]
30pub enum EncoderConfig {
31    Text(text_model::ClipTextConfig),
32    Vision(vision_model::ClipVisionConfig),
33}
34
35impl EncoderConfig {
36    pub fn embed_dim(&self) -> usize {
37        match self {
38            Self::Text(c) => c.embed_dim,
39            Self::Vision(c) => c.embed_dim,
40        }
41    }
42
43    pub fn num_attention_heads(&self) -> usize {
44        match self {
45            Self::Text(c) => c.num_attention_heads,
46            Self::Vision(c) => c.num_attention_heads,
47        }
48    }
49
50    pub fn intermediate_size(&self) -> usize {
51        match self {
52            Self::Text(c) => c.intermediate_size,
53            Self::Vision(c) => c.intermediate_size,
54        }
55    }
56
57    pub fn num_hidden_layers(&self) -> usize {
58        match self {
59            Self::Text(c) => c.num_hidden_layers,
60            Self::Vision(c) => c.num_hidden_layers,
61        }
62    }
63
64    pub fn activation(&self) -> Activation {
65        match self {
66            Self::Text(_c) => Activation::QuickGelu,
67            Self::Vision(c) => c.activation,
68        }
69    }
70}
71
72#[derive(Clone, Debug)]
73pub struct ClipConfig {
74    pub text_config: text_model::ClipTextConfig,
75    pub vision_config: vision_model::ClipVisionConfig,
76    pub logit_scale_init_value: f32,
77    pub image_size: usize,
78}
79
80impl ClipConfig {
81    // base image size is 224, model size is 600Mb
82    pub fn vit_base_patch32() -> Self {
83        let text_config = text_model::ClipTextConfig::vit_base_patch32();
84        let vision_config = vision_model::ClipVisionConfig::vit_base_patch32();
85
86        Self {
87            text_config,
88            vision_config,
89            logit_scale_init_value: 2.6592,
90            image_size: 224,
91        }
92    }
93}
94
95impl ClipModel {
96    pub fn new(vs: candle_nn::VarBuilder, c: &ClipConfig) -> Result<Self> {
97        let text_model = ClipTextTransformer::new(vs.pp("text_model"), &c.text_config)?;
98        let vision_model = ClipVisionTransformer::new(vs.pp("vision_model"), &c.vision_config)?;
99        let visual_projection = candle_nn::linear_no_bias(
100            c.vision_config.embed_dim,
101            c.vision_config.projection_dim,
102            vs.pp("visual_projection"),
103        )?;
104        let text_projection = candle_nn::linear_no_bias(
105            c.text_config.embed_dim,
106            c.text_config.projection_dim,
107            vs.pp("text_projection"),
108        )?;
109        // originally nn.Parameter
110        let logit_scale = if vs.contains_tensor("logit_scale") {
111            vs.get(&[], "logit_scale")?
112        } else {
113            Tensor::new(&[c.logit_scale_init_value], vs.device())?
114        };
115        Ok(Self {
116            text_model,
117            vision_model,
118            visual_projection,
119            text_projection,
120            logit_scale,
121        })
122    }
123
124    pub fn get_text_features(&self, input_ids: &Tensor) -> Result<Tensor> {
125        input_ids
126            .apply(&self.text_model)?
127            .apply(&self.text_projection)
128    }
129
130    pub fn get_image_features(&self, pixel_values: &Tensor) -> Result<Tensor> {
131        pixel_values
132            .apply(&self.vision_model)?
133            .apply(&self.visual_projection)
134    }
135
136    pub fn forward(&self, pixel_values: &Tensor, input_ids: &Tensor) -> Result<(Tensor, Tensor)> {
137        let image_features = self.get_image_features(pixel_values)?;
138        let text_features = self.get_text_features(input_ids)?;
139        let image_features_normalized = div_l2_norm(&image_features)?;
140        let text_features_normalized = div_l2_norm(&text_features)?;
141        let logits_per_text = text_features_normalized.matmul(&image_features_normalized.t()?)?;
142        let logit_scale = self.logit_scale.exp()?;
143        let logits_per_text = logits_per_text.broadcast_mul(&logit_scale)?;
144        let logits_per_image = logits_per_text.t()?;
145        Ok((logits_per_text, logits_per_image))
146    }
147}
148
149pub fn div_l2_norm(v: &Tensor) -> Result<Tensor> {
150    let l2_norm = v.sqr()?.sum_keepdim(D::Minus1)?.sqrt()?;
151    v.broadcast_div(&l2_norm)
152}