mistralrs-core 0.8.1

Fast, flexible LLM inference.
Documentation
#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]

use std::{cmp::Ordering, sync::Arc};

use anyhow::{Context, Result};
use candle_core::{DType, Device, Error as E};
use mistralrs_quant::log::once_log_info;
use tokenizers::Tokenizer;
use tokio::sync::Mutex as TokioMutex;

use crate::pipeline::ForwardInputsResult;
use crate::{
    embedding_models::inputs_processor::{make_prompt_chunk, ModelInputs},
    engine::SearchEmbeddingModel,
    get_mut_arcmutex, AutoDeviceMapParams, DeviceMapSetting, EmbeddingLoaderBuilder,
    EmbeddingSpecificConfig, ModelDType, Pipeline, TokenSource,
};

use super::SearchResult;

const EMBEDDING_BATCH: usize = 8;

pub struct SearchPipeline {
    model: Arc<TokioMutex<dyn Pipeline + Send + Sync>>,
    tokenizer: Arc<Tokenizer>,
    device: Device,
    has_causal_attention: bool,
    max_seq_len: usize,
}

#[derive(Debug, Clone)]
struct DocumentChunk {
    prompt: String,
    content: String,
    token_len: usize,
}

#[derive(Debug, Clone)]
pub struct ScoredChunk {
    pub result_index: usize,
    pub content: String,
    pub token_len: usize,
    pub score: f32,
}

impl SearchPipeline {
    pub fn new(model: SearchEmbeddingModel, runner_device: &Device) -> anyhow::Result<Self> {
        let model_id = model.hf_model_id().to_string();

        once_log_info(format!("Loading embedding model ({model_id})."));

        let loader = EmbeddingLoaderBuilder::new(
            EmbeddingSpecificConfig::default(),
            None,
            Some(model_id.clone()),
        )
        .build(None);

        let pipeline = loader.load_model_from_hf(
            None,
            TokenSource::CacheToken,
            &ModelDType::Auto,
            runner_device,
            true,
            DeviceMapSetting::Auto(AutoDeviceMapParams::default_text()),
            None,
            None,
        )?;

        let guard = get_mut_arcmutex!(pipeline);
        let tokenizer = guard
            .tokenizer()
            .with_context(|| "Embedding model did not expose a tokenizer")?
            .clone();
        let device = guard.device();
        let max_seq_len = guard.get_metadata().max_seq_len;
        drop(guard);

        Ok(Self {
            model: pipeline,
            tokenizer,
            device,
            has_causal_attention: false,
            max_seq_len,
        })
    }

    fn embed(&mut self, prompts: &[String]) -> anyhow::Result<Vec<Vec<f32>>> {
        if prompts.is_empty() {
            return Ok(Vec::new());
        }

        use std::collections::BTreeMap;
        let mut by_len: BTreeMap<usize, Vec<(usize, Vec<u32>)>> = BTreeMap::new();
        for (idx, prompt) in prompts.iter().enumerate() {
            let encoding = self
                .tokenizer
                .encode(prompt.as_str(), true)
                .map_err(E::msg)?;
            let ids = encoding.get_ids().to_vec();
            by_len.entry(ids.len()).or_default().push((idx, ids));
        }

        let mut outputs = vec![Vec::new(); prompts.len()];
        for (_, sequences) in by_len {
            for chunk_entries in sequences.chunks(EMBEDDING_BATCH) {
                let slices: Vec<&[u32]> = chunk_entries
                    .iter()
                    .map(|(_, ids)| ids.as_slice())
                    .collect();
                let chunk = make_prompt_chunk(
                    0,
                    slices,
                    &self.device,
                    None,
                    self.has_causal_attention,
                    None,
                )?;
                let inputs = Box::new(ModelInputs {
                    input_ids: chunk.input,
                    flash_meta: chunk.flash_meta,
                });
                let mut pipeline = get_mut_arcmutex!(self.model);
                let ForwardInputsResult::Embeddings { embeddings } =
                    pipeline.forward_inputs(inputs, false)?
                else {
                    anyhow::bail!("Embedding pipeline returned non-embedding output");
                };
                drop(pipeline);
                let vecs = embeddings
                    .to_dtype(DType::F32)?
                    .to_device(&Device::Cpu)?
                    .to_vec2::<f32>()?;
                for ((idx, _), embedding) in chunk_entries.iter().zip(vecs.into_iter()) {
                    outputs[*idx] = embedding;
                }
            }
        }

        Ok(outputs)
    }
}

impl SearchPipeline {
    fn chunk_document(&self, sanitized_title: &str, text: &str) -> Result<Vec<DocumentChunk>> {
        let trimmed = text.trim();
        if trimmed.is_empty() {
            return Ok(Vec::new());
        }

        let prefix_prompt = format_document_prompt(sanitized_title, "");
        let prefix_tokens = self
            .tokenizer
            .encode(prefix_prompt.as_str(), true)
            .map_err(E::msg)?
            .len();
        if self.max_seq_len <= prefix_tokens {
            return Ok(vec![DocumentChunk {
                prompt: format_document_prompt(sanitized_title, trimmed),
                content: trimmed.to_string(),
                token_len: self.tokenizer.encode(trimmed, false).map_err(E::msg)?.len(),
            }]);
        }

        let budget = (self.max_seq_len - prefix_tokens).max(1);
        let encoding = self.tokenizer.encode(trimmed, false).map_err(E::msg)?;
        let token_count = encoding.len();
        if token_count == 0 {
            return Ok(Vec::new());
        }

        let offsets = encoding.get_offsets();
        let mut chunks = Vec::new();
        let mut start_idx = 0;

        while start_idx < token_count {
            let mut end_idx = (start_idx + budget).min(token_count);
            let mut accepted: Option<(DocumentChunk, usize)> = None;

            while end_idx > start_idx {
                let range = offsets[start_idx].0..offsets[end_idx - 1].1;
                let chunk_slice = trimmed[range.clone()].trim();
                if chunk_slice.is_empty() {
                    end_idx -= 1;
                    continue;
                }

                let candidate = format_document_prompt(sanitized_title, chunk_slice);
                let candidate_len = self
                    .tokenizer
                    .encode(candidate.as_str(), true)
                    .map_err(E::msg)?
                    .len();

                if candidate_len <= self.max_seq_len {
                    let token_len = end_idx - start_idx;
                    accepted = Some((
                        DocumentChunk {
                            prompt: candidate,
                            content: chunk_slice.to_string(),
                            token_len,
                        },
                        end_idx,
                    ));
                    break;
                }
                end_idx -= 1;
            }

            if let Some((chunk, next_idx)) = accepted {
                chunks.push(chunk);
                start_idx = next_idx;
                continue;
            }

            let single_range = offsets[start_idx].0..offsets[start_idx].1;
            let chunk_slice = trimmed[single_range].trim();
            if !chunk_slice.is_empty() {
                chunks.push(DocumentChunk {
                    prompt: format_document_prompt(sanitized_title, chunk_slice),
                    content: chunk_slice.to_string(),
                    token_len: 1,
                });
            }
            start_idx += 1;
        }

        Ok(chunks)
    }
}

fn sanitize_title(title: &str) -> String {
    let trimmed = title.trim();
    if trimmed.is_empty() {
        "none".to_string()
    } else {
        trimmed.to_string()
    }
}

fn format_query_prompt(query: &str) -> String {
    format!("task: search result | query: {}", query.trim())
}

fn format_document_prompt(title: &str, text: &str) -> String {
    format!("title: {title} | text: {}", text.trim())
}

fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
    let dot = a.iter().zip(b).map(|(x, y)| x * y).sum::<f32>();
    let norm_a = a.iter().map(|x| x * x).sum::<f32>().sqrt();
    let norm_b = b.iter().map(|x| x * x).sum::<f32>().sqrt();
    if norm_a == 0.0 || norm_b == 0.0 {
        0.0
    } else {
        dot / (norm_a * norm_b)
    }
}

pub fn rank_document_chunks(
    query: &str,
    results: &[SearchResult],
    pipeline: &mut SearchPipeline,
) -> Result<Vec<ScoredChunk>> {
    if results.is_empty() {
        return Ok(Vec::new());
    }

    let query_embedding = pipeline
        .embed(&[format_query_prompt(query)])?
        .into_iter()
        .next()
        .context("Failed to generate embedding for search query")?;

    let mut bindings: Vec<(usize, DocumentChunk)> = Vec::new();
    for (result_index, result) in results.iter().enumerate() {
        let title = sanitize_title(&result.title);
        bindings.extend(
            pipeline
                .chunk_document(&title, &result.content)?
                .into_iter()
                .map(|chunk| (result_index, chunk)),
        );
    }

    if bindings.is_empty() {
        return Ok(Vec::new());
    }

    let prompts: Vec<String> = bindings
        .iter()
        .map(|(_, chunk)| chunk.prompt.clone())
        .collect();

    let chunk_embeddings = pipeline.embed(&prompts)?;
    let mut scored = Vec::with_capacity(bindings.len());
    for ((result_index, chunk), embedding) in bindings.into_iter().zip(chunk_embeddings.into_iter())
    {
        let score = cosine_similarity(&query_embedding, &embedding);
        scored.push(ScoredChunk {
            result_index,
            content: chunk.content,
            token_len: chunk.token_len,
            score,
        });
    }

    scored.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Less));
    Ok(scored)
}