use std::sync::Arc;
use async_trait::async_trait;
use candle_core::{DType, Device as CandleDevice, Tensor};
use candle_nn::VarBuilder;
use ferrum_interfaces::{
model_executor::{
AttentionType, DecodeInput, DecodeOutput, ExecutorCapabilities, ExecutorMemoryUsage,
ExecutorState, ExecutorStatus, MemoryRequirements, PrefillInput, PrefillOutput,
},
BlockTable, CacheHandleStats, KvCacheHandle, ModelExecutor, TensorRef,
};
use ferrum_types::{DataType, Device, FerrumError, ModelInfo, Result};
use tracing::{debug, info};
use crate::architectures::bert::BertModelWrapper;
use crate::tensor_wrapper::CandleTensorWrapper;
pub struct BertModelExecutor {
model: BertModelWrapper,
info: ModelInfo,
device: CandleDevice,
status: ExecutorStatus,
}
impl BertModelExecutor {
pub fn new(model: BertModelWrapper, model_info: ModelInfo, device: CandleDevice) -> Self {
info!(
"Created BertModelExecutor for model: {}",
model_info.model_id
);
let status = ExecutorStatus {
state: ExecutorState::Ready,
is_ready: true,
current_batch_size: 0,
prefill_operations: 0,
decode_operations: 0,
avg_prefill_time_ms: 0.0,
avg_decode_time_ms: 0.0,
memory_usage: ExecutorMemoryUsage {
allocated_bytes: 0,
used_bytes: 0,
peak_bytes: 0,
utilization_percent: 0.0,
},
last_operation: None,
};
Self {
model,
info: model_info,
device,
status,
}
}
pub async fn from_path(
model_path: &str,
model_def: &crate::definition::ModelDefinition,
device: CandleDevice,
) -> Result<Self> {
info!("Loading BERT model from: {}", model_path);
let path = std::path::Path::new(model_path);
let safetensors_path = if path.join("model.safetensors").exists() {
path.join("model.safetensors")
} else {
std::fs::read_dir(path)
.map_err(|e| FerrumError::model(format!("Failed to read model dir: {}", e)))?
.filter_map(|e| e.ok())
.find(|e| {
e.path()
.extension()
.map_or(false, |ext| ext == "safetensors")
})
.map(|e| e.path())
.ok_or_else(|| FerrumError::model("No safetensors file found"))?
};
info!("Loading weights from: {:?}", safetensors_path);
let dtype = DType::F32;
let vb = unsafe {
VarBuilder::from_mmaped_safetensors(&[&safetensors_path], dtype, &device)
.map_err(|e| FerrumError::model(format!("Failed to load weights: {}", e)))?
};
let config_path = path.join("config.json");
let model = BertModelWrapper::from_config_json(vb, &config_path, device.clone(), dtype)?;
let model_info = model_def.to_model_info(model_path.to_string());
Ok(Self::new(model, model_info, device))
}
pub fn get_embeddings(&self, input_ids: &[u32]) -> Result<Tensor> {
let seq_len = input_ids.len();
let input_tensor = Tensor::from_vec(
input_ids.iter().map(|&x| x as i64).collect::<Vec<_>>(),
(1, seq_len),
&self.device,
)
.map_err(|e| FerrumError::model(format!("Failed to create input tensor: {}", e)))?;
let token_type_ids = Tensor::zeros((1, seq_len), DType::I64, &self.device)
.map_err(|e| FerrumError::model(format!("Failed to create token type ids: {}", e)))?;
self.model
.get_sentence_embedding(&input_tensor, &token_type_ids, None)
}
pub fn model(&self) -> &BertModelWrapper {
&self.model
}
}
#[derive(Debug, Clone)]
struct DummyBertCache;
impl KvCacheHandle for DummyBertCache {
fn block_table(&self) -> &BlockTable {
static EMPTY: std::sync::OnceLock<BlockTable> = std::sync::OnceLock::new();
EMPTY.get_or_init(|| BlockTable::new(16))
}
fn block_table_mut(&mut self) -> &mut BlockTable {
unimplemented!("BERT does not use KV cache")
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn device(&self) -> Device {
Device::CPU
}
fn num_layers(&self) -> usize {
0
}
fn num_heads(&self) -> usize {
0
}
fn head_dim(&self) -> usize {
0
}
fn key_cache(&self, _layer: usize) -> Result<Option<TensorRef>> {
Ok(None)
}
fn value_cache(&self, _layer: usize) -> Result<Option<TensorRef>> {
Ok(None)
}
fn clone_handle(&self) -> Result<Arc<dyn KvCacheHandle>> {
Ok(Arc::new(self.clone()))
}
fn stats(&self) -> CacheHandleStats {
CacheHandleStats {
memory_bytes: 0,
blocks_allocated: 0,
tokens_stored: 0,
utilization: 0.0,
last_access: std::time::Instant::now(),
}
}
fn is_valid(&self) -> bool {
true
}
fn cache_id(&self) -> String {
"bert_dummy_cache".to_string()
}
}
#[async_trait]
impl ModelExecutor for BertModelExecutor {
fn info(&self) -> &ModelInfo {
&self.info
}
async fn prefill(&self, input: &PrefillInput) -> Result<PrefillOutput> {
let token_ids: Vec<u32> = if let Ok(v) = input.input_ids.to_vec_u32() {
v
} else if let Ok(vf) = input.input_ids.to_vec_f32() {
vf.into_iter().map(|x| x as u32).collect()
} else {
return Err(FerrumError::backend("Unable to extract token ids"));
};
debug!("BERT prefill: {} tokens", token_ids.len());
let embeddings = self.get_embeddings(&token_ids)?;
let output_tensor: TensorRef = Arc::new(CandleTensorWrapper::new(embeddings));
let kv_cache: Arc<dyn KvCacheHandle> = Arc::new(DummyBertCache);
Ok(PrefillOutput::new(output_tensor, kv_cache))
}
async fn decode(&self, _input: &DecodeInput) -> Result<DecodeOutput> {
Err(FerrumError::model(
"BERT is an encoder model and does not support token generation. Use prefill() to get embeddings.",
))
}
fn capabilities(&self) -> ExecutorCapabilities {
ExecutorCapabilities {
max_batch_size: 32,
max_sequence_length: self.info.max_sequence_length,
attention_mechanisms: vec![AttentionType::MultiHead],
supports_dynamic_batching: true,
supports_continuous_batching: false,
supports_speculative_decoding: false,
supports_tensor_parallelism: false,
supports_pipeline_parallelism: false,
supported_dtypes: vec![DataType::FP32],
supported_devices: vec![Device::CPU],
memory_requirements: MemoryRequirements {
parameter_memory: 0,
activation_memory_per_token: 0,
kv_cache_memory_per_token: 0,
overhead_memory: 0,
},
}
}
fn status(&self) -> ExecutorStatus {
self.status.clone()
}
}