use crate::{
array::Array,
error::{
ArithmeticOverflowPayload, Error, LengthMismatchPayload, OutOfRangePayload,
RankMismatchPayload, Result, ShapePairMismatchPayload, try_with_capacity,
},
tokenizer::{EncodeOptions, Tokenizer},
};
use super::{PoolingStrategy, model::EmbeddingModel, pool, pool_post};
#[derive(Debug, Clone)]
pub struct EncodeConfig {
strategy: PoolingStrategy,
normalize: bool,
add_special_tokens: bool,
max_length: Option<usize>,
pad_token_id: u32,
dimension: Option<usize>,
apply_layer_norm: bool,
apply_rms_norm: bool,
}
impl Default for EncodeConfig {
fn default() -> Self {
Self {
strategy: PoolingStrategy::Mean,
normalize: true,
add_special_tokens: true,
max_length: Some(512),
pad_token_id: 0,
dimension: None,
apply_layer_norm: false,
apply_rms_norm: false,
}
}
}
impl EncodeConfig {
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_strategy(mut self, v: PoolingStrategy) -> Self {
self.strategy = v;
self
}
#[must_use]
pub fn with_normalize(mut self, v: bool) -> Self {
self.normalize = v;
self
}
#[must_use]
pub fn with_add_special_tokens(mut self, v: bool) -> Self {
self.add_special_tokens = v;
self
}
#[must_use]
pub fn with_max_length(mut self, v: Option<usize>) -> Self {
self.max_length = v;
self
}
#[must_use]
pub fn with_pad_token_id(mut self, v: u32) -> Self {
self.pad_token_id = v;
self
}
#[must_use]
pub fn with_dimension(mut self, v: Option<usize>) -> Self {
self.dimension = v;
self
}
#[must_use]
pub fn with_apply_layer_norm(mut self, v: bool) -> Self {
self.apply_layer_norm = v;
self
}
#[must_use]
pub fn with_apply_rms_norm(mut self, v: bool) -> Self {
self.apply_rms_norm = v;
self
}
#[inline(always)]
pub fn strategy(&self) -> PoolingStrategy {
self.strategy
}
#[inline(always)]
pub fn normalize(&self) -> bool {
self.normalize
}
#[inline(always)]
pub fn add_special_tokens(&self) -> bool {
self.add_special_tokens
}
#[inline(always)]
pub fn max_length(&self) -> Option<usize> {
self.max_length
}
#[inline(always)]
pub fn pad_token_id(&self) -> u32 {
self.pad_token_id
}
#[inline(always)]
pub fn dimension(&self) -> Option<usize> {
self.dimension
}
#[inline(always)]
pub fn apply_layer_norm(&self) -> bool {
self.apply_layer_norm
}
#[inline(always)]
pub fn apply_rms_norm(&self) -> bool {
self.apply_rms_norm
}
}
fn tokenize_and_pad(
tokenizer: &Tokenizer,
texts: &[&str],
add_special_tokens: bool,
max_length: Option<usize>,
pad_token_id: u32,
) -> Result<(Array, Array, usize)> {
let batch = texts.len();
let opts = EncodeOptions::new()
.with_add_special(add_special_tokens)
.with_truncate_to(max_length)
.with_return_attention_mask(true);
let mut rows: Vec<(Vec<u32>, Vec<u8>)> = try_with_capacity(batch)?;
for &text in texts {
let enc = tokenizer.encode_with(text, &opts)?;
if enc.attention_mask().len() != enc.ids().len() {
return Err(Error::LengthMismatch(LengthMismatchPayload::new(
"encode: encode_with(return_attention_mask=true) mask.len() must match ids.len()",
enc.ids().len(),
enc.attention_mask().len(),
)));
}
rows.push(enc.into_parts());
}
let seq_len = rows.iter().map(|(ids, _)| ids.len()).max().unwrap_or(0);
let total = batch.checked_mul(seq_len).ok_or_else(|| {
Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"encode: batch * seq_len",
"usize",
[("batch", batch as u64), ("seq_len", seq_len as u64)],
))
})?;
let pad_id = i32::try_from(pad_token_id).map_err(|_| {
Error::OutOfRange(OutOfRangePayload::new(
"encode: pad_token_id",
"must fit in i32 (the MLX index dtype)",
smol_str::format_smolstr!("{pad_token_id}"),
))
})?;
let mut id_data: Vec<i32> = try_with_capacity(total)?;
let mut mask_data: Vec<f32> = try_with_capacity(total)?;
for (ids, mask) in &rows {
let real = ids.len();
for &id in ids {
let id = i32::try_from(id).map_err(|_| {
Error::OutOfRange(OutOfRangePayload::new(
"encode: token id",
"must fit in i32 (the MLX index dtype)",
smol_str::format_smolstr!("{id}"),
))
})?;
id_data.push(id);
}
mask_data.extend(mask.iter().map(|&m| f32::from(m)));
let pad = seq_len - real;
id_data.extend(std::iter::repeat_n(pad_id, pad));
mask_data.extend(std::iter::repeat_n(0.0_f32, pad));
}
let input_ids = Array::from_slice::<i32>(&id_data, &(batch, seq_len))?;
let attention_mask = Array::from_slice::<f32>(&mask_data, &(batch, seq_len))?;
Ok((input_ids, attention_mask, seq_len))
}
pub fn encode(
model: &dyn EmbeddingModel,
tokenizer: &Tokenizer,
texts: &[&str],
cfg: &EncodeConfig,
) -> Result<Array> {
crate::error::ensure_handler_installed();
crate::stream::assert_streams_not_cleared();
let (input_ids, attention_mask, _seq_len) = tokenize_and_pad(
tokenizer,
texts,
cfg.add_special_tokens,
cfg.max_length,
cfg.pad_token_id,
)?;
let output = model.forward(&input_ids, &attention_mask)?;
let (last_hidden_state, pooled_output) = output.into_parts();
if matches!(cfg.strategy, PoolingStrategy::Cls | PoolingStrategy::None)
&& let Some(pooled) = pooled_output
{
let pooled_shape = pooled.shape();
if pooled_shape.len() != 2 {
return Err(Error::RankMismatch(RankMismatchPayload::new(
"encode: model pooled_output must be rank-2 (batch, hidden)",
pooled_shape.len() as u32,
pooled_shape,
)));
}
if pooled_shape[0] != texts.len() {
return Err(Error::LengthMismatch(LengthMismatchPayload::new(
"encode: model pooled_output batch must match texts",
texts.len(),
pooled_shape[0],
)));
}
let hidden_shape = last_hidden_state.shape();
if hidden_shape.len() != 3 {
return Err(Error::RankMismatch(RankMismatchPayload::new(
"encode: model last_hidden_state must be rank-3 (batch, seq_len, hidden)",
hidden_shape.len() as u32,
hidden_shape,
)));
}
if pooled_shape[1] != hidden_shape[2] {
return Err(Error::ShapePairMismatch(ShapePairMismatchPayload::new(
"encode: model pooled_output hidden width must match last_hidden_state hidden",
pooled_shape,
hidden_shape,
)));
}
return pool_post(
pooled,
cfg.normalize,
cfg.dimension,
cfg.apply_layer_norm,
cfg.apply_rms_norm,
);
}
pool(
&last_hidden_state,
&attention_mask,
cfg.strategy,
cfg.normalize,
cfg.dimension,
cfg.apply_layer_norm,
cfg.apply_rms_norm,
)
}
#[cfg(test)]
mod tests;