dakera-inference 0.6.2

Embedded inference engine for Dakera - generates embeddings locally
Documentation
//! Batch processing utilities for efficient embedding generation.

use crate::error::{InferenceError, Result};
use crate::models::EmbeddingModel;
use candle_core::{Device, Tensor};
use tokenizers::{PaddingParams, PaddingStrategy, Tokenizer, TruncationParams};
use tracing::{debug, instrument};

/// Prepared batch of tokenized inputs ready for inference.
#[derive(Debug)]
pub struct PreparedBatch {
    /// Token IDs tensor [batch_size, seq_len]
    pub input_ids: Tensor,
    /// Attention mask tensor [batch_size, seq_len]
    pub attention_mask: Tensor,
    /// Token type IDs tensor [batch_size, seq_len]
    pub token_type_ids: Tensor,
    /// Number of items in this batch
    pub batch_size: usize,
    /// Original text lengths (for debugging)
    pub original_lengths: Vec<usize>,
}

/// Batch processor for preparing text inputs for embedding models.
pub struct BatchProcessor {
    tokenizer: Tokenizer,
    model: EmbeddingModel,
    max_batch_size: usize,
}

impl BatchProcessor {
    /// Create a new batch processor.
    pub fn new(mut tokenizer: Tokenizer, model: EmbeddingModel, max_batch_size: usize) -> Self {
        // Configure padding
        let padding = PaddingParams {
            strategy: PaddingStrategy::BatchLongest,
            pad_id: tokenizer.get_padding().map_or(0, |p| p.pad_id),
            pad_token: tokenizer
                .get_padding()
                .map_or("[PAD]".to_string(), |p| p.pad_token.clone()),
            ..Default::default()
        };
        tokenizer.with_padding(Some(padding));

        // Configure truncation
        let truncation = TruncationParams {
            max_length: model.max_seq_length(),
            ..Default::default()
        };
        let _ = tokenizer.with_truncation(Some(truncation));

        Self {
            tokenizer,
            model,
            max_batch_size,
        }
    }

    /// Get the maximum batch size.
    pub fn max_batch_size(&self) -> usize {
        self.max_batch_size
    }

    /// Prepare texts for embedding, optionally applying model-specific prefixes.
    #[instrument(skip(self, texts), fields(count = texts.len()))]
    pub fn prepare_texts(&self, texts: &[String], is_query: bool) -> Vec<String> {
        let prefix = if is_query {
            self.model.query_prefix()
        } else {
            self.model.document_prefix()
        };

        match prefix {
            Some(p) => texts.iter().map(|t| format!("{}{}", p, t)).collect(),
            None => texts.to_vec(),
        }
    }

    /// Tokenize a batch of texts and prepare tensors for the model.
    #[instrument(skip(self, texts, device), fields(count = texts.len()))]
    pub fn tokenize_batch(&self, texts: &[String], device: &Device) -> Result<PreparedBatch> {
        if texts.is_empty() {
            return Err(InferenceError::InvalidInput("Empty text batch".into()));
        }

        if texts.len() > self.max_batch_size {
            return Err(InferenceError::InvalidInput(format!(
                "Batch size {} exceeds maximum {}",
                texts.len(),
                self.max_batch_size
            )));
        }

        let original_lengths: Vec<usize> = texts.iter().map(|t| t.len()).collect();

        debug!(
            "Tokenizing {} texts, max length: {}",
            texts.len(),
            original_lengths.iter().max().unwrap_or(&0)
        );

        // Tokenize all texts
        let encodings = self
            .tokenizer
            .encode_batch(texts.to_vec(), true)
            .map_err(|e| InferenceError::TokenizationError(e.to_string()))?;

        let batch_size = encodings.len();

        // Extract token IDs
        let input_ids: Vec<Vec<u32>> = encodings.iter().map(|e| e.get_ids().to_vec()).collect();

        // Extract attention masks
        let attention_masks: Vec<Vec<u32>> = encodings
            .iter()
            .map(|e| e.get_attention_mask().to_vec())
            .collect();

        // Extract token type IDs (or create zeros)
        let token_type_ids: Vec<Vec<u32>> = encodings
            .iter()
            .map(|e| {
                let type_ids = e.get_type_ids();
                if type_ids.is_empty() {
                    vec![0u32; e.get_ids().len()]
                } else {
                    type_ids.to_vec()
                }
            })
            .collect();

        // Get sequence length (should be uniform after padding)
        let seq_len = input_ids.first().map(|v| v.len()).unwrap_or(0);

        // Convert to tensors
        let input_ids_flat: Vec<u32> = input_ids.into_iter().flatten().collect();
        let attention_mask_flat: Vec<u32> = attention_masks.into_iter().flatten().collect();
        let token_type_ids_flat: Vec<u32> = token_type_ids.into_iter().flatten().collect();

        let input_ids = Tensor::from_vec(input_ids_flat, (batch_size, seq_len), device)?;
        let attention_mask = Tensor::from_vec(attention_mask_flat, (batch_size, seq_len), device)?;
        let token_type_ids = Tensor::from_vec(token_type_ids_flat, (batch_size, seq_len), device)?;

        debug!(
            "Created tensors: input_ids {:?}, attention_mask {:?}",
            input_ids.shape(),
            attention_mask.shape()
        );

        Ok(PreparedBatch {
            input_ids,
            attention_mask,
            token_type_ids,
            batch_size,
            original_lengths,
        })
    }

    /// Split texts into batches of maximum size.
    pub fn split_into_batches<'a>(&self, texts: &'a [String]) -> Vec<&'a [String]> {
        texts.chunks(self.max_batch_size).collect()
    }
}

/// Apply mean pooling to model outputs.
///
/// Mean pooling averages the token embeddings, weighted by the attention mask.
#[instrument(skip_all)]
pub fn mean_pooling(last_hidden_state: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
    // Expand attention mask to match hidden state dimensions
    // last_hidden_state: [batch, seq_len, hidden_size]
    // attention_mask: [batch, seq_len]
    let attention_mask = attention_mask.unsqueeze(2)?; // [batch, seq_len, 1]
    let attention_mask = attention_mask.to_dtype(last_hidden_state.dtype())?;

    // Expand to match hidden size (broadcast to last_hidden_state shape)
    let attention_mask = attention_mask.broadcast_as(last_hidden_state.shape())?;

    // Multiply hidden states by attention mask
    let masked_hidden = last_hidden_state.mul(&attention_mask)?;

    // Sum across sequence dimension
    let sum_hidden = masked_hidden.sum(1)?; // [batch, hidden_size]

    // Sum attention mask for normalization
    let sum_mask = attention_mask.sum(1)?; // [batch, hidden_size]

    // Clamp to avoid division by zero
    let sum_mask = sum_mask.clamp(1e-9, f64::MAX)?;

    // Divide to get mean
    let mean_pooled = sum_hidden.div(&sum_mask)?;

    debug!("Mean pooled shape: {:?}", mean_pooled.shape());

    Ok(mean_pooled)
}

/// Normalize embeddings to unit length (L2 normalization).
#[instrument(skip_all)]
pub fn normalize_embeddings(embeddings: &Tensor) -> Result<Tensor> {
    // Compute L2 norm across the embedding dimension
    let norm = embeddings.sqr()?.sum_keepdim(1)?.sqrt()?;

    // Clamp to avoid division by zero
    let norm = norm.clamp(1e-12, f64::MAX)?;

    // Normalize
    let normalized = embeddings.broadcast_div(&norm)?;

    debug!("Normalized embeddings shape: {:?}", normalized.shape());

    Ok(normalized)
}

#[cfg(test)]
mod tests {
    use super::*;

    /// Create a minimal tokenizer for unit tests (no network required).
    /// `prepare_texts` only uses the model's prefix logic, not the tokenizer,
    /// so any valid tokenizer works here.
    fn dummy_tokenizer() -> Tokenizer {
        use tokenizers::models::bpe::BPE;
        Tokenizer::new(BPE::default())
    }

    #[test]
    fn test_prepare_texts_with_prefix() {
        let processor = BatchProcessor::new(dummy_tokenizer(), EmbeddingModel::E5Small, 32);

        let texts = vec!["Hello world".to_string(), "Test query".to_string()];
        let prepared = processor.prepare_texts(&texts, true);

        assert_eq!(prepared[0], "query: Hello world");
        assert_eq!(prepared[1], "query: Test query");
    }

    #[test]
    fn test_prepare_texts_no_prefix() {
        let processor = BatchProcessor::new(dummy_tokenizer(), EmbeddingModel::MiniLM, 32);

        let texts = vec!["Hello world".to_string()];
        let prepared = processor.prepare_texts(&texts, true);

        assert_eq!(prepared[0], "Hello world");
    }
}