Skip to main content

mistralrs_core/pipeline/
mod.rs

1mod amoe;
2mod auto;
3pub mod chat_template;
4mod diffusion;
5mod embedding;
6mod ggml;
7mod gguf;
8pub(crate) mod hf;
9mod inputs_processor;
10mod isq;
11pub(crate) mod llg;
12mod loaders;
13mod macros;
14mod normal;
15mod paths;
16mod processing;
17mod response;
18mod sampling;
19mod speculative;
20mod speech;
21mod vision;
22
23pub use super::diffusion_models::DiffusionGenerationParams;
24use crate::amoe::{AnyMoeConfig, AnyMoeExpertType, AnyMoeTrainingInputs, AnyMoeTrainingResult};
25use crate::device_map::DeviceMapper;
26use crate::paged_attention::{CacheConfig, CacheEngine, ModelConfigLike};
27use crate::prefix_cacher::PrefixCacheManagerV2;
28use crate::PagedAttentionConfig;
29pub use amoe::{AnyMoeLoader, AnyMoePipeline};
30pub use auto::{AutoLoader, AutoLoaderBuilder};
31use chat_template::ChatTemplate;
32pub use diffusion::{DiffusionLoader, DiffusionLoaderBuilder};
33pub use embedding::{EmbeddingLoader, EmbeddingLoaderBuilder, EmbeddingSpecificConfig};
34pub use ggml::{GGMLLoader, GGMLLoaderBuilder, GGMLSpecificConfig};
35pub use gguf::{GGUFLoader, GGUFLoaderBuilder, GGUFSpecificConfig};
36use image::DynamicImage;
37pub use inputs_processor::InputProcessorOutput;
38pub(crate) use isq::IsqModelLoader;
39pub use isq::{parse_isq_value, IsqModel, IsqOrganization, UQFF_MULTI_FILE_DELIMITER};
40use llguidance::toktrie::TokEnv;
41pub use loaders::{
42    AdapterKind, AutoDeviceMapParams, AutoEmbeddingLoader, AutoNormalLoader, AutoVisionLoader,
43    DeepSeekV2Loader, DeepSeekV3Loader, DeviceMappedModelLoader, DiffusionLoaderType,
44    DiffusionModel, DiffusionModelLoader, EmbeddingGemmaLoader, EmbeddingLoaderType,
45    EmbeddingModel, EmbeddingModelLoader, EmbeddingModelPaths, EmbeddingModule,
46    EmbeddingModulePaths, EmbeddingModuleType, FluxLoader, GLM4Loader, GLM4MoeLiteLoader,
47    GLM4MoeLoader, Gemma2Loader, Gemma3Loader, Gemma3nLoader, GemmaLoader, GptOssLoader,
48    GraniteMoeHybridLoader, Idefics2Loader, Idefics3Loader, LLaVALoader, LLaVANextLoader,
49    LlamaLoader, Loader, LocalModelPaths, MiniCpmOLoader, Mistral3Loader, MistralLoader,
50    MixtralLoader, ModelKind, ModelPaths, NormalLoaderType, NormalLoadingMetadata, NormalModel,
51    NormalModelLoader, Phi2Loader, Phi3Loader, Phi3VLoader, Phi3_5MoELoader, Phi4MMLoader,
52    PrettyName, QuantizationKind, Qwen2Loader, Qwen2VLLoader, Qwen2_5VLLoader,
53    Qwen3EmbeddingLoader, Qwen3Loader, Qwen3MoELoader, Qwen3VLLoader, Qwen3VLMoELoader,
54    SmolLm3Loader, Starcoder2Loader, TokenSource, VLlama4Loader, VLlamaLoader, VisionLoaderType,
55    VisionModel, VisionModelLoader,
56};
57#[allow(clippy::too_many_arguments)]
58pub(crate) fn get_device_layers_for_loader(
59    loader: &dyn loaders::DeviceMappedModelLoader,
60    config: &str,
61    num_layers: usize,
62    layer_sizes_in_bytes: Vec<usize>,
63    non_mapped_size_in_bytes: usize,
64    total_model_size_in_bytes: usize,
65    devices: &[Device],
66    dtype: DType,
67    params: &loaders::AutoDeviceMapParams,
68    paged_attn_config: Option<&PagedAttentionConfig>,
69) -> Result<crate::device_map::DeviceMapMetadata> {
70    loaders::auto_device_map::get_device_layers(
71        loader,
72        config,
73        num_layers,
74        layer_sizes_in_bytes,
75        non_mapped_size_in_bytes,
76        total_model_size_in_bytes,
77        devices,
78        dtype,
79        params,
80        paged_attn_config,
81    )
82}
83use mistralrs_quant::IsqType;
84pub use normal::{NormalLoader, NormalLoaderBuilder, NormalSpecificConfig};
85pub(crate) use paths::{get_chat_template, get_model_paths, get_xlora_paths};
86pub use paths::{AdapterPaths, LoraAdapterPaths};
87pub(crate) use processing::{
88    apply_chat_template, BasicProcessor, MessagesAction, Processor, ProcessorCreator,
89};
90use rand_isaac::Isaac64Rng;
91pub use speculative::{SpeculativeConfig, SpeculativeLoader, SpeculativePipeline};
92pub use speech::{SpeechLoader, SpeechPipeline};
93use std::any::Any;
94use std::collections::HashMap;
95use std::fmt::Debug;
96use std::sync::Arc;
97use std::time::{Duration, Instant};
98use tokenizers::Tokenizer;
99pub use vision::{VisionLoader, VisionLoaderBuilder, VisionSpecificConfig};
100
101use anyhow::Result;
102use candle_core::{DType, Device, IndexOp, Tensor, Var};
103
104use crate::sequence::Sequence;
105
106pub use self::inputs_processor::{
107    text_models_inputs_processor, InputsProcessor, InputsProcessorType,
108};
109use self::text_models_inputs_processor::PagedAttentionMeta;
110pub use crate::kv_cache::{
111    Cache, CacheManager, EitherCache, KvCache, LayerCaches, NormalCache, NormalCacheType,
112};
113
114#[derive(Clone, PartialEq, Eq)]
115pub enum SupportedModality {
116    Text,
117    Audio,
118    Vision,
119    Embedding,
120}
121
122impl Debug for SupportedModality {
123    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
124        match self {
125            Self::Text => write!(f, "📝 Text"),
126            Self::Audio => write!(f, "🔊 Audio"),
127            Self::Vision => write!(f, "🖼️ Vision"),
128            Self::Embedding => write!(f, "🔢 Embedding"),
129        }
130    }
131}
132
133#[derive(Debug, Clone)]
134pub struct Modalities {
135    pub input: Vec<SupportedModality>,
136    pub output: Vec<SupportedModality>,
137}
138
139pub struct GeneralMetadata {
140    pub max_seq_len: usize,
141    /// Only None if it doesn't make sense for the model
142    pub llg_factory: Option<Arc<llguidance::ParserFactory>>,
143    pub no_kv_cache: bool,
144    pub no_prefix_cache: bool,
145    pub num_hidden_layers: usize,
146    pub eos_tok: Vec<u32>,
147    pub kind: ModelKind,
148    // TODO: Replace is_xlora queries to check via kind instead:
149    pub is_xlora: bool,
150    pub activation_dtype: DType,
151    pub sliding_window: Option<usize>,
152    // PagedAttention stuff
153    pub cache_config: Option<CacheConfig>,
154    pub cache_engine: Option<CacheEngine>,
155    pub model_metadata: Option<Arc<dyn ModelConfigLike + Send + Sync>>,
156    pub modalities: Modalities,
157}
158
159impl GeneralMetadata {
160    pub fn tok_env(&self) -> Option<TokEnv> {
161        self.llg_factory.as_ref().map(|f| f.tok_env().clone())
162    }
163}
164
165pub enum CacheInstruction {
166    In,
167    Out,
168    /// load_preallocated_cache means to load the preallocated cache, if applicable.
169    Reset {
170        load_preallocated_cache: bool,
171        reset_non_granular: bool,
172    },
173    Nothing,
174}
175
176pub trait PreProcessingMixin: MetadataMixin {
177    fn get_processor(&self) -> Arc<dyn Processor> {
178        Arc::new(BasicProcessor)
179    }
180    /// Only None if it doesnt make sense for the model
181    fn get_chat_template(&self) -> Option<Arc<ChatTemplate>>;
182    fn get_input_processor_config(&self) -> Option<Arc<dyn Any>>;
183}
184
185pub trait IsqPipelineMixin {
186    fn re_isq_model(&mut self, dtype: IsqType) -> Result<()>;
187}
188
189pub trait CacheManagerMixin {
190    /// Clone the cache FROM the sequences' cache TO the model cache. Only called for completion seqs.
191    /// It is not a guarantee that this will be called for each completion step.
192    fn clone_in_cache(&self, seqs: &mut [&mut Sequence]);
193    /// Clone the cache FROM the model cache TO the sequences. Called for prompt and completion seqs.
194    /// It is not a guarantee that this will be called for each step.
195    fn clone_out_cache(&self, seqs: &mut [&mut Sequence]);
196    /// Set the model cache to all None. Only called for prompt seqs.
197    /// It is not a guarantee that this will be called for each prompt step.
198    /// This may also reset the non granular state if applicable.
199    fn set_none_cache(
200        &self,
201        seqs: &mut [&mut Sequence],
202        reset_non_granular: bool,
203        modify_draft_cache: bool,
204        load_preallocated_cache: bool,
205    );
206    fn cache(&self) -> &EitherCache;
207    fn do_preallocated_cache(&self) -> bool {
208        matches!(self.cache(), EitherCache::Normal(_))
209    }
210}
211
212pub trait MetadataMixin {
213    fn device(&self) -> Device;
214    /// Only None if it doesnt make sense for the model
215    fn tokenizer(&self) -> Option<Arc<Tokenizer>>;
216    fn name(&self) -> String;
217    fn reset_non_granular_state(&self);
218    fn get_metadata(&self) -> Arc<GeneralMetadata>;
219    fn device_mapper(&self) -> Option<&dyn DeviceMapper>;
220}
221
222/// Implemented by the base model of an AnyMoe.
223pub trait AnyMoePipelineMixin {
224    /// Get vars for each gating layer
225    fn amoe_layer_vars(&self) -> Vec<Vec<Var>> {
226        unreachable!()
227    }
228    fn amoe_finish_training(&mut self, _gate_model_id: Option<String>) -> candle_core::Result<()> {
229        unreachable!()
230    }
231    fn amoe_base_model_trainable_params(&self) -> usize {
232        unreachable!()
233    }
234    fn amoe_supported(&self) -> bool {
235        false
236    }
237    /// Per-layer cached outputs.
238    fn amoe_take_cached_gating_outputs(&mut self) -> Vec<Tensor> {
239        unreachable!()
240    }
241    /// Inject the MoE layers
242    #[allow(clippy::too_many_arguments)]
243    fn amoe_create_layers(
244        &mut self,
245        _model_ids: Vec<String>,
246        _token: &TokenSource,
247        _revision: Option<String>,
248        _match_regex: &str,
249        _config: AnyMoeConfig,
250        _dtype: DType,
251        _dev: &Device,
252        (_prefix, _mlp): (String, String),
253        _layers: Vec<usize>,
254        _expert_type: AnyMoeExpertType,
255        _silent: bool,
256        _gate_model_id: Option<String>,
257    ) -> candle_core::Result<()> {
258        unreachable!()
259    }
260    /// Pre-train the gating layers
261    #[allow(clippy::too_many_arguments)]
262    fn amoe_pre_train(
263        &self,
264        _inputs: AnyMoeTrainingInputs,
265        (_prefix, _mlp): (String, String),
266        _model_ids: Vec<String>,
267        _token: TokenSource,
268        _revision: Option<String>,
269        _layers: Vec<usize>,
270        _silent: bool,
271    ) -> Result<Option<AnyMoeTrainingResult>, candle_core::Error> {
272        unreachable!()
273    }
274}
275
276/// Category of the model. This can also be used to extract model-category specific tools,
277/// such as the vision model prompt prefixer.
278#[derive(Clone)]
279pub enum ModelCategory {
280    Text,
281    Vision {
282        prefixer: Arc<dyn MultimodalPromptPrefixer>,
283    },
284    Diffusion,
285    Audio,
286    Speech,
287    Embedding,
288}
289
290impl std::fmt::Debug for ModelCategory {
291    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
292        match self {
293            ModelCategory::Text => write!(f, "ModelCategory::Text"),
294            ModelCategory::Vision { .. } => write!(f, "ModelCategory::Vision {{ prefixer: .. }}"),
295            ModelCategory::Diffusion => write!(f, "ModelCategory::Diffusion"),
296            ModelCategory::Audio => write!(f, "ModelCategory::Audio"),
297            ModelCategory::Speech => write!(f, "ModelCategory::Speech"),
298            ModelCategory::Embedding => write!(f, "ModelCategory::Embedding"),
299        }
300    }
301}
302
303impl PartialEq for ModelCategory {
304    fn eq(&self, other: &Self) -> bool {
305        match (self, other) {
306            (Self::Text, Self::Text) => true,
307            (Self::Vision { .. }, Self::Vision { .. }) => true,
308            (Self::Audio, Self::Audio) => true,
309            (Self::Speech, Self::Speech) => true,
310            (Self::Diffusion, Self::Diffusion) => true,
311            (Self::Embedding, Self::Embedding) => true,
312            (
313                Self::Text
314                | Self::Vision { .. }
315                | Self::Diffusion
316                | Self::Audio
317                | Self::Speech
318                | Self::Embedding,
319                _,
320            ) => false,
321        }
322    }
323}
324
325/// Prepend a vision tag appropriate for the model to the prompt. Image indexing is assumed that start at 0.
326pub trait MultimodalPromptPrefixer: Send + Sync {
327    /// Prefix for inclusion in messages (may do nothing if the chat template handles it).
328    fn prefix_image(&self, _image_indices: Vec<usize>, prompt: &str) -> String {
329        prompt.to_string()
330    }
331    /// Prefix for inclusion in messages (may do nothing if the chat template handles it).
332    fn prefix_audio(&self, _audio_indexes: Vec<usize>, prompt: &str) -> String {
333        prompt.to_string()
334    }
335}
336
337pub enum CacheBackendMetadata {
338    DefaultInstructions {
339        pre_op: CacheInstruction,
340        post_op: CacheInstruction,
341    },
342    PagedAttention {
343        metadata: PagedAttentionMeta,
344        blocks_to_copy: HashMap<usize, Vec<usize>>,
345    },
346}
347
348#[derive(Clone, Debug)]
349pub enum ForwardInputsResult {
350    RawLogits {
351        logits: Tensor,
352    },
353    Embeddings {
354        embeddings: Tensor,
355    },
356    CausalGeneration {
357        logits: Tensor,
358    },
359    Image {
360        images: Vec<DynamicImage>,
361    },
362    Speech {
363        pcms: Vec<Arc<Vec<f32>>>,
364        rates: Vec<usize>,
365        channels: Vec<usize>,
366    },
367}
368
369impl ForwardInputsResult {
370    fn index_bs(&self, bs_idx: usize) -> candle_core::Result<Self> {
371        match self {
372            Self::CausalGeneration { logits } => Ok(Self::CausalGeneration {
373                logits: logits.i(bs_idx)?,
374            }),
375            Self::Embeddings { embeddings } => Ok(Self::Embeddings {
376                embeddings: embeddings.i(bs_idx)?,
377            }),
378            Self::RawLogits { logits } => Ok(Self::RawLogits {
379                logits: logits.i(bs_idx)?,
380            }),
381            Self::Image { images } => Ok(Self::Image {
382                images: vec![images[bs_idx].clone()],
383            }),
384            Self::Speech {
385                pcms,
386                rates,
387                channels,
388            } => Ok(Self::Speech {
389                pcms: vec![pcms[bs_idx].clone()],
390                rates: vec![rates[bs_idx]],
391                channels: vec![channels[bs_idx]],
392            }),
393        }
394    }
395
396    fn to_device(&self, device: &Device) -> candle_core::Result<Self> {
397        match self {
398            Self::CausalGeneration { logits } => Ok(Self::CausalGeneration {
399                logits: logits.to_device(device)?,
400            }),
401            Self::RawLogits { logits } => Ok(Self::RawLogits {
402                logits: logits.to_device(device)?,
403            }),
404            Self::Embeddings { embeddings } => Ok(Self::Embeddings {
405                embeddings: embeddings.to_device(device)?,
406            }),
407            Self::Image { .. } => Ok(self.clone()),
408            Self::Speech { .. } => Ok(self.clone()),
409        }
410    }
411}
412
413#[derive(serde::Serialize, serde::Deserialize)]
414pub(crate) struct FileListCache {
415    files: Vec<String>,
416}
417
418#[async_trait::async_trait]
419pub trait Pipeline:
420    Send
421    + Sync
422    + PreProcessingMixin
423    + IsqPipelineMixin
424    + CacheManagerMixin
425    + MetadataMixin
426    + AnyMoePipelineMixin
427{
428    fn forward_inputs(
429        &mut self,
430        inputs: Box<dyn Any>,
431        return_raw_logits: bool,
432    ) -> Result<ForwardInputsResult, candle_core::Error>;
433
434    /// Returns the total of model execution time.
435    #[allow(clippy::too_many_arguments)]
436    async fn step(
437        &mut self,
438        input_seqs: &mut [&mut Sequence],
439        is_prompt: bool,
440        return_raw_logits: bool,
441        prefix_cacher: &mut PrefixCacheManagerV2,
442        disable_eos_stop: bool,
443        rng: Arc<std::sync::Mutex<Isaac64Rng>>,
444        backend_metadata: CacheBackendMetadata,
445    ) -> Result<Duration, candle_core::Error> {
446        match backend_metadata {
447            CacheBackendMetadata::DefaultInstructions { pre_op, post_op } => {
448                let inputs_iter =
449                    std::iter::once(self.get_processor().inputs_processor().process_inputs(
450                        self.tokenizer(),
451                        input_seqs,
452                        is_prompt,
453                        self.get_metadata().is_xlora,
454                        &self.device(),
455                        self.get_metadata().no_kv_cache,
456                        None,
457                        return_raw_logits,
458                        self.get_input_processor_config(),
459                        None,
460                        self.device_mapper(),
461                    ));
462
463                let mut logits = vec![None; input_seqs.len()];
464                let len_inputs = 1;
465                let mut raw_out_logits = vec![vec![None; len_inputs]; input_seqs.len()];
466                let mut embedding_logits = vec![None; input_seqs.len()];
467
468                let mut exec_duration = Duration::ZERO;
469                for (i, inputs) in inputs_iter.into_iter().enumerate() {
470                    let InputProcessorOutput {
471                        inputs,
472                        seq_indices,
473                    } = inputs.map_err(candle_core::Error::msg)?;
474                    if i == 0 {
475                        match pre_op {
476                            CacheInstruction::In => self.clone_in_cache(input_seqs),
477                            CacheInstruction::Nothing => (),
478                            CacheInstruction::Reset {
479                                load_preallocated_cache,
480                                reset_non_granular,
481                            } => self.set_none_cache(
482                                input_seqs,
483                                reset_non_granular,
484                                false,
485                                load_preallocated_cache,
486                            ),
487                            _ => unreachable!("Unreachable PRE cache op."),
488                        }
489                    }
490
491                    let start = Instant::now();
492                    let raw_logits = self.forward_inputs(inputs, return_raw_logits)?;
493                    let end = Instant::now();
494                    exec_duration += end.duration_since(start);
495
496                    for (logit_idx, seq_idx) in seq_indices.into_iter().enumerate() {
497                        if let ForwardInputsResult::RawLogits { logits } = &raw_logits {
498                            raw_out_logits[seq_idx][i] =
499                                Some(logits.i(logit_idx)?.to_device(&Device::Cpu)?);
500                        } else if let ForwardInputsResult::Embeddings { embeddings } = &raw_logits {
501                            embedding_logits[seq_idx] =
502                                Some(embeddings.i(logit_idx)?.to_device(&Device::Cpu)?);
503                        } else {
504                            logits[seq_idx] = Some(raw_logits.index_bs(logit_idx)?);
505                        }
506                    }
507                }
508
509                match post_op {
510                    CacheInstruction::Out => self.clone_out_cache(input_seqs),
511                    CacheInstruction::Nothing => (),
512                    CacheInstruction::Reset {
513                        load_preallocated_cache,
514                        reset_non_granular,
515                    } => self.set_none_cache(
516                        input_seqs,
517                        reset_non_granular,
518                        false,
519                        load_preallocated_cache,
520                    ),
521                    _ => unreachable!("Unreachable POST cache op."),
522                }
523
524                if raw_out_logits[0][0].is_some() {
525                    let start = Instant::now();
526                    response::send_raw_responses(
527                        input_seqs,
528                        raw_out_logits
529                            .into_iter()
530                            .map(|raw| raw.into_iter().flatten().collect::<Vec<_>>())
531                            .collect(),
532                    )
533                    .await?;
534                    let end = Instant::now();
535                    exec_duration += end.duration_since(start);
536
537                    return Ok(exec_duration);
538                }
539                if embedding_logits[0].is_some() {
540                    let start = Instant::now();
541                    response::send_embedding_responses(
542                        input_seqs,
543                        embedding_logits
544                            .into_iter()
545                            .map(|raw| {
546                                raw.unwrap()
547                                    .to_dtype(DType::F32)
548                                    .unwrap()
549                                    .to_vec1::<f32>()
550                                    .unwrap()
551                            })
552                            .collect(),
553                    )
554                    .await?;
555                    let end = Instant::now();
556                    exec_duration += end.duration_since(start);
557
558                    return Ok(exec_duration);
559                }
560
561                let start = Instant::now();
562                let logits_on_cpu = logits.len() > 1;
563                let logits = logits
564                    .into_iter()
565                    .map(|l| {
566                        let l = l.expect("Did not get any inputs. This is shocking.");
567                        if logits_on_cpu {
568                            l.to_device(&Device::Cpu)
569                        } else {
570                            Ok(l)
571                        }
572                    })
573                    .collect::<candle_core::Result<Vec<_>>>()?;
574
575                match &logits[0] {
576                    ForwardInputsResult::RawLogits { .. }
577                    | ForwardInputsResult::Embeddings { .. } => unreachable!(),
578                    ForwardInputsResult::CausalGeneration { .. } => {
579                        self.sample_causal_gen(
580                            input_seqs,
581                            logits
582                                .into_iter()
583                                .map(|r| {
584                                    #[allow(irrefutable_let_patterns)]
585                                    let ForwardInputsResult::CausalGeneration { logits } = r
586                                    else {
587                                        unreachable!(
588                                            "All results must have same type, `CausalGeneration`"
589                                        )
590                                    };
591                                    logits
592                                })
593                                .collect::<Vec<_>>(),
594                            prefix_cacher,
595                            disable_eos_stop,
596                            rng,
597                        )
598                        .await?;
599                    }
600                    ForwardInputsResult::Image { .. } => {
601                        response::send_image_responses(
602                            input_seqs,
603                            logits
604                                .into_iter()
605                                .map(|r| {
606                                    #[allow(irrefutable_let_patterns)]
607                                    let ForwardInputsResult::Image { images } = r
608                                    else {
609                                        unreachable!("All results must have same type, `Image`")
610                                    };
611                                    images
612                                        .into_iter()
613                                        .next()
614                                        .expect("Must have at least 1 element.")
615                                })
616                                .collect::<Vec<_>>(),
617                        )
618                        .await?;
619                    }
620                    ForwardInputsResult::Speech { .. } => {
621                        let rates = logits
622                            .iter()
623                            .map(|r| {
624                                #[allow(irrefutable_let_patterns)]
625                                let ForwardInputsResult::Speech { rates, .. } = r
626                                else {
627                                    unreachable!("All results must have same type, `Speech`")
628                                };
629                                assert_eq!(rates.len(), 1, "Each sequence must have 1 PCM output.");
630                                *rates.first().unwrap()
631                            })
632                            .collect::<Vec<_>>();
633                        let channels = logits
634                            .iter()
635                            .map(|r| {
636                                #[allow(irrefutable_let_patterns)]
637                                let ForwardInputsResult::Speech { channels, .. } = r
638                                else {
639                                    unreachable!("All results must have same type, `Speech`")
640                                };
641                                assert_eq!(
642                                    channels.len(),
643                                    1,
644                                    "Each sequence must have 1 PCM output."
645                                );
646                                *channels.first().unwrap()
647                            })
648                            .collect::<Vec<_>>();
649                        let pcms = logits
650                            .into_iter()
651                            .map(|r| {
652                                #[allow(irrefutable_let_patterns)]
653                                let ForwardInputsResult::Speech { pcms, .. } = r
654                                else {
655                                    unreachable!("All results must have same type, `Speech`")
656                                };
657                                assert_eq!(pcms.len(), 1, "Each sequence must have 1 PCM output.");
658                                pcms.into_iter().nth(0).unwrap()
659                            })
660                            .collect::<Vec<_>>();
661                        response::send_speech_responses(input_seqs, &pcms, &rates, &channels)
662                            .await?;
663                    }
664                }
665                let end = Instant::now();
666                exec_duration += end.duration_since(start);
667
668                Ok(exec_duration)
669            }
670            CacheBackendMetadata::PagedAttention {
671                metadata,
672                blocks_to_copy,
673            } => {
674                // Cloning might be bad?
675                self.get_metadata()
676                    .cache_engine
677                    .as_ref()
678                    .expect("PagedAttention must have cache engines.")
679                    .execute_scheduler_ops(&blocks_to_copy)?;
680
681                let inputs_iter =
682                    std::iter::once(self.get_processor().inputs_processor().process_inputs(
683                        self.tokenizer(),
684                        input_seqs,
685                        is_prompt,
686                        self.get_metadata().is_xlora,
687                        &self.device(),
688                        self.get_metadata().no_kv_cache,
689                        None,
690                        return_raw_logits,
691                        self.get_input_processor_config(),
692                        Some(metadata),
693                        self.device_mapper(),
694                    ));
695
696                let mut logits = vec![None; input_seqs.len()];
697                let len_inputs = 1;
698                let mut raw_out_logits = vec![vec![None; len_inputs]; input_seqs.len()];
699                let mut embedding_logits = vec![None; input_seqs.len()];
700
701                let mut exec_duration = Duration::ZERO;
702                for (i, inputs) in inputs_iter.into_iter().enumerate() {
703                    let InputProcessorOutput {
704                        inputs,
705                        seq_indices,
706                    } = inputs.map_err(candle_core::Error::msg)?;
707
708                    let start = Instant::now();
709                    let raw_logits = self.forward_inputs(inputs, return_raw_logits)?;
710                    let end = Instant::now();
711                    exec_duration += end.duration_since(start);
712
713                    for (logit_idx, seq_idx) in seq_indices.into_iter().enumerate() {
714                        if let ForwardInputsResult::RawLogits { logits } = &raw_logits {
715                            raw_out_logits[seq_idx][i] =
716                                Some(logits.i(logit_idx)?.to_device(&Device::Cpu)?);
717                        } else if let ForwardInputsResult::Embeddings { embeddings } = &raw_logits {
718                            embedding_logits[seq_idx] =
719                                Some(embeddings.i(logit_idx)?.to_device(&Device::Cpu)?);
720                        } else {
721                            logits[seq_idx] = Some(raw_logits.index_bs(logit_idx)?);
722                        }
723                    }
724                }
725
726                if raw_out_logits[0][0].is_some() {
727                    let start = Instant::now();
728                    response::send_raw_responses(
729                        input_seqs,
730                        raw_out_logits
731                            .into_iter()
732                            .map(|raw| raw.into_iter().flatten().collect::<Vec<_>>())
733                            .collect(),
734                    )
735                    .await?;
736                    let end = Instant::now();
737                    exec_duration += end.duration_since(start);
738
739                    return Ok(exec_duration);
740                }
741                if embedding_logits[0].is_some() {
742                    let start = Instant::now();
743                    response::send_embedding_responses(
744                        input_seqs,
745                        embedding_logits
746                            .into_iter()
747                            .map(|raw| {
748                                raw.unwrap()
749                                    .to_dtype(DType::F32)
750                                    .unwrap()
751                                    .to_vec1::<f32>()
752                                    .unwrap()
753                            })
754                            .collect(),
755                    )
756                    .await?;
757                    let end = Instant::now();
758                    exec_duration += end.duration_since(start);
759
760                    return Ok(exec_duration);
761                }
762
763                let start = Instant::now();
764                let logits_on_cpu = logits.len() > 1;
765                let logits = logits
766                    .into_iter()
767                    .map(|l| {
768                        let l = l.expect("Did not get any inputs. This is shocking.");
769                        if logits_on_cpu {
770                            l.to_device(&Device::Cpu)
771                        } else {
772                            Ok(l)
773                        }
774                    })
775                    .collect::<candle_core::Result<Vec<_>>>()?;
776
777                match &logits[0] {
778                    ForwardInputsResult::RawLogits { .. }
779                    | ForwardInputsResult::Embeddings { .. } => unreachable!(),
780                    ForwardInputsResult::CausalGeneration { .. } => {
781                        self.sample_causal_gen(
782                            input_seqs,
783                            logits
784                                .into_iter()
785                                .map(|r| {
786                                    #[allow(irrefutable_let_patterns)]
787                                    let ForwardInputsResult::CausalGeneration { logits } = r
788                                    else {
789                                        unreachable!("All results must have same type")
790                                    };
791                                    logits
792                                })
793                                .collect::<Vec<_>>(),
794                            prefix_cacher,
795                            disable_eos_stop,
796                            rng,
797                        )
798                        .await?;
799                    }
800                    ForwardInputsResult::Image { .. } => {
801                        response::send_image_responses(
802                            input_seqs,
803                            logits
804                                .into_iter()
805                                .map(|r| {
806                                    #[allow(irrefutable_let_patterns)]
807                                    let ForwardInputsResult::Image { images } = r
808                                    else {
809                                        unreachable!("All results must have same type, `Image`")
810                                    };
811                                    images
812                                        .into_iter()
813                                        .next()
814                                        .expect("Must have at least 1 element.")
815                                })
816                                .collect::<Vec<_>>(),
817                        )
818                        .await?;
819                    }
820                    ForwardInputsResult::Speech { .. } => {
821                        let rates = logits
822                            .iter()
823                            .map(|r| {
824                                #[allow(irrefutable_let_patterns)]
825                                let ForwardInputsResult::Speech { rates, .. } = r
826                                else {
827                                    unreachable!("All results must have same type, `Speech`")
828                                };
829                                assert_eq!(rates.len(), 1, "Each sequence must have 1 PCM output.");
830                                *rates.first().unwrap()
831                            })
832                            .collect::<Vec<_>>();
833                        let channels = logits
834                            .iter()
835                            .map(|r| {
836                                #[allow(irrefutable_let_patterns)]
837                                let ForwardInputsResult::Speech { channels, .. } = r
838                                else {
839                                    unreachable!("All results must have same type, `Speech`")
840                                };
841                                assert_eq!(
842                                    channels.len(),
843                                    1,
844                                    "Each sequence must have 1 PCM output."
845                                );
846                                *channels.first().unwrap()
847                            })
848                            .collect::<Vec<_>>();
849                        let pcms = logits
850                            .into_iter()
851                            .map(|r| {
852                                #[allow(irrefutable_let_patterns)]
853                                let ForwardInputsResult::Speech { pcms, .. } = r
854                                else {
855                                    unreachable!("All results must have same type, `Speech`")
856                                };
857                                assert_eq!(pcms.len(), 1, "Each sequence must have 1 PCM output.");
858                                pcms.into_iter().nth(0).unwrap()
859                            })
860                            .collect::<Vec<_>>();
861                        response::send_speech_responses(input_seqs, &pcms, &rates, &channels)
862                            .await?;
863                    }
864                }
865                let end = Instant::now();
866                exec_duration += end.duration_since(start);
867
868                Ok(exec_duration)
869            }
870        }
871    }
872
873    async fn sample_causal_gen(
874        &self,
875        seqs: &mut [&mut Sequence],
876        logits: Vec<Tensor>,
877        prefix_cacher: &mut PrefixCacheManagerV2,
878        disable_eos_stop: bool,
879        rng: Arc<std::sync::Mutex<Isaac64Rng>>,
880    ) -> Result<(), candle_core::Error>;
881
882    fn category(&self) -> ModelCategory;
883}
884
885pub(crate) fn extract_logits(
886    logits: &Tensor,
887    context_lens: Vec<(usize, usize)>,
888) -> candle_core::Result<Tensor> {
889    let mut toks = Vec::new();
890    for (dim, (start, len)) in logits.chunk(logits.dims()[0], 0)?.iter().zip(context_lens) {
891        toks.push(dim.narrow(1, start, len)?);
892    }
893    Tensor::cat(&toks, 0)
894}
895
896#[cfg(test)]
897mod tests {
898    use crate::MessageContent;
899    use either::Either;
900    use indexmap::IndexMap;
901    use serde_json::Value;
902
903    macro_rules! hashmap {
904        (@single $($x:tt)*) => (());
905        (@count $($rest:expr),*) => (<[()]>::len(&[$(hashmap!(@single $rest)),*]));
906
907        ($($key:expr => $value:expr,)+) => { hashmap!($($key => $value),+) };
908        ($($key:expr => $value:expr),*) => {
909            {
910                let _cap = hashmap!(@count $($key),*);
911                let mut _map = ::indexmap::IndexMap::with_capacity(_cap);
912                $(
913                    let _ = _map.insert($key, Value::String($value));
914                )*
915                _map
916            }
917        };
918    }
919
920    #[cfg(test)]
921    #[track_caller]
922    fn test_with_inputs(
923        templates: &[(bool, &str, &str, &str, &str)],
924        expected_outputs: &[&str],
925        inputs: Vec<IndexMap<String, MessageContent>>,
926    ) {
927        use crate::pipeline::chat_template::ChatTemplateValue;
928
929        use super::chat_template::apply_chat_template_to;
930        let mut failed = Vec::new();
931        let n_templates = templates.len();
932        for ((has_system, bos, eos, unk, template), expected) in
933            templates.iter().zip(expected_outputs)
934        {
935            let output = match apply_chat_template_to(
936                if !has_system {
937                    inputs[1..].to_vec()
938                } else {
939                    inputs.clone()
940                },
941                true,
942                None,
943                None, // reasoning_effort
944                &ChatTemplateValue(Either::Left(template.to_string())),
945                Some(bos.to_string()),
946                Some(eos.to_string()),
947                Some(unk.to_string()),
948                Vec::new(),
949            ) {
950                Ok(v) => v,
951                Err(e) => {
952                    failed.push(format!("Failed with {e}."));
953                    continue;
954                }
955            };
956            if output != *expected {
957                failed.push(format!(
958                    "Expected: `{}` \n\nGot:      `{}`",
959                    expected.replace('\n', "\\n"),
960                    output.replace('\n', "\\n")
961                ));
962            }
963        }
964        if !failed.is_empty() {
965            for (i, line) in failed.iter().enumerate() {
966                println!("------------ Template {i} ------------");
967                println!("{line}");
968            }
969            println!("------------------------");
970            panic!("{}/{n_templates} chat templates failed.", failed.len());
971        }
972    }
973
974    #[test]
975    /// Generating these cases:
976    /// ```py
977    /// >>> t=transformers.AutoTokenizer.from_pretrained(...)
978    /// # If non-system prompt model
979    /// >>> t.apply_chat_template([{"role":"user","content":"Hello"},{"role":"assistant","content":"Hi there"},{"role":"user","content":"Who are you"},{"role":"assistant","content":"   I am an assistant   "},{"role":"user","content":"Another question"}], add_generation_prompt=True, tokenize=False)
980    /// # If system prompt model
981    /// >>> t.apply_chat_template([{"role":"system","content":"You are a helpful assistant"},{"role":"user","content":"Hello"},{"role":"assistant","content":"Hi there"},{"role":"user","content":"Who are you"},{"role":"assistant","content":"   I am an assistant   "},{"role":"user","content":"Another question"}], add_generation_prompt=True, tokenize=False)
982    /// ```
983    fn test_chat_templates() {
984        let templates = [
985            // ChatML: https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B
986            (true, "<s>", "</s>", "<unk>", "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"),
987            // mistralai/Mistral-7B-Instruct-v0.1
988            (false, "<s>", "</s>", "<unk>", "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"),
989            // meta-llama/Llama-2-13b-chat-hf
990            (true, "<s>", "</s>", "<unk>", "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' '  + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}"),
991            // mistralai/Mixtral-8x7B-Instruct-v0.1
992            (false, "<s>", "</s>", "<unk>", "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"),
993            // google/gemma-7b-it
994            (false, "<bos>", "<eos>", "<unk>", "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}"),
995            // HuggingFaceM4/idefics2-8b-chatty
996            (true, "<s>", "</s>", "<unk>", "{% for message in messages %}{{message['role'].capitalize()}}{% if message['content'][0]['type'] == 'image' %}{{':'}}{% else %}{{': '}}{% endif %}{% for line in message['content'] %}{% if line['type'] == 'text' %}{{line['text']}}{% elif line['type'] == 'image' %}{{ '<image>' }}{% endif %}{% endfor %}<end_of_utterance>\n{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"),
997        ];
998        let expected_outputs = [
999            // ChatML: https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B
1000            "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi there<|im_end|>\n<|im_start|>user\nWho are you<|im_end|>\n<|im_start|>assistant\n   I am an assistant   <|im_end|>\n<|im_start|>user\nAnother question<|im_end|>\n<|im_start|>assistant\n",
1001            // mistralai/Mistral-7B-Instruct-v0.1
1002            "<s>[INST] Hello [/INST]Hi there</s> [INST] Who are you [/INST]   I am an assistant   </s> [INST] Another question [/INST]",
1003            // meta-llama/Llama-2-13b-chat-hf
1004            "<s>[INST] <<SYS>>\nYou are a helpful assistant\n<</SYS>>\n\nHello [/INST] Hi there </s><s>[INST] Who are you [/INST] I am an assistant </s><s>[INST] Another question [/INST]",
1005            // mistralai/Mixtral-8x7B-Instruct-v0.1
1006            "<s>[INST] Hello [/INST]Hi there</s>[INST] Who are you [/INST]   I am an assistant   </s>[INST] Another question [/INST]",
1007            // google/gemma-7b-it
1008            "<bos><start_of_turn>user\nHello<end_of_turn>\n<start_of_turn>model\nHi there<end_of_turn>\n<start_of_turn>user\nWho are you<end_of_turn>\n<start_of_turn>model\nI am an assistant<end_of_turn>\n<start_of_turn>user\nAnother question<end_of_turn>\n<start_of_turn>model\n",
1009        ];
1010        let messages = [
1011            ["system", "You are a helpful assistant"],
1012            ["user", "Hello"],
1013            ["assistant", "Hi there"],
1014            ["user", "Who are you"],
1015            ["assistant", "   I am an assistant   "],
1016            ["user", "Another question"],
1017        ];
1018        let mut inputs = Vec::new();
1019        for [role, content] in messages {
1020            let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
1021                IndexMap::new();
1022            message.insert("role".to_string(), Either::Left(role.to_string()));
1023            message.insert("content".to_string(), Either::Left(content.to_string()));
1024            inputs.push(message);
1025        }
1026        test_with_inputs(&templates, &expected_outputs, inputs);
1027    }
1028
1029    #[test]
1030    /// Generating these cases:
1031    /// ```py
1032    /// >>> processor=transformers.AutoProcessor.from_pretrained(...)
1033    /// >>> processor.apply_chat_template([
1034    ///         {"role":"system","content":[{"type":"text", "text": "You are a helpful assistant"}]},
1035    ///         {"role":"user","content":[{"type":"image"}, {"type":"text", "text": "Hello, please describe the above."}]},
1036    ///         {"role":"assistant","content":[{"type":"text", "text": "Hi there"}]},
1037    ///         {"role":"user","content":[{"type":"text", "text": "Who are you"}]},
1038    ///         {"role":"assistant","content":[{"type":"text", "text": "   I am an assistant   "}]},
1039    ///         {"role":"user","content":[{"type":"text", "text": "Another question"}]}
1040    ///     ], add_generation_prompt=True, tokenize=False)
1041    /// ```
1042    fn test_image_chat_templates() {
1043        let templates = [
1044            // HuggingFaceM4/idefics2-8b-chatty
1045            (true, "<s>", "</s>", "<unk>", "{% for message in messages %}{{message['role'].capitalize()}}{% if message['content'][0]['type'] == 'image' %}{{':'}}{% else %}{{': '}}{% endif %}{% for line in message['content'] %}{% if line['type'] == 'text' %}{{line['text']}}{% elif line['type'] == 'image' %}{{ '<image>' }}{% endif %}{% endfor %}<end_of_utterance>\n{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"),
1046        ];
1047        let expected_outputs = [
1048            // HuggingFaceM4/idefics2-8b-chatty
1049            "System: You are a helpful assistant<end_of_utterance>\nUser:<image>Hello, please describe the above.<end_of_utterance>\nAssistant: Hi there<end_of_utterance>\nUser:<image>This is me, who are you<end_of_utterance>\nAssistant:    I am an assistant   <end_of_utterance>\nUser:<image>Another question, what is this?<end_of_utterance>\nAssistant:",
1050        ];
1051
1052        let mut inputs = Vec::new();
1053
1054        let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
1055            IndexMap::new();
1056        message.insert("role".to_string(), Either::Left("system".to_string()));
1057        message.insert(
1058            "content".to_string(),
1059            Either::Right(vec![hashmap! {
1060                "type".to_string() => "text".to_string(),
1061                "text".to_string() => "You are a helpful assistant".to_string()
1062            }]),
1063        );
1064        inputs.push(message);
1065
1066        let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
1067            IndexMap::new();
1068        message.insert("role".to_string(), Either::Left("user".to_string()));
1069        message.insert(
1070            "content".to_string(),
1071            Either::Right(vec![
1072                hashmap! {
1073                    "type".to_string() => "image".to_string()
1074                },
1075                hashmap! {
1076                    "type".to_string() => "text".to_string(),
1077                    "text".to_string() => "Hello, please describe the above.".to_string()
1078                },
1079            ]),
1080        );
1081        inputs.push(message);
1082
1083        let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
1084            IndexMap::new();
1085        message.insert("role".to_string(), Either::Left("assistant".to_string()));
1086        message.insert(
1087            "content".to_string(),
1088            Either::Right(vec![hashmap! {
1089                "type".to_string() => "text".to_string(),
1090                "text".to_string() => "Hi there".to_string()
1091            }]),
1092        );
1093        inputs.push(message);
1094
1095        let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
1096            IndexMap::new();
1097        message.insert("role".to_string(), Either::Left("user".to_string()));
1098        message.insert(
1099            "content".to_string(),
1100            Either::Right(vec![
1101                hashmap! {
1102                    "type".to_string() => "image".to_string()
1103                },
1104                hashmap! {
1105                    "type".to_string() => "text".to_string(),
1106                    "text".to_string() => "This is me, who are you".to_string()
1107                },
1108            ]),
1109        );
1110        inputs.push(message);
1111
1112        let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
1113            IndexMap::new();
1114        message.insert("role".to_string(), Either::Left("assistant".to_string()));
1115        message.insert(
1116            "content".to_string(),
1117            Either::Right(vec![hashmap! {
1118                "type".to_string() => "text".to_string(),
1119                "text".to_string() => "   I am an assistant   ".to_string()
1120            }]),
1121        );
1122        inputs.push(message);
1123
1124        let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
1125            IndexMap::new();
1126        message.insert("role".to_string(), Either::Left("user".to_string()));
1127        message.insert(
1128            "content".to_string(),
1129            Either::Right(vec![
1130                hashmap! {
1131                    "type".to_string() => "image".to_string()
1132                },
1133                hashmap! {
1134                    "type".to_string() => "text".to_string(),
1135                    "text".to_string() => "Another question, what is this?".to_string()
1136                },
1137            ]),
1138        );
1139        inputs.push(message);
1140
1141        test_with_inputs(&templates, &expected_outputs, inputs);
1142    }
1143}