#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
use std::{any::Any, fmt::Debug, sync::Arc};
use anyhow::Result;
use candle_core::{Device, Tensor, WithDType};
use tokenizers::Tokenizer;
use crate::{
device_map::DeviceMapper,
pipeline::{
text_models_inputs_processor::{make_flash_params, FlashParams, PagedAttentionMeta},
InputProcessorOutput, InputsProcessor, InputsProcessorType, MessagesAction, Processor,
},
sequence::Sequence,
};
fn _make_tensor_with_pad<D: WithDType>(
x: Vec<Vec<D>>,
max_len: usize,
pad: D,
device: &Device,
) -> Result<Tensor> {
let mut padded_x = Vec::new();
for mut x_i in x {
assert!(x_i.len() <= max_len);
x_i.extend([pad].repeat(max_len - x_i.len()));
let shape = (x_i.len(),);
padded_x.push(Tensor::from_vec(x_i, shape, device)?);
}
Tensor::cat(&padded_x[..], 0).map_err(anyhow::Error::msg)
}
pub struct InputMetadata {
pub input: Tensor,
pub flash_meta: FlashParams,
}
pub struct InnerInputProcessorOutput {
pub inputs: InputMetadata,
pub seq_indices: Vec<usize>,
}
#[allow(clippy::too_many_arguments)]
pub fn make_prompt_chunk<T: WithDType + Debug>(
chunk_offset_toks: usize,
toks: Vec<&[T]>,
device: &Device,
mapper: Option<&dyn DeviceMapper>,
has_causal_attention: bool,
sliding_window: Option<usize>,
) -> Result<InputMetadata> {
let max_len = toks
.iter()
.map(|seq| seq.len())
.max()
.expect("No sequences");
let padding_tok = T::zero();
let mut seqs_tensors = Vec::new();
let flash_attn = crate::using_flash_attn();
let mut seqlens_q = if flash_attn { vec![0] } else { Vec::new() };
let mut seqlens_k = if flash_attn { vec![0] } else { Vec::new() };
for ctxt in toks {
let mut ctxt = ctxt.to_vec();
ctxt.extend(std::iter::repeat_n(
padding_tok,
max_len.saturating_sub(ctxt.len()),
));
if flash_attn {
seqlens_q.push(ctxt.len() as u32);
seqlens_k.push((ctxt.len() + chunk_offset_toks) as u32);
}
seqs_tensors.push(Tensor::new(ctxt, device).unwrap().unsqueeze(0).unwrap());
}
let flash_meta = if flash_attn {
make_flash_params(
device,
mapper,
&seqlens_q,
&seqlens_k,
sliding_window,
has_causal_attention,
)?
} else {
FlashParams::empty(has_causal_attention)
};
let input = Tensor::cat(&seqs_tensors, 0).unwrap();
Ok(InputMetadata { input, flash_meta })
}
#[allow(clippy::too_many_arguments)]
pub(crate) fn get_prompt_input<T: WithDType + std::fmt::Debug>(
toks: Vec<&[T]>,
input_seqs: &[&mut Sequence],
device: &Device,
mapper: Option<&dyn DeviceMapper>,
has_causal_attention: bool,
sliding_window: Option<usize>,
) -> Result<InnerInputProcessorOutput> {
let offset = input_seqs[0].token_offset();
make_prompt_chunk(
offset,
toks,
device,
mapper,
has_causal_attention,
sliding_window,
)
.map(|inputs| InnerInputProcessorOutput {
inputs,
seq_indices: (0..input_seqs.len()).collect(),
})
}
#[derive(Clone)]
pub struct ModelInputs {
pub input_ids: Tensor,
pub flash_meta: FlashParams,
}
pub struct EmbeddingInputsProcessor {
pub has_causal_attention: bool,
}
impl InputsProcessor for EmbeddingInputsProcessor {
fn process_inputs(
&self,
_: Option<Arc<Tokenizer>>,
input_seqs: &mut [&mut Sequence],
is_prompt: bool,
_is_xlora: bool,
device: &Device,
_no_kv_cache: bool,
_last_n_context_len: Option<(usize, usize)>,
_return_raw_logits: bool,
_sliding_window: Option<usize>,
_: Option<Arc<dyn Any>>,
_paged_attn_metadata: Option<PagedAttentionMeta>,
mapper: Option<&dyn DeviceMapper>,
) -> Result<InputProcessorOutput> {
assert!(is_prompt);
let metadata = get_prompt_input(
input_seqs
.iter()
.map(|seq| seq.get_toks())
.collect::<Vec<_>>(),
input_seqs,
device,
mapper,
self.has_causal_attention,
None,
)?;
let InnerInputProcessorOutput {
inputs:
InputMetadata {
input: input_ids,
flash_meta,
},
seq_indices,
} = metadata;
let inputs: Box<dyn Any> = Box::new(ModelInputs {
input_ids,
flash_meta,
});
Ok(InputProcessorOutput {
inputs,
seq_indices,
})
}
fn get_type(&self) -> InputsProcessorType {
InputsProcessorType::Embedding
}
}
pub struct EmbeddingProcessor {
pub has_causal_attention: bool,
}
impl Processor for EmbeddingProcessor {
fn inputs_processor(&self) -> Arc<dyn InputsProcessor> {
Arc::new(EmbeddingInputsProcessor {
has_causal_attention: self.has_causal_attention,
})
}
fn get_special_tokens(&self) -> &[&'static str] {
&[]
}
fn template_action(&self) -> MessagesAction {
MessagesAction::Keep
}
}