1use std::collections::HashMap;
7use std::sync::Arc;
8
9use async_trait::async_trait;
10use candle_core::{DType, Device as CandleDevice, Tensor};
11use candle_nn::VarBuilder;
12use ferrum_interfaces::{
13 model_executor::{
14 AttentionType, DecodeInput, DecodeOutput, ExecutorCapabilities, ExecutorMemoryUsage,
15 ExecutorState, ExecutorStatus, MemoryRequirements, PrefillInput, PrefillOutput,
16 },
17 BlockTable, CacheHandleStats, KvCacheHandle, ModelExecutor, TensorRef,
18};
19use ferrum_types::{DataType, Device, FerrumError, ModelInfo, ModelType, Result};
20use tracing::info;
21
22use super::common;
23use crate::architectures::clip::ClipModelWrapper;
24use crate::image_processor::ClipImageProcessor;
25use crate::tensor_wrapper::CandleTensorWrapper;
26
27pub struct ClipModelExecutor {
29 model: ClipModelWrapper,
30 image_processor: ClipImageProcessor,
31 info: ModelInfo,
32}
33
34impl ClipModelExecutor {
35 pub fn new(model: ClipModelWrapper, info: ModelInfo) -> Self {
36 let image_processor = ClipImageProcessor::new(model.image_size());
37 info!(
38 "Created ClipModelExecutor: {} (dim={}, image_size={})",
39 info.model_id,
40 model.projection_dim(),
41 model.image_size()
42 );
43 Self {
44 model,
45 image_processor,
46 info,
47 }
48 }
49
50 pub fn from_path(model_path: &str, device: CandleDevice, dtype: DType) -> Result<Self> {
52 let dir = std::path::Path::new(model_path);
53 let config_path = dir.join("config.json");
54
55 let safetensors: Vec<_> = std::fs::read_dir(dir)
56 .map_err(|e| FerrumError::model(format!("read dir: {e}")))?
57 .filter_map(|e| e.ok())
58 .map(|e| e.path())
59 .filter(|p| p.extension().map_or(false, |ext| ext == "safetensors"))
60 .collect();
61
62 if safetensors.is_empty() {
63 return Err(FerrumError::model("No safetensors files found"));
64 }
65
66 let vb = unsafe {
67 VarBuilder::from_mmaped_safetensors(&safetensors, dtype, &device)
68 .map_err(|e| FerrumError::model(format!("load weights: {e}")))?
69 };
70
71 let model = ClipModelWrapper::from_config_json(vb, &config_path, device, dtype)?;
72
73 let info = ModelInfo {
74 model_id: ferrum_types::ModelId(model_path.to_string()),
75 model_type: ModelType::Clip,
76 hidden_size: model.projection_dim(),
77 vocab_size: 0,
78 num_layers: 0,
79 num_heads: 0,
80 num_kv_heads: 0,
81 num_parameters: 0,
82 max_sequence_length: 77,
83 device: Device::CPU,
84 dtype: DataType::FP32,
85 version: None,
86 license: None,
87 metadata: HashMap::new(),
88 };
89
90 Ok(Self::new(model, info))
91 }
92
93 pub fn embed_text(&self, input_ids: &[u32]) -> Result<Tensor> {
95 let ids = Tensor::new(input_ids, self.model.device())
96 .map_err(|e| FerrumError::model(format!("tensor: {e}")))?
97 .unsqueeze(0)
98 .map_err(|e| FerrumError::model(format!("unsqueeze: {e}")))?;
99 self.model.get_text_features(&ids)
100 }
101
102 pub fn embed_image_path(&self, path: &str) -> Result<Tensor> {
104 let pixel_values = self
105 .image_processor
106 .process_path(path, self.model.device())?;
107 self.model.get_image_features(&pixel_values)
108 }
109
110 pub fn embed_image_base64(&self, data: &str) -> Result<Tensor> {
112 let pixel_values = self
113 .image_processor
114 .process_base64(data, self.model.device())?;
115 self.model.get_image_features(&pixel_values)
116 }
117
118 pub fn projection_dim(&self) -> usize {
119 self.model.projection_dim()
120 }
121}
122
123#[derive(Clone, Debug)]
125struct DummyClipCache;
126
127impl KvCacheHandle for DummyClipCache {
128 fn block_table(&self) -> &BlockTable {
129 static EMPTY: std::sync::OnceLock<BlockTable> = std::sync::OnceLock::new();
130 EMPTY.get_or_init(|| BlockTable::new(16))
131 }
132
133 fn block_table_mut(&mut self) -> &mut BlockTable {
134 unimplemented!("CLIP does not use KV cache")
135 }
136
137 fn as_any(&self) -> &dyn std::any::Any {
138 self
139 }
140
141 fn device(&self) -> Device {
142 Device::CPU
143 }
144
145 fn num_layers(&self) -> usize {
146 0
147 }
148
149 fn num_heads(&self) -> usize {
150 0
151 }
152
153 fn head_dim(&self) -> usize {
154 0
155 }
156
157 fn key_cache(&self, _layer: usize) -> Result<Option<TensorRef>> {
158 Ok(None)
159 }
160
161 fn value_cache(&self, _layer: usize) -> Result<Option<TensorRef>> {
162 Ok(None)
163 }
164
165 fn clone_handle(&self) -> Result<Arc<dyn KvCacheHandle>> {
166 Ok(Arc::new(self.clone()))
167 }
168
169 fn stats(&self) -> CacheHandleStats {
170 CacheHandleStats {
171 memory_bytes: 0,
172 blocks_allocated: 0,
173 tokens_stored: 0,
174 utilization: 0.0,
175 last_access: std::time::Instant::now(),
176 }
177 }
178
179 fn is_valid(&self) -> bool {
180 true
181 }
182
183 fn cache_id(&self) -> String {
184 "clip_dummy_cache".to_string()
185 }
186}
187
188#[async_trait]
189impl ModelExecutor for ClipModelExecutor {
190 fn info(&self) -> &ModelInfo {
191 &self.info
192 }
193
194 async fn prefill(&self, input: &PrefillInput) -> Result<PrefillOutput> {
195 let tokens: Vec<u32> = input
196 .input_ids
197 .as_any()
198 .downcast_ref::<CandleTensorWrapper>()
199 .and_then(|w| w.inner().flatten_all().and_then(|t| t.to_vec1()).ok())
200 .unwrap_or_default();
201
202 if tokens.is_empty() {
203 return Err(FerrumError::model("Empty input"));
204 }
205
206 let embedding = self.embed_text(&tokens)?;
207 let embedding = embedding
208 .unsqueeze(1)
209 .map_err(|e| FerrumError::model(format!("unsqueeze: {e}")))?;
210 let tensor_ref: TensorRef = Arc::new(CandleTensorWrapper::new(embedding));
211 let cache: Arc<dyn KvCacheHandle> = Arc::new(DummyClipCache);
212
213 Ok(PrefillOutput::new(tensor_ref, cache))
214 }
215
216 async fn decode(&self, _input: &DecodeInput) -> Result<DecodeOutput> {
217 Err(FerrumError::model(
218 "CLIP is an encoder model — decode not supported",
219 ))
220 }
221
222 fn capabilities(&self) -> ExecutorCapabilities {
223 ExecutorCapabilities {
224 max_batch_size: 32,
225 max_sequence_length: 77,
226 attention_mechanisms: vec![AttentionType::MultiHead],
227 supports_dynamic_batching: false,
228 supports_continuous_batching: false,
229 supports_speculative_decoding: false,
230 supports_tensor_parallelism: false,
231 supports_pipeline_parallelism: false,
232 supported_dtypes: vec![DataType::FP32],
233 supported_devices: vec![self.info.device.clone()],
234 memory_requirements: MemoryRequirements {
235 parameter_memory: 600 * 1024 * 1024,
236 activation_memory_per_token: 0,
237 kv_cache_memory_per_token: 0,
238 overhead_memory: 0,
239 },
240 }
241 }
242
243 fn release_cache(&self, _cache_id: &str) {}
244
245 fn status(&self) -> ExecutorStatus {
246 common::default_executor_status()
247 }
248}