1use std::sync::Arc;
7
8use async_trait::async_trait;
9use candle_core::{DType, Device as CandleDevice, Tensor};
10use candle_nn::VarBuilder;
11use ferrum_interfaces::{
12 model_executor::{
13 AttentionType, DecodeInput, DecodeOutput, ExecutorCapabilities, ExecutorMemoryUsage,
14 ExecutorState, ExecutorStatus, MemoryRequirements, PrefillInput, PrefillOutput,
15 },
16 BlockTable, CacheHandleStats, KvCacheHandle, ModelExecutor, TensorRef,
17};
18use ferrum_types::{DataType, Device, FerrumError, ModelInfo, Result};
19use tracing::{debug, info};
20
21use crate::multimodal::bert::BertModelWrapper;
22use crate::tensor_wrapper::CandleTensorWrapper;
23
24pub struct BertModelExecutor {
26 model: BertModelWrapper,
27 info: ModelInfo,
28 device: CandleDevice,
29 status: ExecutorStatus,
30}
31
32impl BertModelExecutor {
33 pub fn new(model: BertModelWrapper, model_info: ModelInfo, device: CandleDevice) -> Self {
35 info!(
36 "Created BertModelExecutor for model: {}",
37 model_info.model_id
38 );
39
40 let status = ExecutorStatus {
41 state: ExecutorState::Ready,
42 is_ready: true,
43 current_batch_size: 0,
44 prefill_operations: 0,
45 decode_operations: 0,
46 avg_prefill_time_ms: 0.0,
47 avg_decode_time_ms: 0.0,
48 memory_usage: ExecutorMemoryUsage {
49 allocated_bytes: 0,
50 used_bytes: 0,
51 peak_bytes: 0,
52 utilization_percent: 0.0,
53 },
54 last_operation: None,
55 };
56
57 Self {
58 model,
59 info: model_info,
60 device,
61 status,
62 }
63 }
64
65 pub async fn from_path(
67 model_path: &str,
68 model_def: &crate::definition::ModelDefinition,
69 device: CandleDevice,
70 ) -> Result<Self> {
71 info!("Loading BERT model from: {}", model_path);
72
73 let path = std::path::Path::new(model_path);
74
75 let safetensors_path = if path.join("model.safetensors").exists() {
77 path.join("model.safetensors")
78 } else {
79 std::fs::read_dir(path)
81 .map_err(|e| FerrumError::model(format!("Failed to read model dir: {}", e)))?
82 .filter_map(|e| e.ok())
83 .find(|e| {
84 e.path()
85 .extension()
86 .map_or(false, |ext| ext == "safetensors")
87 })
88 .map(|e| e.path())
89 .ok_or_else(|| FerrumError::model("No safetensors file found"))?
90 };
91
92 info!("Loading weights from: {:?}", safetensors_path);
93
94 let dtype = DType::F32;
96
97 let vb = unsafe {
99 VarBuilder::from_mmaped_safetensors(&[&safetensors_path], dtype, &device)
100 .map_err(|e| FerrumError::model(format!("Failed to load weights: {}", e)))?
101 };
102
103 let vb = if vb.contains_tensor("bert.embeddings.word_embeddings.weight") {
115 vb.pp("bert")
116 } else {
117 vb
118 };
119
120 let config_path = path.join("config.json");
122 let model = BertModelWrapper::from_config_json(vb, &config_path, device.clone(), dtype)?;
123
124 let model_info = model_def.to_model_info(model_path.to_string());
126
127 Ok(Self::new(model, model_info, device))
128 }
129
130 pub fn get_embeddings(&self, input_ids: &[u32]) -> Result<Tensor> {
132 let seq_len = input_ids.len();
133
134 let input_tensor = Tensor::from_vec(
136 input_ids.iter().map(|&x| x as i64).collect::<Vec<_>>(),
137 (1, seq_len),
138 &self.device,
139 )
140 .map_err(|e| FerrumError::model(format!("Failed to create input tensor: {}", e)))?;
141
142 let token_type_ids = Tensor::zeros((1, seq_len), DType::I64, &self.device)
144 .map_err(|e| FerrumError::model(format!("Failed to create token type ids: {}", e)))?;
145
146 self.model
148 .get_sentence_embedding(&input_tensor, &token_type_ids, None)
149 }
150
151 pub fn model(&self) -> &BertModelWrapper {
153 &self.model
154 }
155}
156
157#[derive(Debug, Clone)]
159struct DummyBertCache;
160
161impl KvCacheHandle for DummyBertCache {
162 fn block_table(&self) -> &BlockTable {
163 static EMPTY: std::sync::OnceLock<BlockTable> = std::sync::OnceLock::new();
164 EMPTY.get_or_init(|| BlockTable::new(16))
165 }
166
167 fn block_table_mut(&mut self) -> &mut BlockTable {
168 unimplemented!("BERT does not use KV cache")
169 }
170
171 fn as_any(&self) -> &dyn std::any::Any {
172 self
173 }
174
175 fn device(&self) -> Device {
176 Device::CPU
177 }
178
179 fn num_layers(&self) -> usize {
180 0
181 }
182
183 fn num_heads(&self) -> usize {
184 0
185 }
186
187 fn head_dim(&self) -> usize {
188 0
189 }
190
191 fn key_cache(&self, _layer: usize) -> Result<Option<TensorRef>> {
192 Ok(None)
193 }
194
195 fn value_cache(&self, _layer: usize) -> Result<Option<TensorRef>> {
196 Ok(None)
197 }
198
199 fn clone_handle(&self) -> Result<Arc<dyn KvCacheHandle>> {
200 Ok(Arc::new(self.clone()))
201 }
202
203 fn stats(&self) -> CacheHandleStats {
204 CacheHandleStats {
205 memory_bytes: 0,
206 blocks_allocated: 0,
207 tokens_stored: 0,
208 utilization: 0.0,
209 last_access: std::time::Instant::now(),
210 }
211 }
212
213 fn is_valid(&self) -> bool {
214 true
215 }
216
217 fn cache_id(&self) -> String {
218 "bert_dummy_cache".to_string()
219 }
220}
221
222#[async_trait]
223impl ModelExecutor for BertModelExecutor {
224 fn info(&self) -> &ModelInfo {
225 &self.info
226 }
227
228 async fn prefill(&self, input: &PrefillInput) -> Result<PrefillOutput> {
230 let token_ids: Vec<u32> = if let Ok(v) = input.input_ids.to_vec_u32() {
231 v
232 } else if let Ok(vf) = input.input_ids.to_vec_f32() {
233 vf.into_iter().map(|x| x as u32).collect()
234 } else {
235 return Err(FerrumError::backend("Unable to extract token ids"));
236 };
237
238 debug!("BERT prefill: {} tokens", token_ids.len());
239
240 let embeddings = self.get_embeddings(&token_ids)?;
241
242 let output_tensor: TensorRef = Arc::new(CandleTensorWrapper::new(embeddings));
244 let kv_cache: Arc<dyn KvCacheHandle> = Arc::new(DummyBertCache);
245
246 Ok(PrefillOutput::new(output_tensor, kv_cache))
247 }
248
249 async fn decode(&self, _input: &DecodeInput) -> Result<DecodeOutput> {
251 Err(FerrumError::model(
252 "BERT is an encoder model and does not support token generation. Use prefill() to get embeddings.",
253 ))
254 }
255
256 fn capabilities(&self) -> ExecutorCapabilities {
257 ExecutorCapabilities {
258 max_batch_size: 32,
259 max_sequence_length: self.info.max_sequence_length,
260 attention_mechanisms: vec![AttentionType::MultiHead],
261 supports_dynamic_batching: true,
262 supports_continuous_batching: false,
263 supports_speculative_decoding: false,
264 supports_tensor_parallelism: false,
265 supports_pipeline_parallelism: false,
266 supported_dtypes: vec![DataType::FP32],
267 supported_devices: vec![Device::CPU],
268 memory_requirements: MemoryRequirements {
269 parameter_memory: 0,
270 activation_memory_per_token: 0,
271 kv_cache_memory_per_token: 0,
272 overhead_memory: 0,
273 },
274 }
275 }
276
277 fn status(&self) -> ExecutorStatus {
278 self.status.clone()
279 }
280}