Skip to main content

ferrum_models/executor/
bert_executor.rs

1//! BERT Model Executor for embeddings
2//!
3//! BERT is an encoder model used for generating text embeddings.
4//! Unlike decoder models (LLaMA, Qwen), it doesn't generate tokens.
5
6use std::sync::Arc;
7
8use async_trait::async_trait;
9use candle_core::{DType, Device as CandleDevice, Tensor};
10use candle_nn::VarBuilder;
11use ferrum_interfaces::{
12    model_executor::{
13        AttentionType, DecodeInput, DecodeOutput, ExecutorCapabilities, ExecutorMemoryUsage,
14        ExecutorState, ExecutorStatus, MemoryRequirements, PrefillInput, PrefillOutput,
15    },
16    BlockTable, CacheHandleStats, KvCacheHandle, ModelExecutor, TensorRef,
17};
18use ferrum_types::{DataType, Device, FerrumError, ModelInfo, Result};
19use tracing::{debug, info};
20
21use crate::multimodal::bert::BertModelWrapper;
22use crate::tensor_wrapper::CandleTensorWrapper;
23
24/// BERT Executor for embedding tasks
25pub struct BertModelExecutor {
26    model: BertModelWrapper,
27    info: ModelInfo,
28    device: CandleDevice,
29    status: ExecutorStatus,
30}
31
32impl BertModelExecutor {
33    /// Create a new BERT executor
34    pub fn new(model: BertModelWrapper, model_info: ModelInfo, device: CandleDevice) -> Self {
35        info!(
36            "Created BertModelExecutor for model: {}",
37            model_info.model_id
38        );
39
40        let status = ExecutorStatus {
41            state: ExecutorState::Ready,
42            is_ready: true,
43            current_batch_size: 0,
44            prefill_operations: 0,
45            decode_operations: 0,
46            avg_prefill_time_ms: 0.0,
47            avg_decode_time_ms: 0.0,
48            memory_usage: ExecutorMemoryUsage {
49                allocated_bytes: 0,
50                used_bytes: 0,
51                peak_bytes: 0,
52                utilization_percent: 0.0,
53            },
54            last_operation: None,
55        };
56
57        Self {
58            model,
59            info: model_info,
60            device,
61            status,
62        }
63    }
64
65    /// Load BERT executor from path
66    pub async fn from_path(
67        model_path: &str,
68        model_def: &crate::definition::ModelDefinition,
69        device: CandleDevice,
70    ) -> Result<Self> {
71        info!("Loading BERT model from: {}", model_path);
72
73        let path = std::path::Path::new(model_path);
74
75        // Find safetensors file
76        let safetensors_path = if path.join("model.safetensors").exists() {
77            path.join("model.safetensors")
78        } else {
79            // Look for any .safetensors file
80            std::fs::read_dir(path)
81                .map_err(|e| FerrumError::model(format!("Failed to read model dir: {}", e)))?
82                .filter_map(|e| e.ok())
83                .find(|e| {
84                    e.path()
85                        .extension()
86                        .map_or(false, |ext| ext == "safetensors")
87                })
88                .map(|e| e.path())
89                .ok_or_else(|| FerrumError::model("No safetensors file found"))?
90        };
91
92        info!("Loading weights from: {:?}", safetensors_path);
93
94        // Use F32 for BERT (better compatibility)
95        let dtype = DType::F32;
96
97        // Load weights
98        let vb = unsafe {
99            VarBuilder::from_mmaped_safetensors(&[&safetensors_path], dtype, &device)
100                .map_err(|e| FerrumError::model(format!("Failed to load weights: {}", e)))?
101        };
102
103        // Some BERT checkpoints prefix every tensor with `bert.` (the
104        // canonical google-bert / bert-base-chinese layout); others
105        // (sentence-transformers, MiniLM, etc.) drop the prefix. Probe
106        // and hand candle's BertModel::load the correct `pp(...)` view.
107        //
108        // KNOWN: google-bert / bert-base-chinese also uses TF-style
109        // `LayerNorm.gamma` / `.beta` instead of PyTorch's `weight` /
110        // `bias` — candle BertModel::load doesn't auto-rename them, so
111        // those checkpoints still error after this prefix fix. Tracked
112        // separately; sentence-transformers / MiniLM-style checkpoints
113        // (no prefix, weight/bias names) work end-to-end.
114        let vb = if vb.contains_tensor("bert.embeddings.word_embeddings.weight") {
115            vb.pp("bert")
116        } else {
117            vb
118        };
119
120        // Create model from config.json
121        let config_path = path.join("config.json");
122        let model = BertModelWrapper::from_config_json(vb, &config_path, device.clone(), dtype)?;
123
124        // Create model info
125        let model_info = model_def.to_model_info(model_path.to_string());
126
127        Ok(Self::new(model, model_info, device))
128    }
129
130    /// Get embeddings for input tokens
131    pub fn get_embeddings(&self, input_ids: &[u32]) -> Result<Tensor> {
132        let seq_len = input_ids.len();
133
134        // Create input tensor
135        let input_tensor = Tensor::from_vec(
136            input_ids.iter().map(|&x| x as i64).collect::<Vec<_>>(),
137            (1, seq_len),
138            &self.device,
139        )
140        .map_err(|e| FerrumError::model(format!("Failed to create input tensor: {}", e)))?;
141
142        // Create token type ids (all zeros for single sentence)
143        let token_type_ids = Tensor::zeros((1, seq_len), DType::I64, &self.device)
144            .map_err(|e| FerrumError::model(format!("Failed to create token type ids: {}", e)))?;
145
146        // Get sentence embedding
147        self.model
148            .get_sentence_embedding(&input_tensor, &token_type_ids, None)
149    }
150
151    /// Get model reference
152    pub fn model(&self) -> &BertModelWrapper {
153        &self.model
154    }
155}
156
157/// Dummy KV cache for BERT (not used but required by interface)
158#[derive(Debug, Clone)]
159struct DummyBertCache;
160
161impl KvCacheHandle for DummyBertCache {
162    fn block_table(&self) -> &BlockTable {
163        static EMPTY: std::sync::OnceLock<BlockTable> = std::sync::OnceLock::new();
164        EMPTY.get_or_init(|| BlockTable::new(16))
165    }
166
167    fn block_table_mut(&mut self) -> &mut BlockTable {
168        unimplemented!("BERT does not use KV cache")
169    }
170
171    fn as_any(&self) -> &dyn std::any::Any {
172        self
173    }
174
175    fn device(&self) -> Device {
176        Device::CPU
177    }
178
179    fn num_layers(&self) -> usize {
180        0
181    }
182
183    fn num_heads(&self) -> usize {
184        0
185    }
186
187    fn head_dim(&self) -> usize {
188        0
189    }
190
191    fn key_cache(&self, _layer: usize) -> Result<Option<TensorRef>> {
192        Ok(None)
193    }
194
195    fn value_cache(&self, _layer: usize) -> Result<Option<TensorRef>> {
196        Ok(None)
197    }
198
199    fn clone_handle(&self) -> Result<Arc<dyn KvCacheHandle>> {
200        Ok(Arc::new(self.clone()))
201    }
202
203    fn stats(&self) -> CacheHandleStats {
204        CacheHandleStats {
205            memory_bytes: 0,
206            blocks_allocated: 0,
207            tokens_stored: 0,
208            utilization: 0.0,
209            last_access: std::time::Instant::now(),
210        }
211    }
212
213    fn is_valid(&self) -> bool {
214        true
215    }
216
217    fn cache_id(&self) -> String {
218        "bert_dummy_cache".to_string()
219    }
220}
221
222#[async_trait]
223impl ModelExecutor for BertModelExecutor {
224    fn info(&self) -> &ModelInfo {
225        &self.info
226    }
227
228    /// For BERT, prefill returns the embeddings (not logits)
229    async fn prefill(&self, input: &PrefillInput) -> Result<PrefillOutput> {
230        let token_ids: Vec<u32> = if let Ok(v) = input.input_ids.to_vec_u32() {
231            v
232        } else if let Ok(vf) = input.input_ids.to_vec_f32() {
233            vf.into_iter().map(|x| x as u32).collect()
234        } else {
235            return Err(FerrumError::backend("Unable to extract token ids"));
236        };
237
238        debug!("BERT prefill: {} tokens", token_ids.len());
239
240        let embeddings = self.get_embeddings(&token_ids)?;
241
242        // Wrap as TensorRef
243        let output_tensor: TensorRef = Arc::new(CandleTensorWrapper::new(embeddings));
244        let kv_cache: Arc<dyn KvCacheHandle> = Arc::new(DummyBertCache);
245
246        Ok(PrefillOutput::new(output_tensor, kv_cache))
247    }
248
249    /// BERT doesn't support decode (it's an encoder model)
250    async fn decode(&self, _input: &DecodeInput) -> Result<DecodeOutput> {
251        Err(FerrumError::model(
252            "BERT is an encoder model and does not support token generation. Use prefill() to get embeddings.",
253        ))
254    }
255
256    fn capabilities(&self) -> ExecutorCapabilities {
257        ExecutorCapabilities {
258            max_batch_size: 32,
259            max_sequence_length: self.info.max_sequence_length,
260            attention_mechanisms: vec![AttentionType::MultiHead],
261            supports_dynamic_batching: true,
262            supports_continuous_batching: false,
263            supports_speculative_decoding: false,
264            supports_tensor_parallelism: false,
265            supports_pipeline_parallelism: false,
266            supported_dtypes: vec![DataType::FP32],
267            supported_devices: vec![Device::CPU],
268            memory_requirements: MemoryRequirements {
269                parameter_memory: 0,
270                activation_memory_per_token: 0,
271                kv_cache_memory_per_token: 0,
272                overhead_memory: 0,
273            },
274        }
275    }
276
277    fn status(&self) -> ExecutorStatus {
278        self.status.clone()
279    }
280}