use super::*;
pub trait CandleTextEncoder: Send + Sync {
fn encode(&self, text: &str) -> Result<(Vec<f32>, usize)>;
fn encode_batch(&self, texts: &[&str]) -> Result<(Vec<f32>, Vec<usize>)> {
let mut all_embeddings = Vec::new();
let mut cu_seqlens = vec![0usize];
let mut total = 0usize;
for text in texts {
let (embeddings, seq_len) = self.encode(text)?;
all_embeddings.extend(embeddings);
total += seq_len;
cu_seqlens.push(total);
}
Ok((all_embeddings, cu_seqlens))
}
fn hidden_dim(&self) -> usize;
fn max_length(&self) -> usize;
fn architecture(&self) -> &str;
}
#[derive(Debug, Clone)]
pub struct EncoderConfig {
pub vocab_size: usize,
pub hidden_size: usize,
pub num_attention_heads: usize,
pub num_hidden_layers: usize,
pub intermediate_size: usize,
pub max_position_embeddings: usize,
pub hidden_dropout_prob: f32,
pub layer_norm_eps: f64,
pub use_rope: bool,
pub use_geglu: bool,
pub rope_theta: f64,
pub use_pre_norm: bool,
}
impl Default for EncoderConfig {
fn default() -> Self {
Self::bert_base()
}
}
impl EncoderConfig {
pub fn bert_base() -> Self {
Self {
vocab_size: 30522,
hidden_size: 768,
num_attention_heads: 12,
num_hidden_layers: 12,
intermediate_size: 3072,
max_position_embeddings: 512,
hidden_dropout_prob: 0.1,
layer_norm_eps: 1e-12,
use_rope: false,
use_geglu: false,
rope_theta: 10000.0,
use_pre_norm: false, }
}
pub fn modernbert_base() -> Self {
Self {
vocab_size: 50368,
hidden_size: 768,
num_attention_heads: 12,
num_hidden_layers: 22,
intermediate_size: 1152, max_position_embeddings: 8192,
hidden_dropout_prob: 0.0, layer_norm_eps: 1e-5,
use_rope: true,
use_geglu: true,
rope_theta: 160000.0, use_pre_norm: true, }
}
pub fn modernbert_large() -> Self {
Self {
vocab_size: 50368,
hidden_size: 1024,
num_attention_heads: 16,
num_hidden_layers: 28,
intermediate_size: 2624,
max_position_embeddings: 8192,
hidden_dropout_prob: 0.0,
layer_norm_eps: 1e-5,
use_rope: true,
use_geglu: true,
rope_theta: 160000.0,
use_pre_norm: true, }
}
pub fn deberta_v3_base() -> Self {
Self {
vocab_size: 128100,
hidden_size: 768,
num_attention_heads: 12,
num_hidden_layers: 12,
intermediate_size: 3072,
max_position_embeddings: 512,
hidden_dropout_prob: 0.1,
layer_norm_eps: 1e-7,
use_rope: false,
use_geglu: false,
rope_theta: 10000.0,
use_pre_norm: true, }
}
pub fn deberta_v3_large() -> Self {
Self {
vocab_size: 128100,
hidden_size: 1024,
num_attention_heads: 16,
num_hidden_layers: 24,
intermediate_size: 4096,
max_position_embeddings: 512,
hidden_dropout_prob: 0.1,
layer_norm_eps: 1e-7,
use_rope: false,
use_geglu: false,
rope_theta: 10000.0,
use_pre_norm: true, }
}
pub fn from_model_name(name: &str) -> Self {
let lower = name.to_lowercase();
if lower.contains("modernbert") {
if lower.contains("large") {
Self::modernbert_large()
} else {
Self::modernbert_base()
}
} else if lower.contains("deberta") {
if lower.contains("large") {
Self::deberta_v3_large()
} else {
Self::deberta_v3_base()
}
} else {
Self::bert_base()
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum EncoderArchitecture {
Bert,
DeBertaV3,
#[default]
ModernBert,
}
impl EncoderArchitecture {
pub fn default_config(&self) -> EncoderConfig {
match self {
Self::Bert => EncoderConfig::bert_base(),
Self::DeBertaV3 => EncoderConfig::deberta_v3_base(),
Self::ModernBert => EncoderConfig::modernbert_base(),
}
}
pub fn default_model_id(&self) -> &'static str {
match self {
Self::Bert => "google-bert/bert-base-uncased",
Self::DeBertaV3 => "microsoft/deberta-v3-base",
Self::ModernBert => "answerdotai/ModernBERT-base",
}
}
pub fn max_length(&self) -> usize {
match self {
Self::Bert | Self::DeBertaV3 => 512,
Self::ModernBert => 8192,
}
}
pub fn uses_rope(&self) -> bool {
matches!(self, Self::ModernBert)
}
pub fn as_str(&self) -> &'static str {
match self {
Self::Bert => "BERT",
Self::DeBertaV3 => "DeBERTa-v3",
Self::ModernBert => "ModernBERT",
}
}
}
impl std::fmt::Display for EncoderArchitecture {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.as_str())
}
}