use candle_core::{DType, Device as CandleDevice, IndexOp, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::models::modernbert::{Config as ModernBertConfigInner, ModernBert};
use hf_hub::{api::sync::Api, Repo, RepoType};
use std::path::Path;
use tokenizers::Tokenizer;
use super::{NeuralError, Result};
#[derive(Clone, Copy, Debug, Default)]
pub enum Device {
#[default]
Cpu,
Cuda(usize),
Metal,
}
impl Device {
pub fn to_candle(&self) -> std::result::Result<CandleDevice, candle_core::Error> {
match self {
Device::Cpu => Ok(CandleDevice::Cpu),
Device::Cuda(idx) => CandleDevice::new_cuda(*idx),
Device::Metal => CandleDevice::new_metal(0),
}
}
}
#[derive(Clone, Debug)]
pub struct ModernBertConfig {
pub model_id: String,
pub device: Device,
pub dtype: DType,
pub max_seq_len: usize,
}
impl Default for ModernBertConfig {
fn default() -> Self {
Self {
model_id: "answerdotai/ModernBERT-base".to_string(),
device: Device::default(),
dtype: DType::F32,
max_seq_len: 8192,
}
}
}
pub struct ModernBertModel {
model: ModernBert,
tokenizer: Tokenizer,
device: CandleDevice,
config: ModernBertConfig,
hidden_size: usize,
}
impl ModernBertModel {
pub fn load(config: ModernBertConfig) -> Result<Self> {
let device = config
.device
.to_candle()
.map_err(|e| NeuralError::DeviceNotAvailable(format!("{:?}: {}", config.device, e)))?;
let api = Api::new().map_err(|e| NeuralError::ModelLoad(e.to_string()))?;
let repo = api.repo(Repo::new(config.model_id.clone(), RepoType::Model));
let model_path = repo
.get("model.safetensors")
.map_err(|e| NeuralError::ModelLoad(format!("Failed to download model: {}", e)))?;
let config_path = repo
.get("config.json")
.map_err(|e| NeuralError::ModelLoad(format!("Failed to download config: {}", e)))?;
let tokenizer_path = repo
.get("tokenizer.json")
.map_err(|e| NeuralError::ModelLoad(format!("Failed to download tokenizer: {}", e)))?;
Self::load_from_files(&model_path, &config_path, &tokenizer_path, config, device)
}
pub fn load_from_files(
model_path: &Path,
config_path: &Path,
tokenizer_path: &Path,
config: ModernBertConfig,
device: CandleDevice,
) -> Result<Self> {
let config_json = std::fs::read_to_string(config_path)?;
let model_config: ModernBertConfigInner = serde_json::from_str(&config_json)
.map_err(|e| NeuralError::ModelLoad(format!("Invalid config: {}", e)))?;
let hidden_size = model_config.hidden_size;
let tokenizer = Tokenizer::from_file(tokenizer_path)
.map_err(|e| NeuralError::Tokenization(format!("Failed to load tokenizer: {}", e)))?;
let vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[model_path], config.dtype, &device)? };
let model = ModernBert::load(vb, &model_config)?;
Ok(Self {
model,
tokenizer,
device,
config,
hidden_size,
})
}
pub fn hidden_size(&self) -> usize {
self.hidden_size
}
pub fn device(&self) -> &CandleDevice {
&self.device
}
pub fn encode(&self, text: &str) -> Result<Vec<u32>> {
let encoding = self
.tokenizer
.encode(text, true)
.map_err(|e| NeuralError::Tokenization(e.to_string()))?;
Ok(encoding.get_ids().to_vec())
}
pub fn encode_batch(&self, texts: &[&str]) -> Result<(Vec<Vec<u32>>, Vec<usize>)> {
let encodings = self
.tokenizer
.encode_batch(texts.to_vec(), true)
.map_err(|e| NeuralError::Tokenization(e.to_string()))?;
let lengths: Vec<usize> = encodings.iter().map(|e| e.len()).collect();
let ids: Vec<Vec<u32>> = encodings.iter().map(|e| e.get_ids().to_vec()).collect();
Ok((ids, lengths))
}
pub fn decode(&self, ids: &[u32]) -> Result<String> {
self.tokenizer
.decode(ids, true)
.map_err(|e| NeuralError::Tokenization(e.to_string()))
}
pub fn forward(&self, input_ids: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
let mask = match attention_mask {
Some(m) => m.clone(),
None => {
let shape = input_ids.dims();
Tensor::ones(shape, DType::F32, &self.device)?
}
};
let output = self.model.forward(input_ids, &mask)?;
Ok(output)
}
pub fn embed(&self, text: &str) -> Result<Vec<f32>> {
let ids = self.encode(text)?;
let input_ids = Tensor::new(&ids[..], &self.device)?.unsqueeze(0)?;
let hidden_states = self.forward(&input_ids, None)?;
let cls_embedding = hidden_states.i((0, 0))?;
let embedding_vec: Vec<f32> = cls_embedding.to_vec1()?;
Ok(embedding_vec)
}
pub fn embed_mean_pooled(&self, text: &str) -> Result<Vec<f32>> {
let ids = self.encode(text)?;
let seq_len = ids.len();
let input_ids = Tensor::new(&ids[..], &self.device)?.unsqueeze(0)?;
let hidden_states = self.forward(&input_ids, None)?;
let sum = hidden_states.sum(1)?;
let mean = (sum / (seq_len as f64))?;
let embedding_vec: Vec<f32> = mean.squeeze(0)?.to_vec1()?;
Ok(embedding_vec)
}
pub fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
if texts.is_empty() {
return Ok(vec![]);
}
let (ids_batch, lengths) = self.encode_batch(texts)?;
let max_len = lengths.iter().copied().max().unwrap_or(0);
let batch_size = ids_batch.len();
let mut padded_ids: Vec<u32> = Vec::with_capacity(batch_size * max_len);
let mut attention_mask: Vec<f32> = Vec::with_capacity(batch_size * max_len);
for (ids, &len) in ids_batch.iter().zip(&lengths) {
padded_ids.extend(ids.iter().copied());
padded_ids.extend(std::iter::repeat(0).take(max_len - len));
attention_mask.extend(std::iter::repeat(1.0).take(len));
attention_mask.extend(std::iter::repeat(0.0).take(max_len - len));
}
let input_tensor = Tensor::from_vec(padded_ids, (batch_size, max_len), &self.device)?;
let mask_tensor = Tensor::from_vec(attention_mask, (batch_size, max_len), &self.device)?;
let hidden_states = self.forward(&input_tensor, Some(&mask_tensor))?;
let mut embeddings = Vec::with_capacity(batch_size);
for i in 0..batch_size {
let cls_embedding = hidden_states.i((i, 0))?;
let embedding_vec: Vec<f32> = cls_embedding.to_vec1()?;
embeddings.push(embedding_vec);
}
Ok(embeddings)
}
pub fn get_mlm_logits(&self, input_ids: &Tensor) -> Result<Tensor> {
let hidden_states = self.forward(input_ids, None)?;
Ok(hidden_states)
}
pub fn mask_token_id(&self) -> Option<u32> {
self.tokenizer.token_to_id("[MASK]")
}
pub fn vocab_size(&self) -> usize {
self.tokenizer.get_vocab_size(false)
}
pub fn tokenizer(&self) -> &Tokenizer {
&self.tokenizer
}
pub fn config(&self) -> &ModernBertConfig {
&self.config
}
}
impl std::fmt::Debug for ModernBertModel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ModernBertModel")
.field("model_id", &self.config.model_id)
.field("device", &self.config.device)
.field("hidden_size", &self.hidden_size)
.field("vocab_size", &self.vocab_size())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_default() {
let config = ModernBertConfig::default();
assert_eq!(config.model_id, "answerdotai/ModernBERT-base");
assert!(matches!(config.device, Device::Cpu));
assert_eq!(config.dtype, DType::F32);
assert_eq!(config.max_seq_len, 8192);
}
#[test]
fn test_cpu_device_conversion() {
let device = Device::Cpu.to_candle().expect("CPU device should exist");
assert!(matches!(device, CandleDevice::Cpu));
}
#[test]
fn test_load_from_files_rejects_invalid_config_json() {
let dir = tempfile::tempdir().expect("tempdir");
let model_path = dir.path().join("model.safetensors");
let config_path = dir.path().join("config.json");
let tokenizer_path = dir.path().join("tokenizer.json");
std::fs::write(&model_path, b"").expect("empty model fixture");
std::fs::write(&config_path, b"not json").expect("invalid config");
std::fs::write(&tokenizer_path, b"{}").expect("minimal tokenizer fixture");
let err = ModernBertModel::load_from_files(
&model_path,
&config_path,
&tokenizer_path,
ModernBertConfig::default(),
CandleDevice::Cpu,
)
.expect_err("invalid config should be rejected before loading weights");
match err {
NeuralError::ModelLoad(message) => {
assert!(message.contains("Invalid config"));
}
other => panic!("expected ModelLoad error, got {other:?}"),
}
}
}