ferrum_models/architectures/
clip.rs1use candle_core::{DType, Device as CandleDevice, Tensor};
7use candle_nn::VarBuilder;
8use candle_transformers::models::chinese_clip::{ChineseClipConfig, ChineseClipModel};
9use candle_transformers::models::clip::{self, ClipConfig, ClipModel};
10use candle_transformers::models::siglip;
11use ferrum_types::{FerrumError, Result};
12use parking_lot::Mutex;
13use tracing::info;
14
15enum ClipVariant {
17 OpenAI(ClipModel),
18 Chinese(ChineseClipModel),
19 SigLIP(siglip::Model),
20}
21
22pub struct ClipModelWrapper {
24 model: Mutex<ClipVariant>,
25 device: CandleDevice,
26 dtype: DType,
27 image_size: usize,
28 projection_dim: usize,
29}
30
31impl ClipModelWrapper {
32 pub fn new_openai(
34 vb: VarBuilder,
35 config: &ClipConfig,
36 device: CandleDevice,
37 dtype: DType,
38 ) -> Result<Self> {
39 info!("Loading OpenAI CLIP (image_size={})", config.image_size);
40 let model = ClipModel::new(vb, config)
41 .map_err(|e| FerrumError::model(format!("CLIP load: {e}")))?;
42 Ok(Self {
43 projection_dim: config.vision_config.projection_dim,
44 image_size: config.image_size,
45 model: Mutex::new(ClipVariant::OpenAI(model)),
46 device,
47 dtype,
48 })
49 }
50
51 pub fn new_chinese(
53 vb: VarBuilder,
54 config: &ChineseClipConfig,
55 device: CandleDevice,
56 dtype: DType,
57 ) -> Result<Self> {
58 info!(
59 "Loading Chinese-CLIP (image_size={}, projection_dim={})",
60 config.image_size, config.projection_dim
61 );
62 let model = ChineseClipModel::new(vb, config)
63 .map_err(|e| FerrumError::model(format!("Chinese-CLIP load: {e}")))?;
64 Ok(Self {
65 projection_dim: config.projection_dim,
66 image_size: config.image_size,
67 model: Mutex::new(ClipVariant::Chinese(model)),
68 device,
69 dtype,
70 })
71 }
72
73 pub fn new_siglip(
75 vb: VarBuilder,
76 config: &siglip::Config,
77 device: CandleDevice,
78 dtype: DType,
79 ) -> Result<Self> {
80 let image_size = config.vision_config.image_size;
81 let projection_dim = config.vision_config.hidden_size;
82 info!(
83 "Loading SigLIP (image_size={}, hidden_size={})",
84 image_size, projection_dim
85 );
86 let model = siglip::Model::new(config, vb)
87 .map_err(|e| FerrumError::model(format!("SigLIP load: {e}")))?;
88 Ok(Self {
89 projection_dim,
90 image_size,
91 model: Mutex::new(ClipVariant::SigLIP(model)),
92 device,
93 dtype,
94 })
95 }
96
97 pub fn from_config_json(
102 vb: VarBuilder,
103 config_path: &std::path::Path,
104 device: CandleDevice,
105 dtype: DType,
106 ) -> Result<Self> {
107 let raw: serde_json::Value = serde_json::from_str(
108 &std::fs::read_to_string(config_path)
109 .map_err(|e| FerrumError::model(format!("read config: {e}")))?,
110 )
111 .map_err(|e| FerrumError::model(format!("parse config: {e}")))?;
112
113 let model_type = raw.get("model_type").and_then(|v| v.as_str()).unwrap_or("");
114
115 if model_type == "siglip" {
116 let config: siglip::Config =
118 serde_json::from_value(raw).unwrap_or_else(|_| siglip::Config::base_patch16_224());
119 return Self::new_siglip(vb, &config, device, dtype);
120 }
121
122 if model_type == "chinese_clip" {
123 let mut config = ChineseClipConfig::clip_vit_base_patch16();
124 if let Some(v) = raw.get("projection_dim").and_then(|v| v.as_u64()) {
125 config.projection_dim = v as usize;
126 }
127 if let Some(vc) = raw.get("vision_config") {
128 if let Some(v) = vc.get("image_size").and_then(|v| v.as_u64()) {
129 config.vision_config.image_size = v as usize;
130 config.image_size = v as usize;
131 }
132 }
133 Self::new_chinese(vb, &config, device, dtype)
134 } else {
135 let mut config = ClipConfig::vit_base_patch32();
136 if let Some(vc) = raw.get("vision_config") {
137 if let Some(v) = vc.get("image_size").and_then(|v| v.as_u64()) {
138 config.vision_config.image_size = v as usize;
139 config.image_size = v as usize;
140 }
141 if let Some(v) = vc.get("projection_dim").and_then(|v| v.as_u64()) {
142 config.vision_config.projection_dim = v as usize;
143 }
144 }
145 Self::new_openai(vb, &config, device, dtype)
146 }
147 }
148
149 pub fn get_text_features(&self, input_ids: &Tensor) -> Result<Tensor> {
151 let model = self.model.lock();
152 let features = match &*model {
153 ClipVariant::OpenAI(m) => m
154 .get_text_features(input_ids)
155 .map_err(|e| FerrumError::model(format!("text features: {e}")))?,
156 ClipVariant::Chinese(m) => m
157 .get_text_features(input_ids, None, None)
158 .map_err(|e| FerrumError::model(format!("text features: {e}")))?,
159 ClipVariant::SigLIP(m) => m
160 .get_text_features(input_ids)
161 .map_err(|e| FerrumError::model(format!("text features: {e}")))?,
162 };
163 clip::div_l2_norm(&features).map_err(|e| FerrumError::model(format!("l2 norm: {e}")))
164 }
165
166 pub fn get_image_features(&self, pixel_values: &Tensor) -> Result<Tensor> {
168 let model = self.model.lock();
169 let features = match &*model {
170 ClipVariant::OpenAI(m) => m
171 .get_image_features(pixel_values)
172 .map_err(|e| FerrumError::model(format!("image features: {e}")))?,
173 ClipVariant::Chinese(m) => m
174 .get_image_features(pixel_values)
175 .map_err(|e| FerrumError::model(format!("image features: {e}")))?,
176 ClipVariant::SigLIP(m) => m
177 .get_image_features(pixel_values)
178 .map_err(|e| FerrumError::model(format!("image features: {e}")))?,
179 };
180 clip::div_l2_norm(&features).map_err(|e| FerrumError::model(format!("l2 norm: {e}")))
181 }
182
183 pub fn device(&self) -> &CandleDevice {
184 &self.device
185 }
186
187 pub fn dtype(&self) -> DType {
188 self.dtype
189 }
190
191 pub fn image_size(&self) -> usize {
192 self.image_size
193 }
194
195 pub fn projection_dim(&self) -> usize {
196 self.projection_dim
197 }
198}