#![warn(missing_docs)]
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use candle_core::{IndexOp, Tensor};
use candle_nn::VarBuilder;
use kalosm_common::*;
use kalosm_model_types::ModelLoadingProgress;
use std::sync::{Arc, RwLock};
use tokenizers::{Encoding, PaddingParams, Tokenizer};
mod language_model;
mod raw;
mod source;
pub use crate::language_model::*;
use crate::raw::DTYPE;
pub use crate::raw::{BertModel, Config};
pub use crate::source::*;
#[derive(Default)]
pub struct BertBuilder {
source: BertSource,
cache: kalosm_common::Cache,
}
impl BertBuilder {
pub fn with_source(mut self, source: BertSource) -> Self {
self.source = source;
self
}
pub async fn build(self) -> Result<Bert, BertLoadingError> {
self.build_with_loading_handler(ModelLoadingProgress::multi_bar_loading_indicator())
.await
}
pub fn with_cache(mut self, cache: kalosm_common::Cache) -> Self {
self.cache = cache;
self
}
pub async fn build_with_loading_handler(
self,
loading_handler: impl FnMut(ModelLoadingProgress) + Send + 'static,
) -> Result<Bert, BertLoadingError> {
Bert::from_builder(self, loading_handler).await
}
}
#[derive(Debug, thiserror::Error)]
pub enum BertLoadingError {
#[error("Failed to load model from huggingface or local file: {0}")]
DownloadingError(#[from] CacheError),
#[error("Failed to load model into device: {0}")]
LoadModel(#[from] candle_core::Error),
#[error("Failed to load tokenizer: {0}")]
LoadTokenizer(tokenizers::Error),
#[error("Failed to load config: {0}")]
LoadConfig(serde_json::Error),
#[error("Config not found")]
ConfigNotFound,
}
#[derive(Debug, thiserror::Error)]
pub enum BertError {
#[error("Failed to run model: {0}")]
Candle(#[from] candle_core::Error),
#[error("Failed to tokenize: {0}")]
TokenizerError(tokenizers::Error),
#[error("Failed to join thread: {0}")]
Join(#[from] tokio::task::JoinError),
}
#[derive(Debug, Clone, Copy)]
pub enum Pooling {
Mean,
CLS,
}
#[derive(Clone)]
pub struct Bert {
embedding_search_prefix: Arc<Option<String>>,
model: Arc<BertModel>,
tokenizer: Arc<RwLock<Tokenizer>>,
}
impl Bert {
pub fn builder() -> BertBuilder {
BertBuilder::default()
}
pub async fn new() -> Result<Self, BertLoadingError> {
Self::builder().build().await
}
pub async fn new_for_search() -> Result<Self, BertLoadingError> {
Self::builder()
.with_source(BertSource::new_for_search())
.build()
.await
}
async fn from_builder(
builder: BertBuilder,
mut progress_handler: impl FnMut(ModelLoadingProgress) + Send + 'static,
) -> Result<Self, BertLoadingError> {
let BertBuilder { source, cache } = builder;
let BertSource {
config,
tokenizer,
model,
search_embedding_prefix,
} = source;
let source = format!("Config ({})", config);
let mut create_progress = ModelLoadingProgress::downloading_progress(source);
let config_filename = cache
.get(&config, |progress| {
progress_handler(create_progress(progress))
})
.await?;
let tokenizer_source = format!("Tokenizer ({})", tokenizer);
let mut create_progress = ModelLoadingProgress::downloading_progress(tokenizer_source);
let tokenizer_filename = cache
.get(&tokenizer, |progress| {
progress_handler(create_progress(progress))
})
.await?;
let model_source = format!("Model ({})", model);
let mut create_progress = ModelLoadingProgress::downloading_progress(model_source);
let weights_filename = cache
.get(&model, |progress| {
progress_handler(create_progress(progress))
})
.await?;
let config = std::fs::read_to_string(config_filename)
.map_err(|_| BertLoadingError::ConfigNotFound)?;
let config: Config = serde_json::from_str(&config).map_err(BertLoadingError::LoadConfig)?;
let device = accelerated_device_if_available()?;
let vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[&weights_filename], DTYPE, &device)? };
let model = BertModel::load(vb, &config)?;
let mut tokenizer =
Tokenizer::from_file(&tokenizer_filename).map_err(BertLoadingError::LoadTokenizer)?;
tokenizer.with_padding(None);
Ok(Bert {
tokenizer: Arc::new(RwLock::new(tokenizer)),
model: Arc::new(model),
embedding_search_prefix: Arc::new(search_embedding_prefix),
})
}
pub(crate) fn embed_batch_raw(
&self,
sentences: Vec<&str>,
pooling: Pooling,
) -> Result<Vec<Tensor>, BertError> {
let embedding_dim = self.model.embedding_dim();
let limit = embedding_dim * 512usize.pow(2) * 2;
let encodings = {
let tokenizer_read = self.tokenizer.read().unwrap();
tokenizer_read.encode_batch(sentences, true)
}
.map_err(BertError::TokenizerError)?;
let mut encodings_with_indices = encodings.into_iter().enumerate().collect::<Vec<_>>();
encodings_with_indices.sort_unstable_by_key(|(_, encoding)| encoding.len());
let mut combined: Vec<Option<Tensor>> = vec![None; encodings_with_indices.len()];
let mut chunks = Vec::new();
let mut current_chunk_len = 0;
let mut current_chunk_max_token_len = 0;
let mut current_chunk_indices = Vec::new();
let mut current_chunk_text: Vec<Encoding> = Vec::new();
for (index, encoding) in encodings_with_indices {
let len = encoding.get_ids().len();
current_chunk_max_token_len = current_chunk_max_token_len.max(len);
current_chunk_len += 1;
let score = current_chunk_len
* (embedding_dim * 8 + embedding_dim * current_chunk_max_token_len.pow(2));
if score > limit {
chunks.push((
std::mem::take(&mut current_chunk_indices),
std::mem::take(&mut current_chunk_text),
));
current_chunk_max_token_len = len;
current_chunk_len = 1;
}
current_chunk_indices.push(index);
current_chunk_text.push(encoding);
}
chunks.push((
std::mem::take(&mut current_chunk_indices),
std::mem::take(&mut current_chunk_text),
));
for (indices, encodings) in chunks {
let embeddings =
maybe_autoreleasepool(|| self.embed_batch_raw_inner(encodings, pooling))?;
for (i, embedding) in indices.iter().zip(embeddings) {
combined[*i] = Some(embedding);
}
}
Ok(combined.into_iter().map(|x| x.unwrap()).collect())
}
fn embed_batch_raw_inner(
&self,
mut tokens: Vec<Encoding>,
pooling: Pooling,
) -> Result<Vec<Tensor>, BertError> {
if tokens.is_empty() {
return Ok(Vec::new());
}
let device = &self.model.device;
let pp = PaddingParams {
strategy: tokenizers::PaddingStrategy::BatchLongest,
..Default::default()
};
tokenizers::pad_encodings(&mut tokens, &pp).map_err(BertError::TokenizerError)?;
let n_sentences = tokens.len();
let max_seq_len = self.model.max_seq_len();
let token_ids = tokens
.iter()
.map(|tokens| {
let tokens = tokens.get_ids().to_vec();
Tensor::new(
&tokens.as_slice()[..max_seq_len.min(tokens.as_slice().len())],
device,
)
})
.collect::<candle_core::Result<Vec<_>>>()?;
let token_ids = Tensor::stack(&token_ids, 0)?;
let attention_masks = tokens
.iter()
.map(|tokens| {
let attention_mask = tokens.get_attention_mask();
let attention_mask = Tensor::new(
&attention_mask[..max_seq_len.min(attention_mask.len())],
device,
)?;
Ok(attention_mask)
})
.collect::<candle_core::Result<Vec<_>>>()?;
let attention_mask = Tensor::stack(&attention_masks, 0)?;
let token_type_ids = token_ids.zeros_like()?;
let embeddings =
self.model
.forward(&token_ids, &token_type_ids, Some(&attention_mask), false)?;
let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;
match pooling {
Pooling::Mean => {
let embeddings = embeddings.mul(
&attention_mask
.to_dtype(DTYPE)?
.unsqueeze(2)?
.broadcast_as(embeddings.shape())?,
)?;
let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?;
let embeddings = normalize_l2(&embeddings)?;
Ok(embeddings.chunk(n_sentences, 0)?)
}
Pooling::CLS => {
let indexed_embeddings = embeddings.i((.., 0, ..))?;
Ok(indexed_embeddings.chunk(n_sentences, 0)?)
}
}
}
}
fn normalize_l2(v: &Tensor) -> candle_core::Result<Tensor> {
v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)
}