ferrum_models/architectures/
clip.rs1use candle_core::{DType, Device as CandleDevice, Tensor};
7use candle_nn::VarBuilder;
8use candle_transformers::models::chinese_clip::{ChineseClipConfig, ChineseClipModel};
9use candle_transformers::models::clip::{self, ClipConfig, ClipModel};
10use candle_transformers::models::siglip;
11use ferrum_types::{FerrumError, Result};
12use parking_lot::Mutex;
13use tracing::info;
14
15enum ClipVariant {
17 OpenAI(ClipModel),
18 Chinese(ChineseClipModel),
19 SigLIP(siglip::Model),
20}
21
22pub struct ClipModelWrapper {
24 model: Mutex<ClipVariant>,
25 device: CandleDevice,
26 dtype: DType,
27 image_size: usize,
28 projection_dim: usize,
29}
30
31impl ClipModelWrapper {
32 pub fn new_openai(
34 vb: VarBuilder,
35 config: &ClipConfig,
36 device: CandleDevice,
37 dtype: DType,
38 ) -> Result<Self> {
39 info!("Loading OpenAI CLIP (image_size={})", config.image_size);
40 let model = ClipModel::new(vb, config)
41 .map_err(|e| FerrumError::model(format!("CLIP load: {e}")))?;
42 Ok(Self {
43 projection_dim: config.vision_config.projection_dim,
44 image_size: config.image_size,
45 model: Mutex::new(ClipVariant::OpenAI(model)),
46 device,
47 dtype,
48 })
49 }
50
51 pub fn new_chinese(
53 vb: VarBuilder,
54 config: &ChineseClipConfig,
55 device: CandleDevice,
56 dtype: DType,
57 ) -> Result<Self> {
58 info!(
59 "Loading Chinese-CLIP (image_size={}, projection_dim={})",
60 config.image_size, config.projection_dim
61 );
62 let model = ChineseClipModel::new(vb, config)
63 .map_err(|e| FerrumError::model(format!("Chinese-CLIP load: {e}")))?;
64 Ok(Self {
65 projection_dim: config.projection_dim,
66 image_size: config.image_size,
67 model: Mutex::new(ClipVariant::Chinese(model)),
68 device,
69 dtype,
70 })
71 }
72
73 pub fn new_siglip(
75 vb: VarBuilder,
76 config: &siglip::Config,
77 device: CandleDevice,
78 dtype: DType,
79 ) -> Result<Self> {
80 let image_size = config.vision_config.image_size;
81 let projection_dim = config.vision_config.hidden_size;
82 info!(
83 "Loading SigLIP (image_size={}, hidden_size={})",
84 image_size, projection_dim
85 );
86 let model = siglip::Model::new(config, vb)
87 .map_err(|e| FerrumError::model(format!("SigLIP load: {e}")))?;
88 Ok(Self {
89 projection_dim,
90 image_size,
91 model: Mutex::new(ClipVariant::SigLIP(model)),
92 device,
93 dtype,
94 })
95 }
96
97 pub fn from_config_json(
102 vb: VarBuilder,
103 config_path: &std::path::Path,
104 device: CandleDevice,
105 dtype: DType,
106 ) -> Result<Self> {
107 let raw: serde_json::Value = serde_json::from_str(
108 &std::fs::read_to_string(config_path)
109 .map_err(|e| FerrumError::model(format!("read config: {e}")))?,
110 )
111 .map_err(|e| FerrumError::model(format!("parse config: {e}")))?;
112
113 let model_type = raw.get("model_type").and_then(|v| v.as_str()).unwrap_or("");
114
115 if model_type == "siglip" {
116 let config: siglip::Config =
118 serde_json::from_value(raw).unwrap_or_else(|_| siglip::Config::base_patch16_224());
119 return Self::new_siglip(vb, &config, device, dtype);
120 }
121
122 if model_type == "chinese_clip" {
123 let mut config = ChineseClipConfig::clip_vit_base_patch16();
124 if let Some(v) = raw.get("projection_dim").and_then(|v| v.as_u64()) {
125 config.projection_dim = v as usize;
126 }
127 if let Some(vc) = raw.get("vision_config") {
128 if let Some(v) = vc.get("image_size").and_then(|v| v.as_u64()) {
129 config.vision_config.image_size = v as usize;
130 config.image_size = v as usize;
131 }
132 }
133 Self::new_chinese(vb, &config, device, dtype)
134 } else {
135 let mut config = ClipConfig::vit_base_patch32();
136 if let Some(v) = raw.get("projection_dim").and_then(|v| v.as_u64()) {
142 config.text_config.projection_dim = v as usize;
143 config.vision_config.projection_dim = v as usize;
144 }
145
146 if let Some(tc) = raw.get("text_config") {
147 if let Some(v) = tc.get("hidden_size").and_then(|v| v.as_u64()) {
148 config.text_config.embed_dim = v as usize;
149 }
150 if let Some(v) = tc.get("intermediate_size").and_then(|v| v.as_u64()) {
151 config.text_config.intermediate_size = v as usize;
152 }
153 if let Some(v) = tc.get("num_hidden_layers").and_then(|v| v.as_u64()) {
154 config.text_config.num_hidden_layers = v as usize;
155 }
156 if let Some(v) = tc.get("num_attention_heads").and_then(|v| v.as_u64()) {
157 config.text_config.num_attention_heads = v as usize;
158 }
159 if let Some(v) = tc.get("vocab_size").and_then(|v| v.as_u64()) {
160 config.text_config.vocab_size = v as usize;
161 }
162 if let Some(v) = tc.get("max_position_embeddings").and_then(|v| v.as_u64()) {
163 config.text_config.max_position_embeddings = v as usize;
164 }
165 if let Some(v) = tc.get("projection_dim").and_then(|v| v.as_u64()) {
166 config.text_config.projection_dim = v as usize;
167 }
168 }
169 if let Some(vc) = raw.get("vision_config") {
170 if let Some(v) = vc.get("image_size").and_then(|v| v.as_u64()) {
171 config.vision_config.image_size = v as usize;
172 config.image_size = v as usize;
173 }
174 if let Some(v) = vc.get("projection_dim").and_then(|v| v.as_u64()) {
175 config.vision_config.projection_dim = v as usize;
176 }
177 if let Some(v) = vc.get("hidden_size").and_then(|v| v.as_u64()) {
178 config.vision_config.embed_dim = v as usize;
179 }
180 if let Some(v) = vc.get("intermediate_size").and_then(|v| v.as_u64()) {
181 config.vision_config.intermediate_size = v as usize;
182 }
183 if let Some(v) = vc.get("num_hidden_layers").and_then(|v| v.as_u64()) {
184 config.vision_config.num_hidden_layers = v as usize;
185 }
186 if let Some(v) = vc.get("num_attention_heads").and_then(|v| v.as_u64()) {
187 config.vision_config.num_attention_heads = v as usize;
188 }
189 if let Some(v) = vc.get("patch_size").and_then(|v| v.as_u64()) {
190 config.vision_config.patch_size = v as usize;
191 }
192 }
193 Self::new_openai(vb, &config, device, dtype)
194 }
195 }
196
197 pub fn get_text_features(&self, input_ids: &Tensor) -> Result<Tensor> {
199 let model = self.model.lock();
200 let features = match &*model {
201 ClipVariant::OpenAI(m) => m
202 .get_text_features(input_ids)
203 .map_err(|e| FerrumError::model(format!("text features: {e}")))?,
204 ClipVariant::Chinese(m) => m
205 .get_text_features(input_ids, None, None)
206 .map_err(|e| FerrumError::model(format!("text features: {e}")))?,
207 ClipVariant::SigLIP(m) => m
208 .get_text_features(input_ids)
209 .map_err(|e| FerrumError::model(format!("text features: {e}")))?,
210 };
211 clip::div_l2_norm(&features).map_err(|e| FerrumError::model(format!("l2 norm: {e}")))
212 }
213
214 pub fn get_image_features(&self, pixel_values: &Tensor) -> Result<Tensor> {
216 let model = self.model.lock();
217 let features = match &*model {
218 ClipVariant::OpenAI(m) => m
219 .get_image_features(pixel_values)
220 .map_err(|e| FerrumError::model(format!("image features: {e}")))?,
221 ClipVariant::Chinese(m) => m
222 .get_image_features(pixel_values)
223 .map_err(|e| FerrumError::model(format!("image features: {e}")))?,
224 ClipVariant::SigLIP(m) => m
225 .get_image_features(pixel_values)
226 .map_err(|e| FerrumError::model(format!("image features: {e}")))?,
227 };
228 clip::div_l2_norm(&features).map_err(|e| FerrumError::model(format!("l2 norm: {e}")))
229 }
230
231 pub fn device(&self) -> &CandleDevice {
232 &self.device
233 }
234
235 pub fn dtype(&self) -> DType {
236 self.dtype
237 }
238
239 pub fn image_size(&self) -> usize {
240 self.image_size
241 }
242
243 pub fn projection_dim(&self) -> usize {
244 self.projection_dim
245 }
246}