Skip to main content

ferrum_models/executor/
clip_executor.rs

1//! CLIP Model Executor for multimodal embeddings.
2//!
3//! Supports both text and image embedding via unified interface.
4//! Text goes through CLIP text encoder, images through vision encoder.
5
6use std::collections::HashMap;
7use std::sync::Arc;
8
9use async_trait::async_trait;
10use candle_core::{DType, Device as CandleDevice, Tensor};
11use candle_nn::VarBuilder;
12use ferrum_interfaces::{
13    model_executor::{
14        AttentionType, DecodeInput, DecodeOutput, ExecutorCapabilities, ExecutorMemoryUsage,
15        ExecutorState, ExecutorStatus, MemoryRequirements, PrefillInput, PrefillOutput,
16    },
17    BlockTable, CacheHandleStats, KvCacheHandle, ModelExecutor, TensorRef,
18};
19use ferrum_types::{DataType, Device, FerrumError, ModelInfo, ModelType, Result};
20use tracing::info;
21
22use super::common;
23use crate::architectures::clip::ClipModelWrapper;
24use crate::image_processor::ClipImageProcessor;
25use crate::tensor_wrapper::CandleTensorWrapper;
26
27/// CLIP executor for text and image embeddings.
28pub struct ClipModelExecutor {
29    model: ClipModelWrapper,
30    image_processor: ClipImageProcessor,
31    info: ModelInfo,
32}
33
34impl ClipModelExecutor {
35    pub fn new(model: ClipModelWrapper, info: ModelInfo) -> Self {
36        let image_processor = ClipImageProcessor::new(model.image_size());
37        info!(
38            "Created ClipModelExecutor: {} (dim={}, image_size={})",
39            info.model_id,
40            model.projection_dim(),
41            model.image_size()
42        );
43        Self {
44            model,
45            image_processor,
46            info,
47        }
48    }
49
50    /// Load from model directory (config.json + safetensors).
51    pub fn from_path(model_path: &str, device: CandleDevice, dtype: DType) -> Result<Self> {
52        let dir = std::path::Path::new(model_path);
53        let config_path = dir.join("config.json");
54
55        let safetensors: Vec<_> = std::fs::read_dir(dir)
56            .map_err(|e| FerrumError::model(format!("read dir: {e}")))?
57            .filter_map(|e| e.ok())
58            .map(|e| e.path())
59            .filter(|p| p.extension().map_or(false, |ext| ext == "safetensors"))
60            .collect();
61
62        if safetensors.is_empty() {
63            return Err(FerrumError::model("No safetensors files found"));
64        }
65
66        let vb = unsafe {
67            VarBuilder::from_mmaped_safetensors(&safetensors, dtype, &device)
68                .map_err(|e| FerrumError::model(format!("load weights: {e}")))?
69        };
70
71        let model = ClipModelWrapper::from_config_json(vb, &config_path, device, dtype)?;
72
73        let info = ModelInfo {
74            model_id: ferrum_types::ModelId(model_path.to_string()),
75            model_type: ModelType::Clip,
76            hidden_size: model.projection_dim(),
77            vocab_size: 0,
78            num_layers: 0,
79            num_heads: 0,
80            num_kv_heads: 0,
81            num_parameters: 0,
82            max_sequence_length: 77,
83            device: Device::CPU,
84            dtype: DataType::FP32,
85            version: None,
86            license: None,
87            metadata: HashMap::new(),
88        };
89
90        Ok(Self::new(model, info))
91    }
92
93    /// Embed text tokens → L2-normalized vector.
94    pub fn embed_text(&self, input_ids: &[u32]) -> Result<Tensor> {
95        let ids = Tensor::new(input_ids, self.model.device())
96            .map_err(|e| FerrumError::model(format!("tensor: {e}")))?
97            .unsqueeze(0)
98            .map_err(|e| FerrumError::model(format!("unsqueeze: {e}")))?;
99        self.model.get_text_features(&ids)
100    }
101
102    /// Embed image from file path → L2-normalized vector.
103    pub fn embed_image_path(&self, path: &str) -> Result<Tensor> {
104        let pixel_values = self
105            .image_processor
106            .process_path(path, self.model.device())?;
107        self.model.get_image_features(&pixel_values)
108    }
109
110    /// Embed image from base64 data → L2-normalized vector.
111    pub fn embed_image_base64(&self, data: &str) -> Result<Tensor> {
112        let pixel_values = self
113            .image_processor
114            .process_base64(data, self.model.device())?;
115        self.model.get_image_features(&pixel_values)
116    }
117
118    pub fn projection_dim(&self) -> usize {
119        self.model.projection_dim()
120    }
121}
122
123// Dummy KV cache for encoder-only CLIP (same pattern as BERT).
124#[derive(Clone, Debug)]
125struct DummyClipCache;
126
127impl KvCacheHandle for DummyClipCache {
128    fn block_table(&self) -> &BlockTable {
129        static EMPTY: std::sync::OnceLock<BlockTable> = std::sync::OnceLock::new();
130        EMPTY.get_or_init(|| BlockTable::new(16))
131    }
132
133    fn block_table_mut(&mut self) -> &mut BlockTable {
134        unimplemented!("CLIP does not use KV cache")
135    }
136
137    fn as_any(&self) -> &dyn std::any::Any {
138        self
139    }
140
141    fn device(&self) -> Device {
142        Device::CPU
143    }
144
145    fn num_layers(&self) -> usize {
146        0
147    }
148
149    fn num_heads(&self) -> usize {
150        0
151    }
152
153    fn head_dim(&self) -> usize {
154        0
155    }
156
157    fn key_cache(&self, _layer: usize) -> Result<Option<TensorRef>> {
158        Ok(None)
159    }
160
161    fn value_cache(&self, _layer: usize) -> Result<Option<TensorRef>> {
162        Ok(None)
163    }
164
165    fn clone_handle(&self) -> Result<Arc<dyn KvCacheHandle>> {
166        Ok(Arc::new(self.clone()))
167    }
168
169    fn stats(&self) -> CacheHandleStats {
170        CacheHandleStats {
171            memory_bytes: 0,
172            blocks_allocated: 0,
173            tokens_stored: 0,
174            utilization: 0.0,
175            last_access: std::time::Instant::now(),
176        }
177    }
178
179    fn is_valid(&self) -> bool {
180        true
181    }
182
183    fn cache_id(&self) -> String {
184        "clip_dummy_cache".to_string()
185    }
186}
187
188#[async_trait]
189impl ModelExecutor for ClipModelExecutor {
190    fn info(&self) -> &ModelInfo {
191        &self.info
192    }
193
194    async fn prefill(&self, input: &PrefillInput) -> Result<PrefillOutput> {
195        let tokens: Vec<u32> = input
196            .input_ids
197            .as_any()
198            .downcast_ref::<CandleTensorWrapper>()
199            .and_then(|w| w.inner().flatten_all().and_then(|t| t.to_vec1()).ok())
200            .unwrap_or_default();
201
202        if tokens.is_empty() {
203            return Err(FerrumError::model("Empty input"));
204        }
205
206        let embedding = self.embed_text(&tokens)?;
207        let embedding = embedding
208            .unsqueeze(1)
209            .map_err(|e| FerrumError::model(format!("unsqueeze: {e}")))?;
210        let tensor_ref: TensorRef = Arc::new(CandleTensorWrapper::new(embedding));
211        let cache: Arc<dyn KvCacheHandle> = Arc::new(DummyClipCache);
212
213        Ok(PrefillOutput::new(tensor_ref, cache))
214    }
215
216    async fn decode(&self, _input: &DecodeInput) -> Result<DecodeOutput> {
217        Err(FerrumError::model(
218            "CLIP is an encoder model — decode not supported",
219        ))
220    }
221
222    fn capabilities(&self) -> ExecutorCapabilities {
223        ExecutorCapabilities {
224            max_batch_size: 32,
225            max_sequence_length: 77,
226            attention_mechanisms: vec![AttentionType::MultiHead],
227            supports_dynamic_batching: false,
228            supports_continuous_batching: false,
229            supports_speculative_decoding: false,
230            supports_tensor_parallelism: false,
231            supports_pipeline_parallelism: false,
232            supported_dtypes: vec![DataType::FP32],
233            supported_devices: vec![self.info.device.clone()],
234            memory_requirements: MemoryRequirements {
235                parameter_memory: 600 * 1024 * 1024,
236                activation_memory_per_token: 0,
237                kv_cache_memory_per_token: 0,
238                overhead_memory: 0,
239            },
240        }
241    }
242
243    fn release_cache(&self, _cache_id: &str) {}
244
245    fn status(&self) -> ExecutorStatus {
246        common::default_executor_status()
247    }
248}