Skip to main content

ferrum_models/architectures/
clip.rs

1//! CLIP model wrapper — supports OpenAI CLIP, Chinese-CLIP, and SigLIP.
2//!
3//! Wraps candle-transformers' ClipModel / ChineseClipModel / siglip::Model
4//! with a unified interface for text and image embedding extraction.
5
6use candle_core::{DType, Device as CandleDevice, Tensor};
7use candle_nn::VarBuilder;
8use candle_transformers::models::chinese_clip::{ChineseClipConfig, ChineseClipModel};
9use candle_transformers::models::clip::{self, ClipConfig, ClipModel};
10use candle_transformers::models::siglip;
11use ferrum_types::{FerrumError, Result};
12use parking_lot::Mutex;
13use tracing::info;
14
15/// Which CLIP variant is loaded.
16enum ClipVariant {
17    OpenAI(ClipModel),
18    Chinese(ChineseClipModel),
19    SigLIP(siglip::Model),
20}
21
22/// Unified CLIP model wrapper.
23pub struct ClipModelWrapper {
24    model: Mutex<ClipVariant>,
25    device: CandleDevice,
26    dtype: DType,
27    image_size: usize,
28    projection_dim: usize,
29}
30
31impl ClipModelWrapper {
32    /// Load OpenAI CLIP from VarBuilder.
33    pub fn new_openai(
34        vb: VarBuilder,
35        config: &ClipConfig,
36        device: CandleDevice,
37        dtype: DType,
38    ) -> Result<Self> {
39        info!("Loading OpenAI CLIP (image_size={})", config.image_size);
40        let model = ClipModel::new(vb, config)
41            .map_err(|e| FerrumError::model(format!("CLIP load: {e}")))?;
42        Ok(Self {
43            projection_dim: config.vision_config.projection_dim,
44            image_size: config.image_size,
45            model: Mutex::new(ClipVariant::OpenAI(model)),
46            device,
47            dtype,
48        })
49    }
50
51    /// Load Chinese-CLIP from VarBuilder.
52    pub fn new_chinese(
53        vb: VarBuilder,
54        config: &ChineseClipConfig,
55        device: CandleDevice,
56        dtype: DType,
57    ) -> Result<Self> {
58        info!(
59            "Loading Chinese-CLIP (image_size={}, projection_dim={})",
60            config.image_size, config.projection_dim
61        );
62        let model = ChineseClipModel::new(vb, config)
63            .map_err(|e| FerrumError::model(format!("Chinese-CLIP load: {e}")))?;
64        Ok(Self {
65            projection_dim: config.projection_dim,
66            image_size: config.image_size,
67            model: Mutex::new(ClipVariant::Chinese(model)),
68            device,
69            dtype,
70        })
71    }
72
73    /// Load SigLIP from VarBuilder.
74    pub fn new_siglip(
75        vb: VarBuilder,
76        config: &siglip::Config,
77        device: CandleDevice,
78        dtype: DType,
79    ) -> Result<Self> {
80        let image_size = config.vision_config.image_size;
81        let projection_dim = config.vision_config.hidden_size;
82        info!(
83            "Loading SigLIP (image_size={}, hidden_size={})",
84            image_size, projection_dim
85        );
86        let model = siglip::Model::new(config, vb)
87            .map_err(|e| FerrumError::model(format!("SigLIP load: {e}")))?;
88        Ok(Self {
89            projection_dim,
90            image_size,
91            model: Mutex::new(ClipVariant::SigLIP(model)),
92            device,
93            dtype,
94        })
95    }
96
97    /// Load from config.json — auto-detects CLIP variant.
98    ///
99    /// candle's ClipConfig doesn't derive Deserialize, so we use preset configs
100    /// and override image_size / projection_dim from the JSON when present.
101    pub fn from_config_json(
102        vb: VarBuilder,
103        config_path: &std::path::Path,
104        device: CandleDevice,
105        dtype: DType,
106    ) -> Result<Self> {
107        let raw: serde_json::Value = serde_json::from_str(
108            &std::fs::read_to_string(config_path)
109                .map_err(|e| FerrumError::model(format!("read config: {e}")))?,
110        )
111        .map_err(|e| FerrumError::model(format!("parse config: {e}")))?;
112
113        let model_type = raw.get("model_type").and_then(|v| v.as_str()).unwrap_or("");
114
115        if model_type == "siglip" {
116            // SigLIP config derives Deserialize — parse directly
117            let config: siglip::Config =
118                serde_json::from_value(raw).unwrap_or_else(|_| siglip::Config::base_patch16_224());
119            return Self::new_siglip(vb, &config, device, dtype);
120        }
121
122        if model_type == "chinese_clip" {
123            let mut config = ChineseClipConfig::clip_vit_base_patch16();
124            if let Some(v) = raw.get("projection_dim").and_then(|v| v.as_u64()) {
125                config.projection_dim = v as usize;
126            }
127            if let Some(vc) = raw.get("vision_config") {
128                if let Some(v) = vc.get("image_size").and_then(|v| v.as_u64()) {
129                    config.vision_config.image_size = v as usize;
130                    config.image_size = v as usize;
131                }
132            }
133            Self::new_chinese(vb, &config, device, dtype)
134        } else {
135            let mut config = ClipConfig::vit_base_patch32();
136            if let Some(vc) = raw.get("vision_config") {
137                if let Some(v) = vc.get("image_size").and_then(|v| v.as_u64()) {
138                    config.vision_config.image_size = v as usize;
139                    config.image_size = v as usize;
140                }
141                if let Some(v) = vc.get("projection_dim").and_then(|v| v.as_u64()) {
142                    config.vision_config.projection_dim = v as usize;
143                }
144            }
145            Self::new_openai(vb, &config, device, dtype)
146        }
147    }
148
149    /// Get text embedding (L2-normalized).
150    pub fn get_text_features(&self, input_ids: &Tensor) -> Result<Tensor> {
151        let model = self.model.lock();
152        let features = match &*model {
153            ClipVariant::OpenAI(m) => m
154                .get_text_features(input_ids)
155                .map_err(|e| FerrumError::model(format!("text features: {e}")))?,
156            ClipVariant::Chinese(m) => m
157                .get_text_features(input_ids, None, None)
158                .map_err(|e| FerrumError::model(format!("text features: {e}")))?,
159            ClipVariant::SigLIP(m) => m
160                .get_text_features(input_ids)
161                .map_err(|e| FerrumError::model(format!("text features: {e}")))?,
162        };
163        clip::div_l2_norm(&features).map_err(|e| FerrumError::model(format!("l2 norm: {e}")))
164    }
165
166    /// Get image embedding (L2-normalized).
167    pub fn get_image_features(&self, pixel_values: &Tensor) -> Result<Tensor> {
168        let model = self.model.lock();
169        let features = match &*model {
170            ClipVariant::OpenAI(m) => m
171                .get_image_features(pixel_values)
172                .map_err(|e| FerrumError::model(format!("image features: {e}")))?,
173            ClipVariant::Chinese(m) => m
174                .get_image_features(pixel_values)
175                .map_err(|e| FerrumError::model(format!("image features: {e}")))?,
176            ClipVariant::SigLIP(m) => m
177                .get_image_features(pixel_values)
178                .map_err(|e| FerrumError::model(format!("image features: {e}")))?,
179        };
180        clip::div_l2_norm(&features).map_err(|e| FerrumError::model(format!("l2 norm: {e}")))
181    }
182
183    pub fn device(&self) -> &CandleDevice {
184        &self.device
185    }
186
187    pub fn dtype(&self) -> DType {
188        self.dtype
189    }
190
191    pub fn image_size(&self) -> usize {
192        self.image_size
193    }
194
195    pub fn projection_dim(&self) -> usize {
196        self.projection_dim
197    }
198}