1mod amoe;
2mod auto;
3pub mod chat_template;
4mod diffusion;
5mod embedding;
6mod ggml;
7mod gguf;
8pub(crate) mod hf;
9mod inputs_processor;
10mod isq;
11pub(crate) mod llg;
12mod loaders;
13mod macros;
14mod multimodal;
15mod normal;
16mod paths;
17mod processing;
18mod response;
19pub(crate) mod sampling;
20mod speech;
21
22pub use super::diffusion_models::DiffusionGenerationParams;
23use crate::amoe::{AnyMoeConfig, AnyMoeExpertType, AnyMoeTrainingInputs, AnyMoeTrainingResult};
24use crate::device_map::DeviceMapper;
25use crate::paged_attention::{CacheConfig, CacheEngine, ModelConfigLike};
26use crate::prefix_cacher::PrefixCacheManagerV2;
27use crate::PagedAttentionConfig;
28pub use amoe::{AnyMoeLoader, AnyMoePipeline};
29pub use auto::{AutoLoader, AutoLoaderBuilder};
30use chat_template::ChatTemplate;
31pub use diffusion::{DiffusionLoader, DiffusionLoaderBuilder};
32pub(crate) use embedding::EmbeddingLoadContext;
33pub use embedding::{EmbeddingLoader, EmbeddingLoaderBuilder, EmbeddingSpecificConfig};
34pub use ggml::{GGMLLoader, GGMLLoaderBuilder, GGMLSpecificConfig};
35pub use gguf::{GGUFLoader, GGUFLoaderBuilder, GGUFSpecificConfig};
36use image::DynamicImage;
37pub use inputs_processor::InputProcessorOutput;
38pub(crate) use isq::IsqModelLoader;
39pub use isq::{
40 expand_isq_value, parse_isq_value, parse_uqff_shard, resolve_uqff_shorthand, IsqModel,
41 IsqOrganization, UQFF_MULTI_FILE_DELIMITER,
42};
43use llguidance::toktrie::TokEnv;
44pub use loaders::{
45 AdapterKind, AutoDeviceMapParams, AutoEmbeddingLoader, AutoMultimodalLoader, AutoNormalLoader,
46 DeepSeekV2Loader, DeepSeekV3Loader, DeviceMappedModelLoader, DiffusionLoaderType,
47 DiffusionModel, DiffusionModelLoader, EmbeddingGemmaLoader, EmbeddingLoaderType,
48 EmbeddingModel, EmbeddingModelLoader, EmbeddingModelPaths, EmbeddingModule,
49 EmbeddingModulePaths, EmbeddingModuleType, FluxLoader, GLM4Loader, GLM4MoeLiteLoader,
50 GLM4MoeLoader, Gemma2Loader, Gemma3Loader, Gemma3nLoader, Gemma4Loader, GemmaLoader,
51 GptOssLoader, GraniteMoeHybridLoader, Idefics2Loader, Idefics3Loader, LLaVALoader,
52 LLaVANextLoader, LlamaLoader, Loader, LocalModelPaths, MiniCpmOLoader, Mistral3Loader,
53 MistralLoader, MixtralLoader, ModelKind, ModelPaths, MultimodalLoaderType, MultimodalModel,
54 MultimodalModelLoader, NormalLoaderType, NormalLoadingMetadata, NormalModel, NormalModelLoader,
55 Phi2Loader, Phi3Loader, Phi3VLoader, Phi3_5MoELoader, Phi4MMLoader, PrettyName,
56 QuantizationKind, Qwen2Loader, Qwen2VLLoader, Qwen2_5VLLoader, Qwen3EmbeddingLoader,
57 Qwen3Loader, Qwen3MoELoader, Qwen3NextLoader, Qwen3VLLoader, Qwen3VLMoELoader, Qwen3_5Loader,
58 Qwen3_5MoeLoader, SmolLm3Loader, Starcoder2Loader, TokenSource, VLlama4Loader, VLlamaLoader,
59 VoxtralLoader,
60};
61#[allow(clippy::too_many_arguments)]
62pub(crate) fn get_device_layers_for_loader(
63 loader: &dyn loaders::DeviceMappedModelLoader,
64 config: &str,
65 num_layers: usize,
66 layer_sizes_in_bytes: Vec<usize>,
67 non_mapped_size_in_bytes: usize,
68 total_model_size_in_bytes: usize,
69 devices: &[Device],
70 dtype: DType,
71 params: &loaders::AutoDeviceMapParams,
72 paged_attn_config: Option<&PagedAttentionConfig>,
73) -> Result<crate::device_map::DeviceMapMetadata> {
74 loaders::auto_device_map::get_device_layers(
75 loader,
76 config,
77 num_layers,
78 layer_sizes_in_bytes,
79 non_mapped_size_in_bytes,
80 total_model_size_in_bytes,
81 devices,
82 dtype,
83 params,
84 paged_attn_config,
85 )
86}
87use hanzo_quant::IsqType;
88pub use multimodal::{MultimodalLoader, MultimodalLoaderBuilder, MultimodalSpecificConfig};
89pub use normal::{NormalLoader, NormalLoaderBuilder, NormalSpecificConfig};
90pub(crate) use paths::{get_chat_template, get_model_paths, get_xlora_paths};
91pub use paths::{AdapterPaths, LoraAdapterPaths};
92pub(crate) use processing::{
93 apply_chat_template, BasicProcessor, MessagesAction, Processor, ProcessorCreator,
94};
95use rand_isaac::Isaac64Rng;
96pub use speech::{SpeechLoader, SpeechPipeline};
97use std::any::Any;
98use std::fmt::Debug;
99use std::sync::atomic::AtomicUsize;
100use std::sync::Arc;
101use std::time::{Duration, Instant};
102
103use tokenizers::Tokenizer;
104
105use anyhow::Result;
106use hanzo_ml::{DType, Device, IndexOp, Tensor, Var};
107
108use crate::sequence::Sequence;
109
110pub use self::inputs_processor::{
111 text_models_inputs_processor, InputsProcessor, InputsProcessorType,
112};
113use self::text_models_inputs_processor::PagedAttentionMeta;
114pub use crate::kv_cache::{
115 Cache, CacheManager, EitherCache, HybridLayerCache, KvCache, LayerCaches, NormalCache,
116 NormalCacheType,
117};
118
119#[derive(Clone, PartialEq, Eq)]
120pub enum SupportedModality {
121 Text,
122 Audio,
123 Vision,
124 Video,
125 Embedding,
126}
127
128impl Debug for SupportedModality {
129 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
130 match self {
131 Self::Text => write!(f, "📝 Text"),
132 Self::Audio => write!(f, "🔊 Audio"),
133 Self::Vision => write!(f, "🖼️ Vision"),
134 Self::Video => write!(f, "🎬 Video"),
135 Self::Embedding => write!(f, "🔢 Embedding"),
136 }
137 }
138}
139
140#[derive(Debug, Clone)]
141pub struct Modalities {
142 pub input: Vec<SupportedModality>,
143 pub output: Vec<SupportedModality>,
144}
145
146pub struct GeneralMetadata {
147 pub max_seq_len: usize,
148 pub llg_factory: Option<Arc<llguidance::ParserFactory>>,
150 pub no_kv_cache: bool,
151 pub no_prefix_cache: bool,
152 pub num_hidden_layers: usize,
153 pub eos_tok: Vec<u32>,
154 pub kind: ModelKind,
155 pub is_xlora: bool,
157 pub activation_dtype: DType,
158 pub sliding_window: Option<usize>,
159 pub cache_config: Option<CacheConfig>,
161 pub cache_engine: Option<CacheEngine>,
162 pub model_metadata: Option<Arc<dyn ModelConfigLike + Send + Sync>>,
163 pub modalities: Modalities,
164}
165
166impl GeneralMetadata {
167 pub fn tok_env(&self) -> Option<TokEnv> {
168 self.llg_factory.as_ref().map(|f| f.tok_env().clone())
169 }
170}
171
172#[derive(Clone, Copy)]
173pub enum CacheInstruction {
174 In,
175 Out,
176 Reset {
178 load_preallocated_cache: bool,
179 reset_non_granular: bool,
180 },
181 Nothing,
182}
183
184pub trait PreProcessingMixin: MetadataMixin {
185 fn get_processor(&self) -> Arc<dyn Processor> {
186 Arc::new(BasicProcessor)
187 }
188 fn get_chat_template(&self) -> Option<Arc<ChatTemplate>>;
190 fn get_input_processor_config(&self) -> Option<Arc<dyn Any>>;
191}
192
193pub trait IsqPipelineMixin {
194 fn re_isq_model(&mut self, dtype: IsqType) -> Result<()>;
195}
196
197pub trait CacheManagerMixin {
198 fn clone_in_cache(&self, seqs: &mut [&mut Sequence]);
201 fn clone_out_cache(&self, seqs: &mut [&mut Sequence]);
204 fn set_none_cache(
208 &self,
209 seqs: &mut [&mut Sequence],
210 reset_non_granular: bool,
211 modify_draft_cache: bool,
212 load_preallocated_cache: bool,
213 );
214 fn cache(&self) -> &EitherCache;
215}
216
217pub trait MetadataMixin {
218 fn device(&self) -> Device;
219 fn tokenizer(&self) -> Option<Arc<Tokenizer>>;
221 fn name(&self) -> String;
222 fn reset_non_granular_state(&self);
223 fn get_metadata(&self) -> Arc<GeneralMetadata>;
224 fn generation_defaults(&self) -> Option<crate::ModelGenerationDefaults> {
225 None
226 }
227 fn device_mapper(&self) -> Option<&dyn DeviceMapper>;
228}
229
230pub trait AnyMoePipelineMixin {
232 fn amoe_layer_vars(&self) -> Vec<Vec<Var>> {
234 unreachable!()
235 }
236 fn amoe_finish_training(&mut self, _gate_model_id: Option<String>) -> hanzo_ml::Result<()> {
237 unreachable!()
238 }
239 fn amoe_base_model_trainable_params(&self) -> usize {
240 unreachable!()
241 }
242 fn amoe_supported(&self) -> bool {
243 false
244 }
245 fn amoe_take_cached_gating_outputs(&mut self) -> Vec<Tensor> {
247 unreachable!()
248 }
249 #[allow(clippy::too_many_arguments)]
251 fn amoe_create_layers(
252 &mut self,
253 _model_ids: Vec<String>,
254 _token: &TokenSource,
255 _revision: Option<String>,
256 _match_regex: &str,
257 _config: AnyMoeConfig,
258 _dtype: DType,
259 _dev: &Device,
260 (_prefix, _mlp): (String, String),
261 _layers: Vec<usize>,
262 _expert_type: AnyMoeExpertType,
263 _silent: bool,
264 _gate_model_id: Option<String>,
265 ) -> hanzo_ml::Result<()> {
266 unreachable!()
267 }
268 #[allow(clippy::too_many_arguments)]
270 fn amoe_pre_train(
271 &self,
272 _inputs: AnyMoeTrainingInputs,
273 (_prefix, _mlp): (String, String),
274 _model_ids: Vec<String>,
275 _token: TokenSource,
276 _revision: Option<String>,
277 _layers: Vec<usize>,
278 _silent: bool,
279 ) -> Result<Option<AnyMoeTrainingResult>, hanzo_ml::Error> {
280 unreachable!()
281 }
282}
283
284#[derive(Clone)]
287pub enum ModelCategory {
288 Text,
289 Multimodal {
290 prefixer: Arc<dyn MultimodalPromptPrefixer>,
291 },
292 Diffusion,
293 Audio,
294 Speech,
295 Embedding,
296}
297
298impl std::fmt::Debug for ModelCategory {
299 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
300 match self {
301 ModelCategory::Text => write!(f, "ModelCategory::Text"),
302 ModelCategory::Multimodal { .. } => {
303 write!(f, "ModelCategory::Multimodal {{ prefixer: .. }}")
304 }
305 ModelCategory::Diffusion => write!(f, "ModelCategory::Diffusion"),
306 ModelCategory::Audio => write!(f, "ModelCategory::Audio"),
307 ModelCategory::Speech => write!(f, "ModelCategory::Speech"),
308 ModelCategory::Embedding => write!(f, "ModelCategory::Embedding"),
309 }
310 }
311}
312
313impl PartialEq for ModelCategory {
314 fn eq(&self, other: &Self) -> bool {
315 match (self, other) {
316 (Self::Text, Self::Text) => true,
317 (Self::Multimodal { .. }, Self::Multimodal { .. }) => true,
318 (Self::Audio, Self::Audio) => true,
319 (Self::Speech, Self::Speech) => true,
320 (Self::Diffusion, Self::Diffusion) => true,
321 (Self::Embedding, Self::Embedding) => true,
322 (
323 Self::Text
324 | Self::Multimodal { .. }
325 | Self::Diffusion
326 | Self::Audio
327 | Self::Speech
328 | Self::Embedding,
329 _,
330 ) => false,
331 }
332 }
333}
334
335pub trait MultimodalPromptPrefixer: Send + Sync {
337 fn prefix_image(&self, _image_indices: Vec<usize>, prompt: &str) -> String {
339 prompt.to_string()
340 }
341 fn prefix_audio(&self, _audio_indexes: Vec<usize>, prompt: &str) -> String {
343 prompt.to_string()
344 }
345 fn prefix_video(&self, _video_indexes: Vec<usize>, prompt: &str) -> String {
347 prompt.to_string()
348 }
349}
350
351#[derive(Clone)]
352pub enum CacheBackendMetadata {
353 DefaultInstructions {
354 pre_op: CacheInstruction,
355 post_op: CacheInstruction,
356 },
357 PagedAttention {
358 metadata: PagedAttentionMeta,
359 },
360}
361
362#[derive(Clone, Debug)]
363pub enum ForwardInputsResult {
364 RawLogits {
365 logits: Tensor,
366 },
367 Embeddings {
368 embeddings: Tensor,
369 },
370 CausalGeneration {
371 logits: Tensor,
372 },
373 Image {
374 images: Vec<DynamicImage>,
375 },
376 Speech {
377 pcms: Vec<Arc<Vec<f32>>>,
378 rates: Vec<usize>,
379 channels: Vec<usize>,
380 },
381}
382
383impl ForwardInputsResult {
384 fn index_bs(&self, bs_idx: usize) -> hanzo_ml::Result<Self> {
385 match self {
386 Self::CausalGeneration { logits } => Ok(Self::CausalGeneration {
387 logits: logits.i(bs_idx)?,
388 }),
389 Self::Embeddings { embeddings } => Ok(Self::Embeddings {
390 embeddings: embeddings.i(bs_idx)?,
391 }),
392 Self::RawLogits { logits } => Ok(Self::RawLogits {
393 logits: logits.i(bs_idx)?,
394 }),
395 Self::Image { images } => Ok(Self::Image {
396 images: vec![images[bs_idx].clone()],
397 }),
398 Self::Speech {
399 pcms,
400 rates,
401 channels,
402 } => Ok(Self::Speech {
403 pcms: vec![pcms[bs_idx].clone()],
404 rates: vec![rates[bs_idx]],
405 channels: vec![channels[bs_idx]],
406 }),
407 }
408 }
409
410 fn to_device(&self, device: &Device) -> hanzo_ml::Result<Self> {
411 match self {
412 Self::CausalGeneration { logits } => Ok(Self::CausalGeneration {
413 logits: logits.to_device(device)?,
414 }),
415 Self::RawLogits { logits } => Ok(Self::RawLogits {
416 logits: logits.to_device(device)?,
417 }),
418 Self::Embeddings { embeddings } => Ok(Self::Embeddings {
419 embeddings: embeddings.to_device(device)?,
420 }),
421 Self::Image { .. } => Ok(self.clone()),
422 Self::Speech { .. } => Ok(self.clone()),
423 }
424 }
425}
426
427#[derive(serde::Serialize, serde::Deserialize)]
428pub(crate) struct FileListCache {
429 files: Vec<String>,
430}
431
432#[async_trait::async_trait]
433pub trait Pipeline:
434 Send
435 + Sync
436 + PreProcessingMixin
437 + IsqPipelineMixin
438 + CacheManagerMixin
439 + MetadataMixin
440 + AnyMoePipelineMixin
441{
442 fn forward_inputs(
443 &mut self,
444 inputs: Box<dyn Any>,
445 return_raw_logits: bool,
446 ) -> Result<ForwardInputsResult, hanzo_ml::Error>;
447
448 fn attach_speculative(
449 &mut self,
450 _config: crate::speculative::SpeculativeConfig,
451 ) -> Result<(), hanzo_ml::Error> {
452 hanzo_ml::bail!("This pipeline does not support speculative decoding attachment.")
453 }
454
455 #[allow(clippy::too_many_arguments)]
456 async fn try_sample_speculative_causal_gen(
457 &mut self,
458 _input_seqs: &mut [&mut Sequence],
459 _logits: &[Tensor],
460 _prefix_cacher: &mut PrefixCacheManagerV2,
461 _disable_eos_stop: bool,
462 _rng: Arc<std::sync::Mutex<Isaac64Rng>>,
463 _metadata: Option<PagedAttentionMeta>,
464 ) -> Result<bool, hanzo_ml::Error> {
465 Ok(false)
466 }
467
468 #[allow(clippy::too_many_arguments)]
470 async fn step(
471 &mut self,
472 input_seqs: &mut [&mut Sequence],
473 is_prompt: bool,
474 return_raw_logits: bool,
475 prefix_cacher: &mut PrefixCacheManagerV2,
476 disable_eos_stop: bool,
477 rng: Arc<std::sync::Mutex<Isaac64Rng>>,
478 backend_metadata: CacheBackendMetadata,
479 ) -> Result<Duration, hanzo_ml::Error> {
480 match backend_metadata {
481 CacheBackendMetadata::DefaultInstructions { pre_op, post_op } => {
482 if !is_prompt && !return_raw_logits {
483 crate::speculative::driver::clear_staged_speculative_tokens(input_seqs);
484 }
485
486 let inputs_iter =
487 std::iter::once(self.get_processor().inputs_processor().process_inputs(
488 self.tokenizer(),
489 input_seqs,
490 is_prompt,
491 self.get_metadata().is_xlora,
492 &self.device(),
493 self.get_metadata().no_kv_cache,
494 None,
495 return_raw_logits,
496 self.get_metadata().sliding_window,
497 self.get_input_processor_config(),
498 None,
499 self.device_mapper(),
500 ));
501
502 let mut logits = vec![None; input_seqs.len()];
503 let len_inputs = 1;
504 let mut raw_out_logits = vec![vec![None; len_inputs]; input_seqs.len()];
505 let mut embedding_logits = vec![None; input_seqs.len()];
506
507 let mut exec_duration = Duration::ZERO;
508 for (i, inputs) in inputs_iter.into_iter().enumerate() {
509 let InputProcessorOutput {
510 inputs,
511 seq_indices,
512 } = inputs.map_err(hanzo_ml::Error::msg)?;
513 if i == 0 {
514 match pre_op {
515 CacheInstruction::In => self.clone_in_cache(input_seqs),
516 CacheInstruction::Nothing => (),
517 CacheInstruction::Reset {
518 load_preallocated_cache,
519 reset_non_granular,
520 } => self.set_none_cache(
521 input_seqs,
522 reset_non_granular,
523 false,
524 load_preallocated_cache,
525 ),
526 _ => unreachable!("Unreachable PRE cache op."),
527 }
528 }
529
530 let start = Instant::now();
531 let raw_logits = self.forward_inputs(inputs, return_raw_logits)?;
532 let end = Instant::now();
533 exec_duration += end.duration_since(start);
534
535 for (logit_idx, seq_idx) in seq_indices.into_iter().enumerate() {
536 if let ForwardInputsResult::RawLogits { logits } = &raw_logits {
537 raw_out_logits[seq_idx][i] =
538 Some(logits.i(logit_idx)?.to_device(&Device::Cpu)?);
539 } else if let ForwardInputsResult::Embeddings { embeddings } = &raw_logits {
540 embedding_logits[seq_idx] =
541 Some(embeddings.i(logit_idx)?.to_device(&Device::Cpu)?);
542 } else {
543 logits[seq_idx] = Some(raw_logits.index_bs(logit_idx)?);
544 }
545 }
546 }
547
548 match post_op {
549 CacheInstruction::Out => self.clone_out_cache(input_seqs),
550 CacheInstruction::Nothing => (),
551 CacheInstruction::Reset {
552 load_preallocated_cache,
553 reset_non_granular,
554 } => self.set_none_cache(
555 input_seqs,
556 reset_non_granular,
557 false,
558 load_preallocated_cache,
559 ),
560 _ => unreachable!("Unreachable POST cache op."),
561 }
562
563 if raw_out_logits[0][0].is_some() {
564 let start = Instant::now();
565 response::send_raw_responses(
566 input_seqs,
567 raw_out_logits
568 .into_iter()
569 .map(|raw| raw.into_iter().flatten().collect::<Vec<_>>())
570 .collect(),
571 )
572 .await?;
573 let end = Instant::now();
574 exec_duration += end.duration_since(start);
575
576 return Ok(exec_duration);
577 }
578 if embedding_logits[0].is_some() {
579 let start = Instant::now();
580 response::send_embedding_responses(
581 input_seqs,
582 embedding_logits
583 .into_iter()
584 .map(|raw| {
585 raw.unwrap()
586 .to_dtype(DType::F32)
587 .unwrap()
588 .to_vec1::<f32>()
589 .unwrap()
590 })
591 .collect(),
592 )
593 .await?;
594 let end = Instant::now();
595 exec_duration += end.duration_since(start);
596
597 return Ok(exec_duration);
598 }
599
600 let start = Instant::now();
601 let logits_on_cpu = logits.len() > 1;
602 let logits = logits
603 .into_iter()
604 .map(|l| {
605 let l = l.expect("missing forward result");
606 if logits_on_cpu {
607 l.to_device(&Device::Cpu)
608 } else {
609 Ok(l)
610 }
611 })
612 .collect::<hanzo_ml::Result<Vec<_>>>()?;
613
614 match &logits[0] {
615 ForwardInputsResult::RawLogits { .. }
616 | ForwardInputsResult::Embeddings { .. } => unreachable!(),
617 ForwardInputsResult::CausalGeneration { .. } => {
618 let logits = logits
619 .into_iter()
620 .map(|r| {
621 #[allow(irrefutable_let_patterns)]
622 let ForwardInputsResult::CausalGeneration { logits } = r
623 else {
624 unreachable!(
625 "All results must have same type, `CausalGeneration`"
626 )
627 };
628 logits
629 })
630 .collect::<Vec<_>>();
631 if is_prompt
632 || return_raw_logits
633 || !self
634 .try_sample_speculative_causal_gen(
635 input_seqs,
636 &logits,
637 prefix_cacher,
638 disable_eos_stop,
639 rng.clone(),
640 None,
641 )
642 .await?
643 {
644 self.sample_causal_gen(
645 input_seqs,
646 logits,
647 prefix_cacher,
648 disable_eos_stop,
649 rng,
650 )
651 .await?;
652 }
653 }
654 ForwardInputsResult::Image { .. } => {
655 response::send_image_responses(
656 input_seqs,
657 logits
658 .into_iter()
659 .map(|r| {
660 #[allow(irrefutable_let_patterns)]
661 let ForwardInputsResult::Image { images } = r
662 else {
663 unreachable!("All results must have same type, `Image`")
664 };
665 images
666 .into_iter()
667 .next()
668 .expect("Must have at least 1 element.")
669 })
670 .collect::<Vec<_>>(),
671 )
672 .await?;
673 }
674 ForwardInputsResult::Speech { .. } => {
675 let rates = logits
676 .iter()
677 .map(|r| {
678 #[allow(irrefutable_let_patterns)]
679 let ForwardInputsResult::Speech { rates, .. } = r
680 else {
681 unreachable!("All results must have same type, `Speech`")
682 };
683 assert_eq!(rates.len(), 1, "Each sequence must have 1 PCM output.");
684 *rates.first().unwrap()
685 })
686 .collect::<Vec<_>>();
687 let channels = logits
688 .iter()
689 .map(|r| {
690 #[allow(irrefutable_let_patterns)]
691 let ForwardInputsResult::Speech { channels, .. } = r
692 else {
693 unreachable!("All results must have same type, `Speech`")
694 };
695 assert_eq!(
696 channels.len(),
697 1,
698 "Each sequence must have 1 PCM output."
699 );
700 *channels.first().unwrap()
701 })
702 .collect::<Vec<_>>();
703 let pcms = logits
704 .into_iter()
705 .map(|r| {
706 #[allow(irrefutable_let_patterns)]
707 let ForwardInputsResult::Speech { pcms, .. } = r
708 else {
709 unreachable!("All results must have same type, `Speech`")
710 };
711 assert_eq!(pcms.len(), 1, "Each sequence must have 1 PCM output.");
712 pcms.into_iter().nth(0).unwrap()
713 })
714 .collect::<Vec<_>>();
715 response::send_speech_responses(input_seqs, &pcms, &rates, &channels)
716 .await?;
717 }
718 }
719 let end = Instant::now();
720 exec_duration += end.duration_since(start);
721
722 Ok(exec_duration)
723 }
724 CacheBackendMetadata::PagedAttention { metadata } => {
725 let speculative_metadata = metadata.clone();
726 if self.cache().is_hybrid() {
731 let mut hybrid_cache = self.cache().hybrid();
732 let recurrent_device = hybrid_cache.caches.iter().find_map(|c| {
733 if let HybridLayerCache::Recurrent(pool) = c {
734 Some(pool.device().clone())
735 } else {
736 None
737 }
738 });
739 if let Some(device) = recurrent_device {
740 #[allow(clippy::cast_possible_truncation)]
741 let indices: Vec<u32> = input_seqs
742 .iter()
743 .filter_map(|seq| seq.recurrent_state_idx().map(|idx| idx as u32))
744 .collect();
745 if indices.len() == input_seqs.len() {
746 if let Ok(si) = Tensor::from_vec(indices, (input_seqs.len(),), &device)
747 {
748 hybrid_cache.set_state_indices(Some(si));
749 }
750 }
751 }
752 }
753
754 let inputs_iter =
755 std::iter::once(self.get_processor().inputs_processor().process_inputs(
756 self.tokenizer(),
757 input_seqs,
758 is_prompt,
759 self.get_metadata().is_xlora,
760 &self.device(),
761 self.get_metadata().no_kv_cache,
762 None,
763 return_raw_logits,
764 self.get_metadata().sliding_window,
765 self.get_input_processor_config(),
766 Some(metadata),
767 self.device_mapper(),
768 ));
769
770 let mut logits = vec![None; input_seqs.len()];
771 let len_inputs = 1;
772 let mut raw_out_logits = vec![vec![None; len_inputs]; input_seqs.len()];
773 let mut embedding_logits = vec![None; input_seqs.len()];
774
775 let mut exec_duration = Duration::ZERO;
776 for (i, inputs) in inputs_iter.into_iter().enumerate() {
777 let InputProcessorOutput {
778 inputs,
779 seq_indices,
780 } = inputs.map_err(hanzo_ml::Error::msg)?;
781
782 let start = Instant::now();
783 let raw_logits = self.forward_inputs(inputs, return_raw_logits)?;
784 let end = Instant::now();
785 exec_duration += end.duration_since(start);
786
787 for (logit_idx, seq_idx) in seq_indices.into_iter().enumerate() {
788 if let ForwardInputsResult::RawLogits { logits } = &raw_logits {
789 raw_out_logits[seq_idx][i] =
790 Some(logits.i(logit_idx)?.to_device(&Device::Cpu)?);
791 } else if let ForwardInputsResult::Embeddings { embeddings } = &raw_logits {
792 embedding_logits[seq_idx] =
793 Some(embeddings.i(logit_idx)?.to_device(&Device::Cpu)?);
794 } else {
795 logits[seq_idx] = Some(raw_logits.index_bs(logit_idx)?);
796 }
797 }
798 }
799
800 if raw_out_logits[0][0].is_some() {
801 let start = Instant::now();
802 response::send_raw_responses(
803 input_seqs,
804 raw_out_logits
805 .into_iter()
806 .map(|raw| raw.into_iter().flatten().collect::<Vec<_>>())
807 .collect(),
808 )
809 .await?;
810 let end = Instant::now();
811 exec_duration += end.duration_since(start);
812
813 return Ok(exec_duration);
814 }
815 if embedding_logits[0].is_some() {
816 let start = Instant::now();
817 response::send_embedding_responses(
818 input_seqs,
819 embedding_logits
820 .into_iter()
821 .map(|raw| {
822 raw.unwrap()
823 .to_dtype(DType::F32)
824 .unwrap()
825 .to_vec1::<f32>()
826 .unwrap()
827 })
828 .collect(),
829 )
830 .await?;
831 let end = Instant::now();
832 exec_duration += end.duration_since(start);
833
834 return Ok(exec_duration);
835 }
836
837 let start = Instant::now();
838 let logits_on_cpu = logits.len() > 1;
839 let logits = logits
840 .into_iter()
841 .map(|l| {
842 let l = l.expect("missing forward result");
843 if logits_on_cpu {
844 l.to_device(&Device::Cpu)
845 } else {
846 Ok(l)
847 }
848 })
849 .collect::<hanzo_ml::Result<Vec<_>>>()?;
850 match &logits[0] {
851 ForwardInputsResult::RawLogits { .. }
852 | ForwardInputsResult::Embeddings { .. } => unreachable!(),
853 ForwardInputsResult::CausalGeneration { .. } => {
854 let logits = logits
855 .into_iter()
856 .map(|r| {
857 #[allow(irrefutable_let_patterns)]
858 let ForwardInputsResult::CausalGeneration { logits } = r
859 else {
860 unreachable!("All results must have same type")
861 };
862 logits
863 })
864 .collect::<Vec<_>>();
865 if is_prompt
866 || return_raw_logits
867 || !self
868 .try_sample_speculative_causal_gen(
869 input_seqs,
870 &logits,
871 prefix_cacher,
872 disable_eos_stop,
873 rng.clone(),
874 Some(speculative_metadata),
875 )
876 .await?
877 {
878 self.sample_causal_gen(
879 input_seqs,
880 logits,
881 prefix_cacher,
882 disable_eos_stop,
883 rng,
884 )
885 .await?;
886 }
887 }
888 ForwardInputsResult::Image { .. } => {
889 response::send_image_responses(
890 input_seqs,
891 logits
892 .into_iter()
893 .map(|r| {
894 #[allow(irrefutable_let_patterns)]
895 let ForwardInputsResult::Image { images } = r
896 else {
897 unreachable!("All results must have same type, `Image`")
898 };
899 images
900 .into_iter()
901 .next()
902 .expect("Must have at least 1 element.")
903 })
904 .collect::<Vec<_>>(),
905 )
906 .await?;
907 }
908 ForwardInputsResult::Speech { .. } => {
909 let rates = logits
910 .iter()
911 .map(|r| {
912 #[allow(irrefutable_let_patterns)]
913 let ForwardInputsResult::Speech { rates, .. } = r
914 else {
915 unreachable!("All results must have same type, `Speech`")
916 };
917 assert_eq!(rates.len(), 1, "Each sequence must have 1 PCM output.");
918 *rates.first().unwrap()
919 })
920 .collect::<Vec<_>>();
921 let channels = logits
922 .iter()
923 .map(|r| {
924 #[allow(irrefutable_let_patterns)]
925 let ForwardInputsResult::Speech { channels, .. } = r
926 else {
927 unreachable!("All results must have same type, `Speech`")
928 };
929 assert_eq!(
930 channels.len(),
931 1,
932 "Each sequence must have 1 PCM output."
933 );
934 *channels.first().unwrap()
935 })
936 .collect::<Vec<_>>();
937 let pcms = logits
938 .into_iter()
939 .map(|r| {
940 #[allow(irrefutable_let_patterns)]
941 let ForwardInputsResult::Speech { pcms, .. } = r
942 else {
943 unreachable!("All results must have same type, `Speech`")
944 };
945 assert_eq!(pcms.len(), 1, "Each sequence must have 1 PCM output.");
946 pcms.into_iter().nth(0).unwrap()
947 })
948 .collect::<Vec<_>>();
949 response::send_speech_responses(input_seqs, &pcms, &rates, &channels)
950 .await?;
951 }
952 }
953 let end = Instant::now();
954 exec_duration += end.duration_since(start);
955
956 Ok(exec_duration)
957 }
958 }
959 }
960
961 async fn sample_causal_gen(
962 &self,
963 seqs: &mut [&mut Sequence],
964 logits: Vec<Tensor>,
965 prefix_cacher: &mut PrefixCacheManagerV2,
966 disable_eos_stop: bool,
967 rng: Arc<std::sync::Mutex<Isaac64Rng>>,
968 ) -> Result<(), hanzo_ml::Error>;
969
970 fn category(&self) -> ModelCategory;
971
972 fn encoder_cache_counters(&self) -> Option<(Arc<AtomicUsize>, Arc<AtomicUsize>)> {
974 None
975 }
976}
977
978pub(crate) fn extract_logits(
979 logits: &Tensor,
980 context_lens: Vec<(usize, usize)>,
981) -> hanzo_ml::Result<Tensor> {
982 let mut toks = Vec::new();
983 for (dim, (start, len)) in logits.chunk(logits.dims()[0], 0)?.iter().zip(context_lens) {
984 toks.push(dim.narrow(1, start, len)?);
985 }
986 Tensor::cat(&toks, 0)
987}
988
989#[cfg(test)]
990mod tests {
991 use crate::MessageContent;
992 use either::Either;
993 use indexmap::IndexMap;
994 use serde_json::Value;
995
996 macro_rules! hashmap {
997 (@single $($x:tt)*) => (());
998 (@count $($rest:expr),*) => (<[()]>::len(&[$(hashmap!(@single $rest)),*]));
999
1000 ($($key:expr => $value:expr,)+) => { hashmap!($($key => $value),+) };
1001 ($($key:expr => $value:expr),*) => {
1002 {
1003 let _cap = hashmap!(@count $($key),*);
1004 let mut _map = ::indexmap::IndexMap::with_capacity(_cap);
1005 $(
1006 let _ = _map.insert($key, Value::String($value));
1007 )*
1008 _map
1009 }
1010 };
1011 }
1012
1013 #[cfg(test)]
1014 #[track_caller]
1015 fn test_with_inputs(
1016 templates: &[(bool, &str, &str, &str, &str)],
1017 expected_outputs: &[&str],
1018 inputs: Vec<IndexMap<String, MessageContent>>,
1019 ) {
1020 use crate::pipeline::chat_template::ChatTemplateValue;
1021
1022 use super::chat_template::apply_chat_template_to;
1023 let mut failed = Vec::new();
1024 let n_templates = templates.len();
1025 for ((has_system, bos, eos, unk, template), expected) in
1026 templates.iter().zip(expected_outputs)
1027 {
1028 let output = match apply_chat_template_to(
1029 if !has_system {
1030 inputs[1..].to_vec()
1031 } else {
1032 inputs.clone()
1033 },
1034 true,
1035 None,
1036 None, &ChatTemplateValue(Either::Left(template.to_string())),
1038 Some(bos.to_string()),
1039 Some(eos.to_string()),
1040 Some(unk.to_string()),
1041 Vec::new(),
1042 ) {
1043 Ok(v) => v,
1044 Err(e) => {
1045 failed.push(format!("Failed with {e}."));
1046 continue;
1047 }
1048 };
1049 if output != *expected {
1050 failed.push(format!(
1051 "Expected: `{}` \n\nGot: `{}`",
1052 expected.replace('\n', "\\n"),
1053 output.replace('\n', "\\n")
1054 ));
1055 }
1056 }
1057 if !failed.is_empty() {
1058 for (i, line) in failed.iter().enumerate() {
1059 println!("------------ Template {i} ------------");
1060 println!("{line}");
1061 }
1062 println!("------------------------");
1063 panic!("{}/{n_templates} chat templates failed.", failed.len());
1064 }
1065 }
1066
1067 #[test]
1068 fn test_chat_templates() {
1077 let templates = [
1078 (true, "<s>", "</s>", "<unk>", "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"),
1080 (false, "<s>", "</s>", "<unk>", "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"),
1082 (true, "<s>", "</s>", "<unk>", "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}"),
1084 (false, "<s>", "</s>", "<unk>", "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"),
1086 (false, "<bos>", "<eos>", "<unk>", "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}"),
1088 (true, "<s>", "</s>", "<unk>", "{% for message in messages %}{{message['role'].capitalize()}}{% if message['content'][0]['type'] == 'image' %}{{':'}}{% else %}{{': '}}{% endif %}{% for line in message['content'] %}{% if line['type'] == 'text' %}{{line['text']}}{% elif line['type'] == 'image' %}{{ '<image>' }}{% endif %}{% endfor %}<end_of_utterance>\n{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"),
1090 ];
1091 let expected_outputs = [
1092 "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi there<|im_end|>\n<|im_start|>user\nWho are you<|im_end|>\n<|im_start|>assistant\n I am an assistant <|im_end|>\n<|im_start|>user\nAnother question<|im_end|>\n<|im_start|>assistant\n",
1094 "<s>[INST] Hello [/INST]Hi there</s> [INST] Who are you [/INST] I am an assistant </s> [INST] Another question [/INST]",
1096 "<s>[INST] <<SYS>>\nYou are a helpful assistant\n<</SYS>>\n\nHello [/INST] Hi there </s><s>[INST] Who are you [/INST] I am an assistant </s><s>[INST] Another question [/INST]",
1098 "<s>[INST] Hello [/INST]Hi there</s>[INST] Who are you [/INST] I am an assistant </s>[INST] Another question [/INST]",
1100 "<bos><start_of_turn>user\nHello<end_of_turn>\n<start_of_turn>model\nHi there<end_of_turn>\n<start_of_turn>user\nWho are you<end_of_turn>\n<start_of_turn>model\nI am an assistant<end_of_turn>\n<start_of_turn>user\nAnother question<end_of_turn>\n<start_of_turn>model\n",
1102 ];
1103 let messages = [
1104 ["system", "You are a helpful assistant"],
1105 ["user", "Hello"],
1106 ["assistant", "Hi there"],
1107 ["user", "Who are you"],
1108 ["assistant", " I am an assistant "],
1109 ["user", "Another question"],
1110 ];
1111 let mut inputs = Vec::new();
1112 for [role, content] in messages {
1113 let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
1114 IndexMap::new();
1115 message.insert("role".to_string(), Either::Left(role.to_string()));
1116 message.insert("content".to_string(), Either::Left(content.to_string()));
1117 inputs.push(message);
1118 }
1119 test_with_inputs(&templates, &expected_outputs, inputs);
1120 }
1121
1122 #[test]
1123 fn test_image_chat_templates() {
1136 let templates = [
1137 (true, "<s>", "</s>", "<unk>", "{% for message in messages %}{{message['role'].capitalize()}}{% if message['content'][0]['type'] == 'image' %}{{':'}}{% else %}{{': '}}{% endif %}{% for line in message['content'] %}{% if line['type'] == 'text' %}{{line['text']}}{% elif line['type'] == 'image' %}{{ '<image>' }}{% endif %}{% endfor %}<end_of_utterance>\n{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"),
1139 ];
1140 let expected_outputs = [
1141 "System: You are a helpful assistant<end_of_utterance>\nUser:<image>Hello, please describe the above.<end_of_utterance>\nAssistant: Hi there<end_of_utterance>\nUser:<image>This is me, who are you<end_of_utterance>\nAssistant: I am an assistant <end_of_utterance>\nUser:<image>Another question, what is this?<end_of_utterance>\nAssistant:",
1143 ];
1144
1145 let mut inputs = Vec::new();
1146
1147 let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
1148 IndexMap::new();
1149 message.insert("role".to_string(), Either::Left("system".to_string()));
1150 message.insert(
1151 "content".to_string(),
1152 Either::Right(vec![hashmap! {
1153 "type".to_string() => "text".to_string(),
1154 "text".to_string() => "You are a helpful assistant".to_string()
1155 }]),
1156 );
1157 inputs.push(message);
1158
1159 let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
1160 IndexMap::new();
1161 message.insert("role".to_string(), Either::Left("user".to_string()));
1162 message.insert(
1163 "content".to_string(),
1164 Either::Right(vec![
1165 hashmap! {
1166 "type".to_string() => "image".to_string()
1167 },
1168 hashmap! {
1169 "type".to_string() => "text".to_string(),
1170 "text".to_string() => "Hello, please describe the above.".to_string()
1171 },
1172 ]),
1173 );
1174 inputs.push(message);
1175
1176 let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
1177 IndexMap::new();
1178 message.insert("role".to_string(), Either::Left("assistant".to_string()));
1179 message.insert(
1180 "content".to_string(),
1181 Either::Right(vec![hashmap! {
1182 "type".to_string() => "text".to_string(),
1183 "text".to_string() => "Hi there".to_string()
1184 }]),
1185 );
1186 inputs.push(message);
1187
1188 let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
1189 IndexMap::new();
1190 message.insert("role".to_string(), Either::Left("user".to_string()));
1191 message.insert(
1192 "content".to_string(),
1193 Either::Right(vec![
1194 hashmap! {
1195 "type".to_string() => "image".to_string()
1196 },
1197 hashmap! {
1198 "type".to_string() => "text".to_string(),
1199 "text".to_string() => "This is me, who are you".to_string()
1200 },
1201 ]),
1202 );
1203 inputs.push(message);
1204
1205 let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
1206 IndexMap::new();
1207 message.insert("role".to_string(), Either::Left("assistant".to_string()));
1208 message.insert(
1209 "content".to_string(),
1210 Either::Right(vec![hashmap! {
1211 "type".to_string() => "text".to_string(),
1212 "text".to_string() => " I am an assistant ".to_string()
1213 }]),
1214 );
1215 inputs.push(message);
1216
1217 let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
1218 IndexMap::new();
1219 message.insert("role".to_string(), Either::Left("user".to_string()));
1220 message.insert(
1221 "content".to_string(),
1222 Either::Right(vec![
1223 hashmap! {
1224 "type".to_string() => "image".to_string()
1225 },
1226 hashmap! {
1227 "type".to_string() => "text".to_string(),
1228 "text".to_string() => "Another question, what is this?".to_string()
1229 },
1230 ]),
1231 );
1232 inputs.push(message);
1233
1234 test_with_inputs(&templates, &expected_outputs, inputs);
1235 }
1236}