mistralrs-core 0.8.1

Fast, flexible LLM inference.
Documentation
#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]

use std::{any::Any, fmt::Debug, sync::Arc};

use anyhow::Result;
use candle_core::{Device, Tensor, WithDType};
use tokenizers::Tokenizer;

use crate::{
    device_map::DeviceMapper,
    pipeline::{
        text_models_inputs_processor::{make_flash_params, FlashParams, PagedAttentionMeta},
        InputProcessorOutput, InputsProcessor, InputsProcessorType, MessagesAction, Processor,
    },
    sequence::Sequence,
};

fn _make_tensor_with_pad<D: WithDType>(
    x: Vec<Vec<D>>,
    max_len: usize,
    pad: D,
    device: &Device,
) -> Result<Tensor> {
    let mut padded_x = Vec::new();
    for mut x_i in x {
        assert!(x_i.len() <= max_len);
        x_i.extend([pad].repeat(max_len - x_i.len()));
        let shape = (x_i.len(),);
        padded_x.push(Tensor::from_vec(x_i, shape, device)?);
    }
    Tensor::cat(&padded_x[..], 0).map_err(anyhow::Error::msg)
}

pub struct InputMetadata {
    pub input: Tensor,
    pub flash_meta: FlashParams,
}

pub struct InnerInputProcessorOutput {
    pub inputs: InputMetadata,
    pub seq_indices: Vec<usize>,
}

// chunk_offset_toks is the number of tokens by which the tokens are offset,
// chunk_offset_toks / prompt_chunksize = number of batches
#[allow(clippy::too_many_arguments)]
pub fn make_prompt_chunk<T: WithDType + Debug>(
    chunk_offset_toks: usize,
    toks: Vec<&[T]>,
    device: &Device,
    mapper: Option<&dyn DeviceMapper>,
    has_causal_attention: bool,
    sliding_window: Option<usize>,
) -> Result<InputMetadata> {
    let max_len = toks
        .iter()
        .map(|seq| seq.len())
        .max()
        .expect("No sequences");
    let padding_tok = T::zero();
    // Pad each sequence by the padding token to the max len.
    let mut seqs_tensors = Vec::new();
    let flash_attn = crate::using_flash_attn();
    let mut seqlens_q = if flash_attn { vec![0] } else { Vec::new() };
    let mut seqlens_k = if flash_attn { vec![0] } else { Vec::new() };
    for ctxt in toks {
        let mut ctxt = ctxt.to_vec();
        ctxt.extend(std::iter::repeat_n(
            padding_tok,
            max_len.saturating_sub(ctxt.len()),
        ));

        if flash_attn {
            seqlens_q.push(ctxt.len() as u32);
            seqlens_k.push((ctxt.len() + chunk_offset_toks) as u32);
        }

        seqs_tensors.push(Tensor::new(ctxt, device).unwrap().unsqueeze(0).unwrap());
    }

    let flash_meta = if flash_attn {
        make_flash_params(
            device,
            mapper,
            &seqlens_q,
            &seqlens_k,
            sliding_window,
            has_causal_attention,
        )?
    } else {
        FlashParams::empty(has_causal_attention)
    };

    let input = Tensor::cat(&seqs_tensors, 0).unwrap();

    Ok(InputMetadata { input, flash_meta })
}

#[allow(clippy::too_many_arguments)]
pub(crate) fn get_prompt_input<T: WithDType + std::fmt::Debug>(
    toks: Vec<&[T]>,
    input_seqs: &[&mut Sequence],
    device: &Device,
    mapper: Option<&dyn DeviceMapper>,
    has_causal_attention: bool,
    sliding_window: Option<usize>,
) -> Result<InnerInputProcessorOutput> {
    let offset = input_seqs[0].token_offset();
    make_prompt_chunk(
        offset,
        toks,
        device,
        mapper,
        has_causal_attention,
        sliding_window,
    )
    .map(|inputs| InnerInputProcessorOutput {
        inputs,
        seq_indices: (0..input_seqs.len()).collect(),
    })
}

#[derive(Clone)]
pub struct ModelInputs {
    pub input_ids: Tensor,
    pub flash_meta: FlashParams,
}

pub struct EmbeddingInputsProcessor {
    pub has_causal_attention: bool,
}

impl InputsProcessor for EmbeddingInputsProcessor {
    fn process_inputs(
        &self,
        _: Option<Arc<Tokenizer>>,
        input_seqs: &mut [&mut Sequence],
        is_prompt: bool,
        _is_xlora: bool,
        device: &Device,
        _no_kv_cache: bool,
        _last_n_context_len: Option<(usize, usize)>,
        _return_raw_logits: bool,
        _sliding_window: Option<usize>,
        _: Option<Arc<dyn Any>>,
        _paged_attn_metadata: Option<PagedAttentionMeta>,
        mapper: Option<&dyn DeviceMapper>,
    ) -> Result<InputProcessorOutput> {
        assert!(is_prompt);

        let metadata = get_prompt_input(
            input_seqs
                .iter()
                .map(|seq| seq.get_toks())
                .collect::<Vec<_>>(),
            input_seqs,
            device,
            mapper,
            self.has_causal_attention,
            None,
        )?;
        let InnerInputProcessorOutput {
            inputs:
                InputMetadata {
                    input: input_ids,
                    flash_meta,
                },
            seq_indices,
        } = metadata;
        let inputs: Box<dyn Any> = Box::new(ModelInputs {
            input_ids,
            flash_meta,
        });
        Ok(InputProcessorOutput {
            inputs,
            seq_indices,
        })
    }

    fn get_type(&self) -> InputsProcessorType {
        InputsProcessorType::Embedding
    }
}

pub struct EmbeddingProcessor {
    pub has_causal_attention: bool,
}

impl Processor for EmbeddingProcessor {
    fn inputs_processor(&self) -> Arc<dyn InputsProcessor> {
        Arc::new(EmbeddingInputsProcessor {
            has_causal_attention: self.has_causal_attention,
        })
    }
    fn get_special_tokens(&self) -> &[&'static str] {
        &[]
    }
    fn template_action(&self) -> MessagesAction {
        MessagesAction::Keep
    }
}