Skip to main content

ferrum_models/multimodal/
clip.rs

1//! CLIP model wrapper — supports OpenAI CLIP, Chinese-CLIP, and SigLIP.
2//!
3//! Wraps candle-transformers' ClipModel / ChineseClipModel / siglip::Model
4//! with a unified interface for text and image embedding extraction.
5
6use candle_core::{DType, Device as CandleDevice, Tensor};
7use candle_nn::VarBuilder;
8use candle_transformers::models::chinese_clip::{ChineseClipConfig, ChineseClipModel};
9use candle_transformers::models::clip::{self, ClipConfig, ClipModel};
10use candle_transformers::models::siglip;
11use ferrum_types::{FerrumError, Result};
12use parking_lot::Mutex;
13use tracing::info;
14
15/// Which CLIP variant is loaded.
16enum ClipVariant {
17    OpenAI(ClipModel),
18    Chinese(ChineseClipModel),
19    SigLIP(siglip::Model),
20}
21
22/// Unified CLIP model wrapper.
23pub struct ClipModelWrapper {
24    model: Mutex<ClipVariant>,
25    device: CandleDevice,
26    dtype: DType,
27    image_size: usize,
28    projection_dim: usize,
29}
30
31impl ClipModelWrapper {
32    /// Load OpenAI CLIP from VarBuilder.
33    pub fn new_openai(
34        vb: VarBuilder,
35        config: &ClipConfig,
36        device: CandleDevice,
37        dtype: DType,
38    ) -> Result<Self> {
39        info!("Loading OpenAI CLIP (image_size={})", config.image_size);
40        let model = ClipModel::new(vb, config)
41            .map_err(|e| FerrumError::model(format!("CLIP load: {e}")))?;
42        Ok(Self {
43            projection_dim: config.vision_config.projection_dim,
44            image_size: config.image_size,
45            model: Mutex::new(ClipVariant::OpenAI(model)),
46            device,
47            dtype,
48        })
49    }
50
51    /// Load Chinese-CLIP from VarBuilder.
52    pub fn new_chinese(
53        vb: VarBuilder,
54        config: &ChineseClipConfig,
55        device: CandleDevice,
56        dtype: DType,
57    ) -> Result<Self> {
58        info!(
59            "Loading Chinese-CLIP (image_size={}, projection_dim={})",
60            config.image_size, config.projection_dim
61        );
62        let model = ChineseClipModel::new(vb, config)
63            .map_err(|e| FerrumError::model(format!("Chinese-CLIP load: {e}")))?;
64        Ok(Self {
65            projection_dim: config.projection_dim,
66            image_size: config.image_size,
67            model: Mutex::new(ClipVariant::Chinese(model)),
68            device,
69            dtype,
70        })
71    }
72
73    /// Load SigLIP from VarBuilder.
74    pub fn new_siglip(
75        vb: VarBuilder,
76        config: &siglip::Config,
77        device: CandleDevice,
78        dtype: DType,
79    ) -> Result<Self> {
80        let image_size = config.vision_config.image_size;
81        let projection_dim = config.vision_config.hidden_size;
82        info!(
83            "Loading SigLIP (image_size={}, hidden_size={})",
84            image_size, projection_dim
85        );
86        let model = siglip::Model::new(config, vb)
87            .map_err(|e| FerrumError::model(format!("SigLIP load: {e}")))?;
88        Ok(Self {
89            projection_dim,
90            image_size,
91            model: Mutex::new(ClipVariant::SigLIP(model)),
92            device,
93            dtype,
94        })
95    }
96
97    /// Load from config.json — auto-detects CLIP variant.
98    ///
99    /// candle's ClipConfig doesn't derive Deserialize, so we use preset configs
100    /// and override image_size / projection_dim from the JSON when present.
101    pub fn from_config_json(
102        vb: VarBuilder,
103        config_path: &std::path::Path,
104        device: CandleDevice,
105        dtype: DType,
106    ) -> Result<Self> {
107        let raw: serde_json::Value = serde_json::from_str(
108            &std::fs::read_to_string(config_path)
109                .map_err(|e| FerrumError::model(format!("read config: {e}")))?,
110        )
111        .map_err(|e| FerrumError::model(format!("parse config: {e}")))?;
112
113        let model_type = raw.get("model_type").and_then(|v| v.as_str()).unwrap_or("");
114
115        if model_type == "siglip" {
116            // SigLIP config derives Deserialize — parse directly
117            let config: siglip::Config =
118                serde_json::from_value(raw).unwrap_or_else(|_| siglip::Config::base_patch16_224());
119            return Self::new_siglip(vb, &config, device, dtype);
120        }
121
122        if model_type == "chinese_clip" {
123            let mut config = ChineseClipConfig::clip_vit_base_patch16();
124            if let Some(v) = raw.get("projection_dim").and_then(|v| v.as_u64()) {
125                config.projection_dim = v as usize;
126            }
127            if let Some(vc) = raw.get("vision_config") {
128                if let Some(v) = vc.get("image_size").and_then(|v| v.as_u64()) {
129                    config.vision_config.image_size = v as usize;
130                    config.image_size = v as usize;
131                }
132            }
133            Self::new_chinese(vb, &config, device, dtype)
134        } else {
135            let mut config = ClipConfig::vit_base_patch32();
136            // Override from config.json — supports base/large/any variant that
137            // differs from vit_base_patch32 defaults (e.g. clip-vit-large-patch14
138            // has embed_dim=768 / 24 layers / 16 heads, not 512 / 12 / 8).
139
140            // Top-level projection_dim (shared across text/vision in HF config).
141            if let Some(v) = raw.get("projection_dim").and_then(|v| v.as_u64()) {
142                config.text_config.projection_dim = v as usize;
143                config.vision_config.projection_dim = v as usize;
144            }
145
146            if let Some(tc) = raw.get("text_config") {
147                if let Some(v) = tc.get("hidden_size").and_then(|v| v.as_u64()) {
148                    config.text_config.embed_dim = v as usize;
149                }
150                if let Some(v) = tc.get("intermediate_size").and_then(|v| v.as_u64()) {
151                    config.text_config.intermediate_size = v as usize;
152                }
153                if let Some(v) = tc.get("num_hidden_layers").and_then(|v| v.as_u64()) {
154                    config.text_config.num_hidden_layers = v as usize;
155                }
156                if let Some(v) = tc.get("num_attention_heads").and_then(|v| v.as_u64()) {
157                    config.text_config.num_attention_heads = v as usize;
158                }
159                if let Some(v) = tc.get("vocab_size").and_then(|v| v.as_u64()) {
160                    config.text_config.vocab_size = v as usize;
161                }
162                if let Some(v) = tc.get("max_position_embeddings").and_then(|v| v.as_u64()) {
163                    config.text_config.max_position_embeddings = v as usize;
164                }
165                if let Some(v) = tc.get("projection_dim").and_then(|v| v.as_u64()) {
166                    config.text_config.projection_dim = v as usize;
167                }
168            }
169            if let Some(vc) = raw.get("vision_config") {
170                if let Some(v) = vc.get("image_size").and_then(|v| v.as_u64()) {
171                    config.vision_config.image_size = v as usize;
172                    config.image_size = v as usize;
173                }
174                if let Some(v) = vc.get("projection_dim").and_then(|v| v.as_u64()) {
175                    config.vision_config.projection_dim = v as usize;
176                }
177                if let Some(v) = vc.get("hidden_size").and_then(|v| v.as_u64()) {
178                    config.vision_config.embed_dim = v as usize;
179                }
180                if let Some(v) = vc.get("intermediate_size").and_then(|v| v.as_u64()) {
181                    config.vision_config.intermediate_size = v as usize;
182                }
183                if let Some(v) = vc.get("num_hidden_layers").and_then(|v| v.as_u64()) {
184                    config.vision_config.num_hidden_layers = v as usize;
185                }
186                if let Some(v) = vc.get("num_attention_heads").and_then(|v| v.as_u64()) {
187                    config.vision_config.num_attention_heads = v as usize;
188                }
189                if let Some(v) = vc.get("patch_size").and_then(|v| v.as_u64()) {
190                    config.vision_config.patch_size = v as usize;
191                }
192            }
193            Self::new_openai(vb, &config, device, dtype)
194        }
195    }
196
197    /// Get text embedding (L2-normalized).
198    pub fn get_text_features(&self, input_ids: &Tensor) -> Result<Tensor> {
199        let model = self.model.lock();
200        let features = match &*model {
201            ClipVariant::OpenAI(m) => m
202                .get_text_features(input_ids)
203                .map_err(|e| FerrumError::model(format!("text features: {e}")))?,
204            ClipVariant::Chinese(m) => m
205                .get_text_features(input_ids, None, None)
206                .map_err(|e| FerrumError::model(format!("text features: {e}")))?,
207            ClipVariant::SigLIP(m) => m
208                .get_text_features(input_ids)
209                .map_err(|e| FerrumError::model(format!("text features: {e}")))?,
210        };
211        clip::div_l2_norm(&features).map_err(|e| FerrumError::model(format!("l2 norm: {e}")))
212    }
213
214    /// Get image embedding (L2-normalized).
215    pub fn get_image_features(&self, pixel_values: &Tensor) -> Result<Tensor> {
216        let model = self.model.lock();
217        let features = match &*model {
218            ClipVariant::OpenAI(m) => m
219                .get_image_features(pixel_values)
220                .map_err(|e| FerrumError::model(format!("image features: {e}")))?,
221            ClipVariant::Chinese(m) => m
222                .get_image_features(pixel_values)
223                .map_err(|e| FerrumError::model(format!("image features: {e}")))?,
224            ClipVariant::SigLIP(m) => m
225                .get_image_features(pixel_values)
226                .map_err(|e| FerrumError::model(format!("image features: {e}")))?,
227        };
228        clip::div_l2_norm(&features).map_err(|e| FerrumError::model(format!("l2 norm: {e}")))
229    }
230
231    pub fn device(&self) -> &CandleDevice {
232        &self.device
233    }
234
235    pub fn dtype(&self) -> DType {
236        self.dtype
237    }
238
239    pub fn image_size(&self) -> usize {
240        self.image_size
241    }
242
243    pub fn projection_dim(&self) -> usize {
244        self.projection_dim
245    }
246}