use super::config::EncoderConfig;
use super::model::{Encoder, EncoderClient, Pooling};
use crate::error::Result;
use crate::format::Gguf;
use crate::format::gguf_tokenizer::GgufTokenizer;
use crate::nn::Weight;
use crate::quant::traits::DequantOps;
use numr::dtype::DType;
use numr::ops::{IndexingOps, ScalarOps, TensorOps, TypeConversionOps};
use numr::runtime::Runtime;
use numr::tensor::Tensor;
use splintr::{Tokenize, Tokenizer};
fn preferred_compute_dtype<R: Runtime>() -> DType {
let is_cuda = R::name() == "cuda";
#[cfg(feature = "f16")]
if is_cuda {
return DType::F16;
}
#[cfg(not(feature = "f16"))]
let _ = is_cuda;
DType::F32
}
const SEQ_LEN_BUCKETS: [usize; 6] = [64, 128, 192, 256, 384, 512];
fn seq_len_bucket(raw: usize, max_seq: usize) -> usize {
SEQ_LEN_BUCKETS
.iter()
.copied()
.find(|&b| b >= raw)
.unwrap_or_else(|| raw.next_multiple_of(16))
.min(max_seq)
}
pub struct EmbeddingPipeline<R: Runtime, T: Tokenize = Tokenizer> {
encoder: Encoder<R>,
tokenizer: T,
device: R::Device,
}
impl<R: Runtime<DType = DType>, T: Tokenize> EmbeddingPipeline<R, T> {
pub fn new(encoder: Encoder<R>, tokenizer: T, device: R::Device) -> Self {
Self {
encoder,
tokenizer,
device,
}
}
pub fn embed_text<C>(&self, client: &C, text: &str) -> Result<Vec<f32>>
where
C: EncoderClient<R>,
R::Client: TensorOps<R> + ScalarOps<R> + IndexingOps<R>,
{
let max_seq = self.encoder.config().max_position_embeddings;
let mut token_ids = self.tokenizer.encode(text);
if token_ids.len() > max_seq {
token_ids.truncate(max_seq);
}
let seq_len = token_ids.len();
let input: Vec<i64> = token_ids.into_iter().map(|t| t as i64).collect();
let input_tensor = Tensor::<R>::from_slice(&input, &[1, seq_len], &self.device);
let embedding = self.encoder.embed_inference(client, &input_tensor, None)?;
Ok(embedding.to_vec())
}
pub fn embed_texts<C>(&self, client: &C, texts: &[&str]) -> Result<Vec<Vec<f32>>>
where
C: EncoderClient<R>,
R::Client: TensorOps<R> + ScalarOps<R> + IndexingOps<R>,
{
if texts.is_empty() {
return Ok(vec![]);
}
let max_seq = self.encoder.config().max_position_embeddings;
let all_ids: Vec<Vec<u32>> = texts
.iter()
.map(|t| {
let mut ids = self.tokenizer.encode(t);
if ids.len() > max_seq {
ids.truncate(max_seq);
}
ids
})
.collect();
use super::config::ArchFamily;
let arch = self.encoder.config().arch_family;
let use_varlen = match arch {
ArchFamily::NomicBert => true,
ArchFamily::Bert | ArchFamily::XlmRoberta => {
let hd = self.encoder.config().head_dim();
hd == 64 || hd == 128
}
ArchFamily::GemmaEmbedding => {
let hd = self.encoder.config().resolved_head_dim();
hd == 64 || hd == 128 || hd == 256
}
};
if use_varlen {
return self.embed_texts_varlen(client, &all_ids);
}
let max_len = {
let raw = all_ids.iter().map(|ids| ids.len()).max().unwrap_or(0);
if raw == 0 {
0
} else {
seq_len_bucket(raw, max_seq)
}
};
if max_len == 0 {
return Ok(vec![vec![]; texts.len()]);
}
let batch_size = texts.len();
let mut flat: Vec<i64> = Vec::with_capacity(batch_size * max_len);
let mut mask_flat: Vec<f32> = Vec::with_capacity(batch_size * max_len);
for ids in &all_ids {
let real_len = ids.len();
flat.extend(ids.iter().map(|&t| t as i64));
flat.extend(std::iter::repeat_n(0i64, max_len - real_len));
mask_flat.extend(std::iter::repeat_n(1.0f32, real_len));
mask_flat.extend(std::iter::repeat_n(0.0f32, max_len - real_len));
}
let input_tensor = Tensor::<R>::from_slice(&flat, &[batch_size, max_len], &self.device);
let mask_tensor = Tensor::<R>::from_slice(&mask_flat, &[batch_size, max_len], &self.device);
let embeddings = self
.encoder
.embed_inference(client, &input_tensor, Some(&mask_tensor))?;
let data: Vec<f32> = embeddings.to_vec();
let hidden = self.encoder.config().hidden_size;
let result = data.chunks_exact(hidden).map(|c| c.to_vec()).collect();
Ok(result)
}
fn embed_texts_varlen<C>(&self, client: &C, all_ids: &[Vec<u32>]) -> Result<Vec<Vec<f32>>>
where
C: EncoderClient<R>,
R::Client: TensorOps<R> + ScalarOps<R> + IndexingOps<R>,
{
use super::config::DEFAULT_MAX_TOKENS_PER_FORWARD;
let batch = all_ids.len();
if batch == 0 {
return Ok(vec![]);
}
let budget = self
.encoder
.config()
.max_tokens_per_forward
.unwrap_or(DEFAULT_MAX_TOKENS_PER_FORWARD);
let mut result: Vec<Vec<f32>> = Vec::with_capacity(batch);
let mut start = 0usize;
while start < batch {
let mut end = start;
let mut tokens_in_sub = 0usize;
while end < batch {
let doc_len = all_ids[end].len();
if end == start || tokens_in_sub + doc_len <= budget {
tokens_in_sub += doc_len;
end += 1;
} else {
break;
}
}
let chunk = &all_ids[start..end];
let mut sub_result = self.embed_one_varlen_batch(client, chunk)?;
result.append(&mut sub_result);
start = end;
}
Ok(result)
}
fn embed_one_varlen_batch<C>(&self, client: &C, ids_chunk: &[Vec<u32>]) -> Result<Vec<Vec<f32>>>
where
C: EncoderClient<R>,
R::Client: TensorOps<R> + ScalarOps<R> + IndexingOps<R>,
{
let sub_batch = ids_chunk.len();
if sub_batch == 0 {
return Ok(vec![]);
}
let mut flat_ids: Vec<i64> = Vec::new();
let mut cu: Vec<i32> = Vec::with_capacity(sub_batch + 1);
let mut pos_ids: Vec<i64> = Vec::new();
let mut seg_ids: Vec<i32> = Vec::new();
let mut max_seqlen: usize = 0;
let arch = self.encoder.config().arch_family;
let xlmr_pad_id: i64 = self.encoder.config().padding_token_id;
cu.push(0i32);
for (b, ids) in ids_chunk.iter().enumerate() {
let n = ids.len();
if n > max_seqlen {
max_seqlen = n;
}
flat_ids.extend(ids.iter().map(|&t| t as i64));
for p in 0..n as i64 {
let pid = match arch {
super::config::ArchFamily::XlmRoberta => xlmr_pad_id + 1 + p,
_ => p,
};
pos_ids.push(pid);
}
seg_ids.extend(std::iter::repeat_n(b as i32, n));
let last = *cu.last().unwrap_or(&0);
cu.push(last + n as i32);
}
if flat_ids.is_empty() {
return Ok(vec![vec![]; sub_batch]);
}
let total_tokens = flat_ids.len();
let d = &self.device;
let input_t = Tensor::<R>::from_slice(&flat_ids, &[total_tokens], d);
let cu_t = Tensor::<R>::from_slice(&cu, &[sub_batch + 1], d);
let pos_t = Tensor::<R>::from_slice(&pos_ids, &[total_tokens], d);
let seg_t = Tensor::<R>::from_slice(&seg_ids, &[total_tokens], d);
let embeddings = self.encoder.embed_inference_varlen(
client, &input_t, &cu_t, &pos_t, &seg_t, sub_batch, max_seqlen,
)?;
let data: Vec<f32> = embeddings.to_vec();
let hidden = self.encoder.config().hidden_size;
let chunk_result = data.chunks_exact(hidden).map(|c| c.to_vec()).collect();
Ok(chunk_result)
}
pub fn encoder(&self) -> &Encoder<R> {
&self.encoder
}
pub fn tokenizer(&self) -> &T {
&self.tokenizer
}
pub fn config(&self) -> &EncoderConfig {
self.encoder.config()
}
}
impl<R: Runtime<DType = DType>> EmbeddingPipeline<R, GgufTokenizer>
where
R::Client: Clone + TypeConversionOps<R> + DequantOps<R>,
{
pub fn from_gguf(gguf: &mut Gguf, device: R::Device) -> Result<Self> {
let tokenizer = GgufTokenizer::from_gguf(gguf)?;
let mut config = EncoderConfig::from_gguf_metadata(gguf.metadata())?;
config.compute_dtype = preferred_compute_dtype::<R>();
let d = &device;
let encoder = match config.arch_family {
super::config::ArchFamily::NomicBert => {
let client = R::default_client(d);
Encoder::from_weights_nomic(config, Pooling::Mean, &client, |gguf_name| {
gguf.load_tensor_f32::<R>(gguf_name, d)
.map(Weight::Standard)
})?
}
super::config::ArchFamily::GemmaEmbedding => {
let client = R::default_client(d);
Encoder::from_weights_gemma(config, Pooling::Mean, &client, |gguf_name| {
gguf.load_tensor_f32::<R>(gguf_name, d)
.map(Weight::Standard)
})?
}
_ => Encoder::from_weights(config, Pooling::Mean, |hf_name| {
let gguf_name = hf_name_to_gguf(hf_name);
gguf.load_tensor_f32::<R>(&gguf_name, d)
})?,
};
Ok(Self::new(encoder, tokenizer, device))
}
}
fn hf_name_to_gguf(hf: &str) -> String {
if hf == "embeddings.word_embeddings.weight" {
return "token_embd.weight".into();
}
if hf == "embeddings.position_embeddings.weight" {
return "position_embd.weight".into();
}
if hf == "embeddings.layer_norm.weight" {
return "token_embd_norm.weight".into();
}
if hf == "embeddings.layer_norm.bias" {
return "token_embd_norm.bias".into();
}
if let Some(rest) = hf.strip_prefix("encoder.layer.")
&& let Some(dot) = rest.find('.')
{
let layer = &rest[..dot];
let suffix = &rest[dot + 1..];
let mapped = match suffix {
"attention.self.query.weight" => "attn_q.weight",
"attention.self.query.bias" => "attn_q.bias",
"attention.self.key.weight" => "attn_k.weight",
"attention.self.key.bias" => "attn_k.bias",
"attention.self.value.weight" => "attn_v.weight",
"attention.self.value.bias" => "attn_v.bias",
"attention.output.dense.weight" => "attn_output.weight",
"attention.output.dense.bias" => "attn_output.bias",
"attention.output.LayerNorm.weight" => "attn_output_norm.weight",
"attention.output.LayerNorm.bias" => "attn_output_norm.bias",
"intermediate.dense.weight" => "ffn_up.weight",
"intermediate.dense.bias" => "ffn_up.bias",
"output.dense.weight" => "ffn_down.weight",
"output.dense.bias" => "ffn_down.bias",
"output.LayerNorm.weight" => "layer_output_norm.weight",
"output.LayerNorm.bias" => "layer_output_norm.bias",
_ => return hf.to_string(),
};
return format!("blk.{layer}.{mapped}");
}
hf.to_string()
}
#[cfg(test)]
#[path = "pipeline_tests.rs"]
mod tests;