Skip to main content

hanzo_engine/pipeline/
mod.rs

1mod amoe;
2mod auto;
3pub mod chat_template;
4mod diffusion;
5mod embedding;
6mod ggml;
7mod gguf;
8pub(crate) mod hf;
9mod inputs_processor;
10mod isq;
11pub(crate) mod llg;
12mod loaders;
13mod macros;
14mod multimodal;
15mod normal;
16mod paths;
17mod processing;
18mod response;
19pub(crate) mod sampling;
20mod speech;
21
22pub use super::diffusion_models::DiffusionGenerationParams;
23use crate::amoe::{AnyMoeConfig, AnyMoeExpertType, AnyMoeTrainingInputs, AnyMoeTrainingResult};
24use crate::device_map::DeviceMapper;
25use crate::paged_attention::{CacheConfig, CacheEngine, ModelConfigLike};
26use crate::prefix_cacher::PrefixCacheManagerV2;
27use crate::PagedAttentionConfig;
28pub use amoe::{AnyMoeLoader, AnyMoePipeline};
29pub use auto::{AutoLoader, AutoLoaderBuilder};
30use chat_template::ChatTemplate;
31pub use diffusion::{DiffusionLoader, DiffusionLoaderBuilder};
32pub(crate) use embedding::EmbeddingLoadContext;
33pub use embedding::{EmbeddingLoader, EmbeddingLoaderBuilder, EmbeddingSpecificConfig};
34pub use ggml::{GGMLLoader, GGMLLoaderBuilder, GGMLSpecificConfig};
35pub use gguf::{GGUFLoader, GGUFLoaderBuilder, GGUFSpecificConfig};
36use image::DynamicImage;
37pub use inputs_processor::InputProcessorOutput;
38pub(crate) use isq::IsqModelLoader;
39pub use isq::{
40    expand_isq_value, parse_isq_value, parse_uqff_shard, resolve_uqff_shorthand, IsqModel,
41    IsqOrganization, UQFF_MULTI_FILE_DELIMITER,
42};
43use llguidance::toktrie::TokEnv;
44pub use loaders::{
45    AdapterKind, AutoDeviceMapParams, AutoEmbeddingLoader, AutoMultimodalLoader, AutoNormalLoader,
46    DeepSeekV2Loader, DeepSeekV3Loader, DeviceMappedModelLoader, DiffusionLoaderType,
47    DiffusionModel, DiffusionModelLoader, EmbeddingGemmaLoader, EmbeddingLoaderType,
48    EmbeddingModel, EmbeddingModelLoader, EmbeddingModelPaths, EmbeddingModule,
49    EmbeddingModulePaths, EmbeddingModuleType, FluxLoader, GLM4Loader, GLM4MoeLiteLoader,
50    GLM4MoeLoader, Gemma2Loader, Gemma3Loader, Gemma3nLoader, Gemma4Loader, GemmaLoader,
51    GptOssLoader, GraniteMoeHybridLoader, Idefics2Loader, Idefics3Loader, LLaVALoader,
52    LLaVANextLoader, LlamaLoader, Loader, LocalModelPaths, MiniCpmOLoader, Mistral3Loader,
53    MistralLoader, MixtralLoader, ModelKind, ModelPaths, MultimodalLoaderType, MultimodalModel,
54    MultimodalModelLoader, NormalLoaderType, NormalLoadingMetadata, NormalModel, NormalModelLoader,
55    Phi2Loader, Phi3Loader, Phi3VLoader, Phi3_5MoELoader, Phi4MMLoader, PrettyName,
56    QuantizationKind, Qwen2Loader, Qwen2VLLoader, Qwen2_5VLLoader, Qwen3EmbeddingLoader,
57    Qwen3Loader, Qwen3MoELoader, Qwen3NextLoader, Qwen3VLLoader, Qwen3VLMoELoader, Qwen3_5Loader,
58    Qwen3_5MoeLoader, SmolLm3Loader, Starcoder2Loader, TokenSource, VLlama4Loader, VLlamaLoader,
59    VoxtralLoader,
60};
61#[allow(clippy::too_many_arguments)]
62pub(crate) fn get_device_layers_for_loader(
63    loader: &dyn loaders::DeviceMappedModelLoader,
64    config: &str,
65    num_layers: usize,
66    layer_sizes_in_bytes: Vec<usize>,
67    non_mapped_size_in_bytes: usize,
68    total_model_size_in_bytes: usize,
69    devices: &[Device],
70    dtype: DType,
71    params: &loaders::AutoDeviceMapParams,
72    paged_attn_config: Option<&PagedAttentionConfig>,
73) -> Result<crate::device_map::DeviceMapMetadata> {
74    loaders::auto_device_map::get_device_layers(
75        loader,
76        config,
77        num_layers,
78        layer_sizes_in_bytes,
79        non_mapped_size_in_bytes,
80        total_model_size_in_bytes,
81        devices,
82        dtype,
83        params,
84        paged_attn_config,
85    )
86}
87use hanzo_quant::IsqType;
88pub use multimodal::{MultimodalLoader, MultimodalLoaderBuilder, MultimodalSpecificConfig};
89pub use normal::{NormalLoader, NormalLoaderBuilder, NormalSpecificConfig};
90pub(crate) use paths::{get_chat_template, get_model_paths, get_xlora_paths};
91pub use paths::{AdapterPaths, LoraAdapterPaths};
92pub(crate) use processing::{
93    apply_chat_template, BasicProcessor, MessagesAction, Processor, ProcessorCreator,
94};
95use rand_isaac::Isaac64Rng;
96pub use speech::{SpeechLoader, SpeechPipeline};
97use std::any::Any;
98use std::fmt::Debug;
99use std::sync::atomic::AtomicUsize;
100use std::sync::Arc;
101use std::time::{Duration, Instant};
102
103use tokenizers::Tokenizer;
104
105use anyhow::Result;
106use hanzo_ml::{DType, Device, IndexOp, Tensor, Var};
107
108use crate::sequence::Sequence;
109
110pub use self::inputs_processor::{
111    text_models_inputs_processor, InputsProcessor, InputsProcessorType,
112};
113use self::text_models_inputs_processor::PagedAttentionMeta;
114pub use crate::kv_cache::{
115    Cache, CacheManager, EitherCache, HybridLayerCache, KvCache, LayerCaches, NormalCache,
116    NormalCacheType,
117};
118
119#[derive(Clone, PartialEq, Eq)]
120pub enum SupportedModality {
121    Text,
122    Audio,
123    Vision,
124    Video,
125    Embedding,
126}
127
128impl Debug for SupportedModality {
129    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
130        match self {
131            Self::Text => write!(f, "📝 Text"),
132            Self::Audio => write!(f, "🔊 Audio"),
133            Self::Vision => write!(f, "🖼️ Vision"),
134            Self::Video => write!(f, "🎬 Video"),
135            Self::Embedding => write!(f, "🔢 Embedding"),
136        }
137    }
138}
139
140#[derive(Debug, Clone)]
141pub struct Modalities {
142    pub input: Vec<SupportedModality>,
143    pub output: Vec<SupportedModality>,
144}
145
146pub struct GeneralMetadata {
147    pub max_seq_len: usize,
148    /// Only None if it doesn't make sense for the model
149    pub llg_factory: Option<Arc<llguidance::ParserFactory>>,
150    pub no_kv_cache: bool,
151    pub no_prefix_cache: bool,
152    pub num_hidden_layers: usize,
153    pub eos_tok: Vec<u32>,
154    pub kind: ModelKind,
155    // TODO: Replace is_xlora queries to check via kind instead:
156    pub is_xlora: bool,
157    pub activation_dtype: DType,
158    pub sliding_window: Option<usize>,
159    // PagedAttention stuff
160    pub cache_config: Option<CacheConfig>,
161    pub cache_engine: Option<CacheEngine>,
162    pub model_metadata: Option<Arc<dyn ModelConfigLike + Send + Sync>>,
163    pub modalities: Modalities,
164}
165
166impl GeneralMetadata {
167    pub fn tok_env(&self) -> Option<TokEnv> {
168        self.llg_factory.as_ref().map(|f| f.tok_env().clone())
169    }
170}
171
172#[derive(Clone, Copy)]
173pub enum CacheInstruction {
174    In,
175    Out,
176    /// load_preallocated_cache means to load the preallocated cache, if applicable.
177    Reset {
178        load_preallocated_cache: bool,
179        reset_non_granular: bool,
180    },
181    Nothing,
182}
183
184pub trait PreProcessingMixin: MetadataMixin {
185    fn get_processor(&self) -> Arc<dyn Processor> {
186        Arc::new(BasicProcessor)
187    }
188    /// Only None if it doesnt make sense for the model
189    fn get_chat_template(&self) -> Option<Arc<ChatTemplate>>;
190    fn get_input_processor_config(&self) -> Option<Arc<dyn Any>>;
191}
192
193pub trait IsqPipelineMixin {
194    fn re_isq_model(&mut self, dtype: IsqType) -> Result<()>;
195}
196
197pub trait CacheManagerMixin {
198    /// Clone the cache FROM the sequences' cache TO the model cache. Only called for completion seqs.
199    /// It is not a guarantee that this will be called for each completion step.
200    fn clone_in_cache(&self, seqs: &mut [&mut Sequence]);
201    /// Clone the cache FROM the model cache TO the sequences. Called for prompt and completion seqs.
202    /// It is not a guarantee that this will be called for each step.
203    fn clone_out_cache(&self, seqs: &mut [&mut Sequence]);
204    /// Set the model cache to all None. Only called for prompt seqs.
205    /// It is not a guarantee that this will be called for each prompt step.
206    /// This may also reset the non granular state if applicable.
207    fn set_none_cache(
208        &self,
209        seqs: &mut [&mut Sequence],
210        reset_non_granular: bool,
211        modify_draft_cache: bool,
212        load_preallocated_cache: bool,
213    );
214    fn cache(&self) -> &EitherCache;
215}
216
217pub trait MetadataMixin {
218    fn device(&self) -> Device;
219    /// Only None if it doesnt make sense for the model
220    fn tokenizer(&self) -> Option<Arc<Tokenizer>>;
221    fn name(&self) -> String;
222    fn reset_non_granular_state(&self);
223    fn get_metadata(&self) -> Arc<GeneralMetadata>;
224    fn generation_defaults(&self) -> Option<crate::ModelGenerationDefaults> {
225        None
226    }
227    fn device_mapper(&self) -> Option<&dyn DeviceMapper>;
228}
229
230/// Implemented by the base model of an AnyMoe.
231pub trait AnyMoePipelineMixin {
232    /// Get vars for each gating layer
233    fn amoe_layer_vars(&self) -> Vec<Vec<Var>> {
234        unreachable!()
235    }
236    fn amoe_finish_training(&mut self, _gate_model_id: Option<String>) -> hanzo_ml::Result<()> {
237        unreachable!()
238    }
239    fn amoe_base_model_trainable_params(&self) -> usize {
240        unreachable!()
241    }
242    fn amoe_supported(&self) -> bool {
243        false
244    }
245    /// Per-layer cached outputs.
246    fn amoe_take_cached_gating_outputs(&mut self) -> Vec<Tensor> {
247        unreachable!()
248    }
249    /// Inject the MoE layers
250    #[allow(clippy::too_many_arguments)]
251    fn amoe_create_layers(
252        &mut self,
253        _model_ids: Vec<String>,
254        _token: &TokenSource,
255        _revision: Option<String>,
256        _match_regex: &str,
257        _config: AnyMoeConfig,
258        _dtype: DType,
259        _dev: &Device,
260        (_prefix, _mlp): (String, String),
261        _layers: Vec<usize>,
262        _expert_type: AnyMoeExpertType,
263        _silent: bool,
264        _gate_model_id: Option<String>,
265    ) -> hanzo_ml::Result<()> {
266        unreachable!()
267    }
268    /// Pre-train the gating layers
269    #[allow(clippy::too_many_arguments)]
270    fn amoe_pre_train(
271        &self,
272        _inputs: AnyMoeTrainingInputs,
273        (_prefix, _mlp): (String, String),
274        _model_ids: Vec<String>,
275        _token: TokenSource,
276        _revision: Option<String>,
277        _layers: Vec<usize>,
278        _silent: bool,
279    ) -> Result<Option<AnyMoeTrainingResult>, hanzo_ml::Error> {
280        unreachable!()
281    }
282}
283
284/// Category of the model. This can also be used to extract model-category specific tools,
285/// such as the multimodal model prompt prefixer.
286#[derive(Clone)]
287pub enum ModelCategory {
288    Text,
289    Multimodal {
290        prefixer: Arc<dyn MultimodalPromptPrefixer>,
291    },
292    Diffusion,
293    Audio,
294    Speech,
295    Embedding,
296}
297
298impl std::fmt::Debug for ModelCategory {
299    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
300        match self {
301            ModelCategory::Text => write!(f, "ModelCategory::Text"),
302            ModelCategory::Multimodal { .. } => {
303                write!(f, "ModelCategory::Multimodal {{ prefixer: .. }}")
304            }
305            ModelCategory::Diffusion => write!(f, "ModelCategory::Diffusion"),
306            ModelCategory::Audio => write!(f, "ModelCategory::Audio"),
307            ModelCategory::Speech => write!(f, "ModelCategory::Speech"),
308            ModelCategory::Embedding => write!(f, "ModelCategory::Embedding"),
309        }
310    }
311}
312
313impl PartialEq for ModelCategory {
314    fn eq(&self, other: &Self) -> bool {
315        match (self, other) {
316            (Self::Text, Self::Text) => true,
317            (Self::Multimodal { .. }, Self::Multimodal { .. }) => true,
318            (Self::Audio, Self::Audio) => true,
319            (Self::Speech, Self::Speech) => true,
320            (Self::Diffusion, Self::Diffusion) => true,
321            (Self::Embedding, Self::Embedding) => true,
322            (
323                Self::Text
324                | Self::Multimodal { .. }
325                | Self::Diffusion
326                | Self::Audio
327                | Self::Speech
328                | Self::Embedding,
329                _,
330            ) => false,
331        }
332    }
333}
334
335/// Prepend a vision tag appropriate for the model to the prompt. Image indexing is assumed that start at 0.
336pub trait MultimodalPromptPrefixer: Send + Sync {
337    /// Prefix for inclusion in messages (may do nothing if the chat template handles it).
338    fn prefix_image(&self, _image_indices: Vec<usize>, prompt: &str) -> String {
339        prompt.to_string()
340    }
341    /// Prefix for inclusion in messages (may do nothing if the chat template handles it).
342    fn prefix_audio(&self, _audio_indexes: Vec<usize>, prompt: &str) -> String {
343        prompt.to_string()
344    }
345    /// Prefix for inclusion in messages (may do nothing if the chat template handles it).
346    fn prefix_video(&self, _video_indexes: Vec<usize>, prompt: &str) -> String {
347        prompt.to_string()
348    }
349}
350
351#[derive(Clone)]
352pub enum CacheBackendMetadata {
353    DefaultInstructions {
354        pre_op: CacheInstruction,
355        post_op: CacheInstruction,
356    },
357    PagedAttention {
358        metadata: PagedAttentionMeta,
359    },
360}
361
362#[derive(Clone, Debug)]
363pub enum ForwardInputsResult {
364    RawLogits {
365        logits: Tensor,
366    },
367    Embeddings {
368        embeddings: Tensor,
369    },
370    CausalGeneration {
371        logits: Tensor,
372    },
373    Image {
374        images: Vec<DynamicImage>,
375    },
376    Speech {
377        pcms: Vec<Arc<Vec<f32>>>,
378        rates: Vec<usize>,
379        channels: Vec<usize>,
380    },
381}
382
383impl ForwardInputsResult {
384    fn index_bs(&self, bs_idx: usize) -> hanzo_ml::Result<Self> {
385        match self {
386            Self::CausalGeneration { logits } => Ok(Self::CausalGeneration {
387                logits: logits.i(bs_idx)?,
388            }),
389            Self::Embeddings { embeddings } => Ok(Self::Embeddings {
390                embeddings: embeddings.i(bs_idx)?,
391            }),
392            Self::RawLogits { logits } => Ok(Self::RawLogits {
393                logits: logits.i(bs_idx)?,
394            }),
395            Self::Image { images } => Ok(Self::Image {
396                images: vec![images[bs_idx].clone()],
397            }),
398            Self::Speech {
399                pcms,
400                rates,
401                channels,
402            } => Ok(Self::Speech {
403                pcms: vec![pcms[bs_idx].clone()],
404                rates: vec![rates[bs_idx]],
405                channels: vec![channels[bs_idx]],
406            }),
407        }
408    }
409
410    fn to_device(&self, device: &Device) -> hanzo_ml::Result<Self> {
411        match self {
412            Self::CausalGeneration { logits } => Ok(Self::CausalGeneration {
413                logits: logits.to_device(device)?,
414            }),
415            Self::RawLogits { logits } => Ok(Self::RawLogits {
416                logits: logits.to_device(device)?,
417            }),
418            Self::Embeddings { embeddings } => Ok(Self::Embeddings {
419                embeddings: embeddings.to_device(device)?,
420            }),
421            Self::Image { .. } => Ok(self.clone()),
422            Self::Speech { .. } => Ok(self.clone()),
423        }
424    }
425}
426
427#[derive(serde::Serialize, serde::Deserialize)]
428pub(crate) struct FileListCache {
429    files: Vec<String>,
430}
431
432#[async_trait::async_trait]
433pub trait Pipeline:
434    Send
435    + Sync
436    + PreProcessingMixin
437    + IsqPipelineMixin
438    + CacheManagerMixin
439    + MetadataMixin
440    + AnyMoePipelineMixin
441{
442    fn forward_inputs(
443        &mut self,
444        inputs: Box<dyn Any>,
445        return_raw_logits: bool,
446    ) -> Result<ForwardInputsResult, hanzo_ml::Error>;
447
448    fn attach_speculative(
449        &mut self,
450        _config: crate::speculative::SpeculativeConfig,
451    ) -> Result<(), hanzo_ml::Error> {
452        hanzo_ml::bail!("This pipeline does not support speculative decoding attachment.")
453    }
454
455    #[allow(clippy::too_many_arguments)]
456    async fn try_sample_speculative_causal_gen(
457        &mut self,
458        _input_seqs: &mut [&mut Sequence],
459        _logits: &[Tensor],
460        _prefix_cacher: &mut PrefixCacheManagerV2,
461        _disable_eos_stop: bool,
462        _rng: Arc<std::sync::Mutex<Isaac64Rng>>,
463        _metadata: Option<PagedAttentionMeta>,
464    ) -> Result<bool, hanzo_ml::Error> {
465        Ok(false)
466    }
467
468    /// Returns the total of model execution time.
469    #[allow(clippy::too_many_arguments)]
470    async fn step(
471        &mut self,
472        input_seqs: &mut [&mut Sequence],
473        is_prompt: bool,
474        return_raw_logits: bool,
475        prefix_cacher: &mut PrefixCacheManagerV2,
476        disable_eos_stop: bool,
477        rng: Arc<std::sync::Mutex<Isaac64Rng>>,
478        backend_metadata: CacheBackendMetadata,
479    ) -> Result<Duration, hanzo_ml::Error> {
480        match backend_metadata {
481            CacheBackendMetadata::DefaultInstructions { pre_op, post_op } => {
482                if !is_prompt && !return_raw_logits {
483                    crate::speculative::driver::clear_staged_speculative_tokens(input_seqs);
484                }
485
486                let inputs_iter =
487                    std::iter::once(self.get_processor().inputs_processor().process_inputs(
488                        self.tokenizer(),
489                        input_seqs,
490                        is_prompt,
491                        self.get_metadata().is_xlora,
492                        &self.device(),
493                        self.get_metadata().no_kv_cache,
494                        None,
495                        return_raw_logits,
496                        self.get_metadata().sliding_window,
497                        self.get_input_processor_config(),
498                        None,
499                        self.device_mapper(),
500                    ));
501
502                let mut logits = vec![None; input_seqs.len()];
503                let len_inputs = 1;
504                let mut raw_out_logits = vec![vec![None; len_inputs]; input_seqs.len()];
505                let mut embedding_logits = vec![None; input_seqs.len()];
506
507                let mut exec_duration = Duration::ZERO;
508                for (i, inputs) in inputs_iter.into_iter().enumerate() {
509                    let InputProcessorOutput {
510                        inputs,
511                        seq_indices,
512                    } = inputs.map_err(hanzo_ml::Error::msg)?;
513                    if i == 0 {
514                        match pre_op {
515                            CacheInstruction::In => self.clone_in_cache(input_seqs),
516                            CacheInstruction::Nothing => (),
517                            CacheInstruction::Reset {
518                                load_preallocated_cache,
519                                reset_non_granular,
520                            } => self.set_none_cache(
521                                input_seqs,
522                                reset_non_granular,
523                                false,
524                                load_preallocated_cache,
525                            ),
526                            _ => unreachable!("Unreachable PRE cache op."),
527                        }
528                    }
529
530                    let start = Instant::now();
531                    let raw_logits = self.forward_inputs(inputs, return_raw_logits)?;
532                    let end = Instant::now();
533                    exec_duration += end.duration_since(start);
534
535                    for (logit_idx, seq_idx) in seq_indices.into_iter().enumerate() {
536                        if let ForwardInputsResult::RawLogits { logits } = &raw_logits {
537                            raw_out_logits[seq_idx][i] =
538                                Some(logits.i(logit_idx)?.to_device(&Device::Cpu)?);
539                        } else if let ForwardInputsResult::Embeddings { embeddings } = &raw_logits {
540                            embedding_logits[seq_idx] =
541                                Some(embeddings.i(logit_idx)?.to_device(&Device::Cpu)?);
542                        } else {
543                            logits[seq_idx] = Some(raw_logits.index_bs(logit_idx)?);
544                        }
545                    }
546                }
547
548                match post_op {
549                    CacheInstruction::Out => self.clone_out_cache(input_seqs),
550                    CacheInstruction::Nothing => (),
551                    CacheInstruction::Reset {
552                        load_preallocated_cache,
553                        reset_non_granular,
554                    } => self.set_none_cache(
555                        input_seqs,
556                        reset_non_granular,
557                        false,
558                        load_preallocated_cache,
559                    ),
560                    _ => unreachable!("Unreachable POST cache op."),
561                }
562
563                if raw_out_logits[0][0].is_some() {
564                    let start = Instant::now();
565                    response::send_raw_responses(
566                        input_seqs,
567                        raw_out_logits
568                            .into_iter()
569                            .map(|raw| raw.into_iter().flatten().collect::<Vec<_>>())
570                            .collect(),
571                    )
572                    .await?;
573                    let end = Instant::now();
574                    exec_duration += end.duration_since(start);
575
576                    return Ok(exec_duration);
577                }
578                if embedding_logits[0].is_some() {
579                    let start = Instant::now();
580                    response::send_embedding_responses(
581                        input_seqs,
582                        embedding_logits
583                            .into_iter()
584                            .map(|raw| {
585                                raw.unwrap()
586                                    .to_dtype(DType::F32)
587                                    .unwrap()
588                                    .to_vec1::<f32>()
589                                    .unwrap()
590                            })
591                            .collect(),
592                    )
593                    .await?;
594                    let end = Instant::now();
595                    exec_duration += end.duration_since(start);
596
597                    return Ok(exec_duration);
598                }
599
600                let start = Instant::now();
601                let logits_on_cpu = logits.len() > 1;
602                let logits = logits
603                    .into_iter()
604                    .map(|l| {
605                        let l = l.expect("missing forward result");
606                        if logits_on_cpu {
607                            l.to_device(&Device::Cpu)
608                        } else {
609                            Ok(l)
610                        }
611                    })
612                    .collect::<hanzo_ml::Result<Vec<_>>>()?;
613
614                match &logits[0] {
615                    ForwardInputsResult::RawLogits { .. }
616                    | ForwardInputsResult::Embeddings { .. } => unreachable!(),
617                    ForwardInputsResult::CausalGeneration { .. } => {
618                        let logits = logits
619                            .into_iter()
620                            .map(|r| {
621                                #[allow(irrefutable_let_patterns)]
622                                let ForwardInputsResult::CausalGeneration { logits } = r
623                                else {
624                                    unreachable!(
625                                        "All results must have same type, `CausalGeneration`"
626                                    )
627                                };
628                                logits
629                            })
630                            .collect::<Vec<_>>();
631                        if is_prompt
632                            || return_raw_logits
633                            || !self
634                                .try_sample_speculative_causal_gen(
635                                    input_seqs,
636                                    &logits,
637                                    prefix_cacher,
638                                    disable_eos_stop,
639                                    rng.clone(),
640                                    None,
641                                )
642                                .await?
643                        {
644                            self.sample_causal_gen(
645                                input_seqs,
646                                logits,
647                                prefix_cacher,
648                                disable_eos_stop,
649                                rng,
650                            )
651                            .await?;
652                        }
653                    }
654                    ForwardInputsResult::Image { .. } => {
655                        response::send_image_responses(
656                            input_seqs,
657                            logits
658                                .into_iter()
659                                .map(|r| {
660                                    #[allow(irrefutable_let_patterns)]
661                                    let ForwardInputsResult::Image { images } = r
662                                    else {
663                                        unreachable!("All results must have same type, `Image`")
664                                    };
665                                    images
666                                        .into_iter()
667                                        .next()
668                                        .expect("Must have at least 1 element.")
669                                })
670                                .collect::<Vec<_>>(),
671                        )
672                        .await?;
673                    }
674                    ForwardInputsResult::Speech { .. } => {
675                        let rates = logits
676                            .iter()
677                            .map(|r| {
678                                #[allow(irrefutable_let_patterns)]
679                                let ForwardInputsResult::Speech { rates, .. } = r
680                                else {
681                                    unreachable!("All results must have same type, `Speech`")
682                                };
683                                assert_eq!(rates.len(), 1, "Each sequence must have 1 PCM output.");
684                                *rates.first().unwrap()
685                            })
686                            .collect::<Vec<_>>();
687                        let channels = logits
688                            .iter()
689                            .map(|r| {
690                                #[allow(irrefutable_let_patterns)]
691                                let ForwardInputsResult::Speech { channels, .. } = r
692                                else {
693                                    unreachable!("All results must have same type, `Speech`")
694                                };
695                                assert_eq!(
696                                    channels.len(),
697                                    1,
698                                    "Each sequence must have 1 PCM output."
699                                );
700                                *channels.first().unwrap()
701                            })
702                            .collect::<Vec<_>>();
703                        let pcms = logits
704                            .into_iter()
705                            .map(|r| {
706                                #[allow(irrefutable_let_patterns)]
707                                let ForwardInputsResult::Speech { pcms, .. } = r
708                                else {
709                                    unreachable!("All results must have same type, `Speech`")
710                                };
711                                assert_eq!(pcms.len(), 1, "Each sequence must have 1 PCM output.");
712                                pcms.into_iter().nth(0).unwrap()
713                            })
714                            .collect::<Vec<_>>();
715                        response::send_speech_responses(input_seqs, &pcms, &rates, &channels)
716                            .await?;
717                    }
718                }
719                let end = Instant::now();
720                exec_duration += end.duration_since(start);
721
722                Ok(exec_duration)
723            }
724            CacheBackendMetadata::PagedAttention { metadata } => {
725                let speculative_metadata = metadata.clone();
726                // For hybrid models, build state_indices tensor from sequences'
727                // recurrent_state_idx so recurrent layers are active during forward.
728                // Paged attention manages KV caches separately, but recurrent state
729                // pool access still needs the indices tensor to be set.
730                if self.cache().is_hybrid() {
731                    let mut hybrid_cache = self.cache().hybrid();
732                    let recurrent_device = hybrid_cache.caches.iter().find_map(|c| {
733                        if let HybridLayerCache::Recurrent(pool) = c {
734                            Some(pool.device().clone())
735                        } else {
736                            None
737                        }
738                    });
739                    if let Some(device) = recurrent_device {
740                        #[allow(clippy::cast_possible_truncation)]
741                        let indices: Vec<u32> = input_seqs
742                            .iter()
743                            .filter_map(|seq| seq.recurrent_state_idx().map(|idx| idx as u32))
744                            .collect();
745                        if indices.len() == input_seqs.len() {
746                            if let Ok(si) = Tensor::from_vec(indices, (input_seqs.len(),), &device)
747                            {
748                                hybrid_cache.set_state_indices(Some(si));
749                            }
750                        }
751                    }
752                }
753
754                let inputs_iter =
755                    std::iter::once(self.get_processor().inputs_processor().process_inputs(
756                        self.tokenizer(),
757                        input_seqs,
758                        is_prompt,
759                        self.get_metadata().is_xlora,
760                        &self.device(),
761                        self.get_metadata().no_kv_cache,
762                        None,
763                        return_raw_logits,
764                        self.get_metadata().sliding_window,
765                        self.get_input_processor_config(),
766                        Some(metadata),
767                        self.device_mapper(),
768                    ));
769
770                let mut logits = vec![None; input_seqs.len()];
771                let len_inputs = 1;
772                let mut raw_out_logits = vec![vec![None; len_inputs]; input_seqs.len()];
773                let mut embedding_logits = vec![None; input_seqs.len()];
774
775                let mut exec_duration = Duration::ZERO;
776                for (i, inputs) in inputs_iter.into_iter().enumerate() {
777                    let InputProcessorOutput {
778                        inputs,
779                        seq_indices,
780                    } = inputs.map_err(hanzo_ml::Error::msg)?;
781
782                    let start = Instant::now();
783                    let raw_logits = self.forward_inputs(inputs, return_raw_logits)?;
784                    let end = Instant::now();
785                    exec_duration += end.duration_since(start);
786
787                    for (logit_idx, seq_idx) in seq_indices.into_iter().enumerate() {
788                        if let ForwardInputsResult::RawLogits { logits } = &raw_logits {
789                            raw_out_logits[seq_idx][i] =
790                                Some(logits.i(logit_idx)?.to_device(&Device::Cpu)?);
791                        } else if let ForwardInputsResult::Embeddings { embeddings } = &raw_logits {
792                            embedding_logits[seq_idx] =
793                                Some(embeddings.i(logit_idx)?.to_device(&Device::Cpu)?);
794                        } else {
795                            logits[seq_idx] = Some(raw_logits.index_bs(logit_idx)?);
796                        }
797                    }
798                }
799
800                if raw_out_logits[0][0].is_some() {
801                    let start = Instant::now();
802                    response::send_raw_responses(
803                        input_seqs,
804                        raw_out_logits
805                            .into_iter()
806                            .map(|raw| raw.into_iter().flatten().collect::<Vec<_>>())
807                            .collect(),
808                    )
809                    .await?;
810                    let end = Instant::now();
811                    exec_duration += end.duration_since(start);
812
813                    return Ok(exec_duration);
814                }
815                if embedding_logits[0].is_some() {
816                    let start = Instant::now();
817                    response::send_embedding_responses(
818                        input_seqs,
819                        embedding_logits
820                            .into_iter()
821                            .map(|raw| {
822                                raw.unwrap()
823                                    .to_dtype(DType::F32)
824                                    .unwrap()
825                                    .to_vec1::<f32>()
826                                    .unwrap()
827                            })
828                            .collect(),
829                    )
830                    .await?;
831                    let end = Instant::now();
832                    exec_duration += end.duration_since(start);
833
834                    return Ok(exec_duration);
835                }
836
837                let start = Instant::now();
838                let logits_on_cpu = logits.len() > 1;
839                let logits = logits
840                    .into_iter()
841                    .map(|l| {
842                        let l = l.expect("missing forward result");
843                        if logits_on_cpu {
844                            l.to_device(&Device::Cpu)
845                        } else {
846                            Ok(l)
847                        }
848                    })
849                    .collect::<hanzo_ml::Result<Vec<_>>>()?;
850                match &logits[0] {
851                    ForwardInputsResult::RawLogits { .. }
852                    | ForwardInputsResult::Embeddings { .. } => unreachable!(),
853                    ForwardInputsResult::CausalGeneration { .. } => {
854                        let logits = logits
855                            .into_iter()
856                            .map(|r| {
857                                #[allow(irrefutable_let_patterns)]
858                                let ForwardInputsResult::CausalGeneration { logits } = r
859                                else {
860                                    unreachable!("All results must have same type")
861                                };
862                                logits
863                            })
864                            .collect::<Vec<_>>();
865                        if is_prompt
866                            || return_raw_logits
867                            || !self
868                                .try_sample_speculative_causal_gen(
869                                    input_seqs,
870                                    &logits,
871                                    prefix_cacher,
872                                    disable_eos_stop,
873                                    rng.clone(),
874                                    Some(speculative_metadata),
875                                )
876                                .await?
877                        {
878                            self.sample_causal_gen(
879                                input_seqs,
880                                logits,
881                                prefix_cacher,
882                                disable_eos_stop,
883                                rng,
884                            )
885                            .await?;
886                        }
887                    }
888                    ForwardInputsResult::Image { .. } => {
889                        response::send_image_responses(
890                            input_seqs,
891                            logits
892                                .into_iter()
893                                .map(|r| {
894                                    #[allow(irrefutable_let_patterns)]
895                                    let ForwardInputsResult::Image { images } = r
896                                    else {
897                                        unreachable!("All results must have same type, `Image`")
898                                    };
899                                    images
900                                        .into_iter()
901                                        .next()
902                                        .expect("Must have at least 1 element.")
903                                })
904                                .collect::<Vec<_>>(),
905                        )
906                        .await?;
907                    }
908                    ForwardInputsResult::Speech { .. } => {
909                        let rates = logits
910                            .iter()
911                            .map(|r| {
912                                #[allow(irrefutable_let_patterns)]
913                                let ForwardInputsResult::Speech { rates, .. } = r
914                                else {
915                                    unreachable!("All results must have same type, `Speech`")
916                                };
917                                assert_eq!(rates.len(), 1, "Each sequence must have 1 PCM output.");
918                                *rates.first().unwrap()
919                            })
920                            .collect::<Vec<_>>();
921                        let channels = logits
922                            .iter()
923                            .map(|r| {
924                                #[allow(irrefutable_let_patterns)]
925                                let ForwardInputsResult::Speech { channels, .. } = r
926                                else {
927                                    unreachable!("All results must have same type, `Speech`")
928                                };
929                                assert_eq!(
930                                    channels.len(),
931                                    1,
932                                    "Each sequence must have 1 PCM output."
933                                );
934                                *channels.first().unwrap()
935                            })
936                            .collect::<Vec<_>>();
937                        let pcms = logits
938                            .into_iter()
939                            .map(|r| {
940                                #[allow(irrefutable_let_patterns)]
941                                let ForwardInputsResult::Speech { pcms, .. } = r
942                                else {
943                                    unreachable!("All results must have same type, `Speech`")
944                                };
945                                assert_eq!(pcms.len(), 1, "Each sequence must have 1 PCM output.");
946                                pcms.into_iter().nth(0).unwrap()
947                            })
948                            .collect::<Vec<_>>();
949                        response::send_speech_responses(input_seqs, &pcms, &rates, &channels)
950                            .await?;
951                    }
952                }
953                let end = Instant::now();
954                exec_duration += end.duration_since(start);
955
956                Ok(exec_duration)
957            }
958        }
959    }
960
961    async fn sample_causal_gen(
962        &self,
963        seqs: &mut [&mut Sequence],
964        logits: Vec<Tensor>,
965        prefix_cacher: &mut PrefixCacheManagerV2,
966        disable_eos_stop: bool,
967        rng: Arc<std::sync::Mutex<Isaac64Rng>>,
968    ) -> Result<(), hanzo_ml::Error>;
969
970    fn category(&self) -> ModelCategory;
971
972    /// Return encoder cache hit/miss counters (hits, misses) if this pipeline has an encoder cache.
973    fn encoder_cache_counters(&self) -> Option<(Arc<AtomicUsize>, Arc<AtomicUsize>)> {
974        None
975    }
976}
977
978pub(crate) fn extract_logits(
979    logits: &Tensor,
980    context_lens: Vec<(usize, usize)>,
981) -> hanzo_ml::Result<Tensor> {
982    let mut toks = Vec::new();
983    for (dim, (start, len)) in logits.chunk(logits.dims()[0], 0)?.iter().zip(context_lens) {
984        toks.push(dim.narrow(1, start, len)?);
985    }
986    Tensor::cat(&toks, 0)
987}
988
989#[cfg(test)]
990mod tests {
991    use crate::MessageContent;
992    use either::Either;
993    use indexmap::IndexMap;
994    use serde_json::Value;
995
996    macro_rules! hashmap {
997        (@single $($x:tt)*) => (());
998        (@count $($rest:expr),*) => (<[()]>::len(&[$(hashmap!(@single $rest)),*]));
999
1000        ($($key:expr => $value:expr,)+) => { hashmap!($($key => $value),+) };
1001        ($($key:expr => $value:expr),*) => {
1002            {
1003                let _cap = hashmap!(@count $($key),*);
1004                let mut _map = ::indexmap::IndexMap::with_capacity(_cap);
1005                $(
1006                    let _ = _map.insert($key, Value::String($value));
1007                )*
1008                _map
1009            }
1010        };
1011    }
1012
1013    #[cfg(test)]
1014    #[track_caller]
1015    fn test_with_inputs(
1016        templates: &[(bool, &str, &str, &str, &str)],
1017        expected_outputs: &[&str],
1018        inputs: Vec<IndexMap<String, MessageContent>>,
1019    ) {
1020        use crate::pipeline::chat_template::ChatTemplateValue;
1021
1022        use super::chat_template::apply_chat_template_to;
1023        let mut failed = Vec::new();
1024        let n_templates = templates.len();
1025        for ((has_system, bos, eos, unk, template), expected) in
1026            templates.iter().zip(expected_outputs)
1027        {
1028            let output = match apply_chat_template_to(
1029                if !has_system {
1030                    inputs[1..].to_vec()
1031                } else {
1032                    inputs.clone()
1033                },
1034                true,
1035                None,
1036                None, // reasoning_effort
1037                &ChatTemplateValue(Either::Left(template.to_string())),
1038                Some(bos.to_string()),
1039                Some(eos.to_string()),
1040                Some(unk.to_string()),
1041                Vec::new(),
1042            ) {
1043                Ok(v) => v,
1044                Err(e) => {
1045                    failed.push(format!("Failed with {e}."));
1046                    continue;
1047                }
1048            };
1049            if output != *expected {
1050                failed.push(format!(
1051                    "Expected: `{}` \n\nGot:      `{}`",
1052                    expected.replace('\n', "\\n"),
1053                    output.replace('\n', "\\n")
1054                ));
1055            }
1056        }
1057        if !failed.is_empty() {
1058            for (i, line) in failed.iter().enumerate() {
1059                println!("------------ Template {i} ------------");
1060                println!("{line}");
1061            }
1062            println!("------------------------");
1063            panic!("{}/{n_templates} chat templates failed.", failed.len());
1064        }
1065    }
1066
1067    #[test]
1068    /// Generating these cases:
1069    /// ```py
1070    /// >>> t=transformers.AutoTokenizer.from_pretrained(...)
1071    /// # If non-system prompt model
1072    /// >>> t.apply_chat_template([{"role":"user","content":"Hello"},{"role":"assistant","content":"Hi there"},{"role":"user","content":"Who are you"},{"role":"assistant","content":"   I am an assistant   "},{"role":"user","content":"Another question"}], add_generation_prompt=True, tokenize=False)
1073    /// # If system prompt model
1074    /// >>> t.apply_chat_template([{"role":"system","content":"You are a helpful assistant"},{"role":"user","content":"Hello"},{"role":"assistant","content":"Hi there"},{"role":"user","content":"Who are you"},{"role":"assistant","content":"   I am an assistant   "},{"role":"user","content":"Another question"}], add_generation_prompt=True, tokenize=False)
1075    /// ```
1076    fn test_chat_templates() {
1077        let templates = [
1078            // ChatML: https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B
1079            (true, "<s>", "</s>", "<unk>", "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"),
1080            // mistralai/Mistral-7B-Instruct-v0.1
1081            (false, "<s>", "</s>", "<unk>", "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"),
1082            // meta-llama/Llama-2-13b-chat-hf
1083            (true, "<s>", "</s>", "<unk>", "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' '  + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}"),
1084            // mistralai/Mixtral-8x7B-Instruct-v0.1
1085            (false, "<s>", "</s>", "<unk>", "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"),
1086            // google/gemma-7b-it
1087            (false, "<bos>", "<eos>", "<unk>", "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}"),
1088            // HuggingFaceM4/idefics2-8b-chatty
1089            (true, "<s>", "</s>", "<unk>", "{% for message in messages %}{{message['role'].capitalize()}}{% if message['content'][0]['type'] == 'image' %}{{':'}}{% else %}{{': '}}{% endif %}{% for line in message['content'] %}{% if line['type'] == 'text' %}{{line['text']}}{% elif line['type'] == 'image' %}{{ '<image>' }}{% endif %}{% endfor %}<end_of_utterance>\n{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"),
1090        ];
1091        let expected_outputs = [
1092            // ChatML: https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B
1093            "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi there<|im_end|>\n<|im_start|>user\nWho are you<|im_end|>\n<|im_start|>assistant\n   I am an assistant   <|im_end|>\n<|im_start|>user\nAnother question<|im_end|>\n<|im_start|>assistant\n",
1094            // mistralai/Mistral-7B-Instruct-v0.1
1095            "<s>[INST] Hello [/INST]Hi there</s> [INST] Who are you [/INST]   I am an assistant   </s> [INST] Another question [/INST]",
1096            // meta-llama/Llama-2-13b-chat-hf
1097            "<s>[INST] <<SYS>>\nYou are a helpful assistant\n<</SYS>>\n\nHello [/INST] Hi there </s><s>[INST] Who are you [/INST] I am an assistant </s><s>[INST] Another question [/INST]",
1098            // mistralai/Mixtral-8x7B-Instruct-v0.1
1099            "<s>[INST] Hello [/INST]Hi there</s>[INST] Who are you [/INST]   I am an assistant   </s>[INST] Another question [/INST]",
1100            // google/gemma-7b-it
1101            "<bos><start_of_turn>user\nHello<end_of_turn>\n<start_of_turn>model\nHi there<end_of_turn>\n<start_of_turn>user\nWho are you<end_of_turn>\n<start_of_turn>model\nI am an assistant<end_of_turn>\n<start_of_turn>user\nAnother question<end_of_turn>\n<start_of_turn>model\n",
1102        ];
1103        let messages = [
1104            ["system", "You are a helpful assistant"],
1105            ["user", "Hello"],
1106            ["assistant", "Hi there"],
1107            ["user", "Who are you"],
1108            ["assistant", "   I am an assistant   "],
1109            ["user", "Another question"],
1110        ];
1111        let mut inputs = Vec::new();
1112        for [role, content] in messages {
1113            let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
1114                IndexMap::new();
1115            message.insert("role".to_string(), Either::Left(role.to_string()));
1116            message.insert("content".to_string(), Either::Left(content.to_string()));
1117            inputs.push(message);
1118        }
1119        test_with_inputs(&templates, &expected_outputs, inputs);
1120    }
1121
1122    #[test]
1123    /// Generating these cases:
1124    /// ```py
1125    /// >>> processor=transformers.AutoProcessor.from_pretrained(...)
1126    /// >>> processor.apply_chat_template([
1127    ///         {"role":"system","content":[{"type":"text", "text": "You are a helpful assistant"}]},
1128    ///         {"role":"user","content":[{"type":"image"}, {"type":"text", "text": "Hello, please describe the above."}]},
1129    ///         {"role":"assistant","content":[{"type":"text", "text": "Hi there"}]},
1130    ///         {"role":"user","content":[{"type":"text", "text": "Who are you"}]},
1131    ///         {"role":"assistant","content":[{"type":"text", "text": "   I am an assistant   "}]},
1132    ///         {"role":"user","content":[{"type":"text", "text": "Another question"}]}
1133    ///     ], add_generation_prompt=True, tokenize=False)
1134    /// ```
1135    fn test_image_chat_templates() {
1136        let templates = [
1137            // HuggingFaceM4/idefics2-8b-chatty
1138            (true, "<s>", "</s>", "<unk>", "{% for message in messages %}{{message['role'].capitalize()}}{% if message['content'][0]['type'] == 'image' %}{{':'}}{% else %}{{': '}}{% endif %}{% for line in message['content'] %}{% if line['type'] == 'text' %}{{line['text']}}{% elif line['type'] == 'image' %}{{ '<image>' }}{% endif %}{% endfor %}<end_of_utterance>\n{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"),
1139        ];
1140        let expected_outputs = [
1141            // HuggingFaceM4/idefics2-8b-chatty
1142            "System: You are a helpful assistant<end_of_utterance>\nUser:<image>Hello, please describe the above.<end_of_utterance>\nAssistant: Hi there<end_of_utterance>\nUser:<image>This is me, who are you<end_of_utterance>\nAssistant:    I am an assistant   <end_of_utterance>\nUser:<image>Another question, what is this?<end_of_utterance>\nAssistant:",
1143        ];
1144
1145        let mut inputs = Vec::new();
1146
1147        let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
1148            IndexMap::new();
1149        message.insert("role".to_string(), Either::Left("system".to_string()));
1150        message.insert(
1151            "content".to_string(),
1152            Either::Right(vec![hashmap! {
1153                "type".to_string() => "text".to_string(),
1154                "text".to_string() => "You are a helpful assistant".to_string()
1155            }]),
1156        );
1157        inputs.push(message);
1158
1159        let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
1160            IndexMap::new();
1161        message.insert("role".to_string(), Either::Left("user".to_string()));
1162        message.insert(
1163            "content".to_string(),
1164            Either::Right(vec![
1165                hashmap! {
1166                    "type".to_string() => "image".to_string()
1167                },
1168                hashmap! {
1169                    "type".to_string() => "text".to_string(),
1170                    "text".to_string() => "Hello, please describe the above.".to_string()
1171                },
1172            ]),
1173        );
1174        inputs.push(message);
1175
1176        let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
1177            IndexMap::new();
1178        message.insert("role".to_string(), Either::Left("assistant".to_string()));
1179        message.insert(
1180            "content".to_string(),
1181            Either::Right(vec![hashmap! {
1182                "type".to_string() => "text".to_string(),
1183                "text".to_string() => "Hi there".to_string()
1184            }]),
1185        );
1186        inputs.push(message);
1187
1188        let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
1189            IndexMap::new();
1190        message.insert("role".to_string(), Either::Left("user".to_string()));
1191        message.insert(
1192            "content".to_string(),
1193            Either::Right(vec![
1194                hashmap! {
1195                    "type".to_string() => "image".to_string()
1196                },
1197                hashmap! {
1198                    "type".to_string() => "text".to_string(),
1199                    "text".to_string() => "This is me, who are you".to_string()
1200                },
1201            ]),
1202        );
1203        inputs.push(message);
1204
1205        let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
1206            IndexMap::new();
1207        message.insert("role".to_string(), Either::Left("assistant".to_string()));
1208        message.insert(
1209            "content".to_string(),
1210            Either::Right(vec![hashmap! {
1211                "type".to_string() => "text".to_string(),
1212                "text".to_string() => "   I am an assistant   ".to_string()
1213            }]),
1214        );
1215        inputs.push(message);
1216
1217        let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
1218            IndexMap::new();
1219        message.insert("role".to_string(), Either::Left("user".to_string()));
1220        message.insert(
1221            "content".to_string(),
1222            Either::Right(vec![
1223                hashmap! {
1224                    "type".to_string() => "image".to_string()
1225                },
1226                hashmap! {
1227                    "type".to_string() => "text".to_string(),
1228                    "text".to_string() => "Another question, what is this?".to_string()
1229                },
1230            ]),
1231        );
1232        inputs.push(message);
1233
1234        test_with_inputs(&templates, &expected_outputs, inputs);
1235    }
1236}