1mod amoe;
2mod auto;
3pub mod chat_template;
4mod diffusion;
5mod embedding;
6mod ggml;
7mod gguf;
8pub(crate) mod hf;
9mod inputs_processor;
10mod isq;
11pub(crate) mod llg;
12mod loaders;
13mod macros;
14mod normal;
15mod paths;
16mod processing;
17mod response;
18mod sampling;
19mod speculative;
20mod speech;
21mod vision;
22
23pub use super::diffusion_models::DiffusionGenerationParams;
24use crate::amoe::{AnyMoeConfig, AnyMoeExpertType, AnyMoeTrainingInputs, AnyMoeTrainingResult};
25use crate::device_map::DeviceMapper;
26use crate::paged_attention::{CacheConfig, CacheEngine, ModelConfigLike};
27use crate::prefix_cacher::PrefixCacheManagerV2;
28use crate::PagedAttentionConfig;
29pub use amoe::{AnyMoeLoader, AnyMoePipeline};
30pub use auto::{AutoLoader, AutoLoaderBuilder};
31use chat_template::ChatTemplate;
32pub use diffusion::{DiffusionLoader, DiffusionLoaderBuilder};
33pub use embedding::{EmbeddingLoader, EmbeddingLoaderBuilder, EmbeddingSpecificConfig};
34pub use ggml::{GGMLLoader, GGMLLoaderBuilder, GGMLSpecificConfig};
35pub use gguf::{GGUFLoader, GGUFLoaderBuilder, GGUFSpecificConfig};
36use image::DynamicImage;
37pub use inputs_processor::InputProcessorOutput;
38pub(crate) use isq::IsqModelLoader;
39pub use isq::{parse_isq_value, IsqModel, IsqOrganization, UQFF_MULTI_FILE_DELIMITER};
40use llguidance::toktrie::TokEnv;
41pub use loaders::{
42 AdapterKind, AutoDeviceMapParams, AutoEmbeddingLoader, AutoNormalLoader, AutoVisionLoader,
43 DeepSeekV2Loader, DeepSeekV3Loader, DeviceMappedModelLoader, DiffusionLoaderType,
44 DiffusionModel, DiffusionModelLoader, EmbeddingGemmaLoader, EmbeddingLoaderType,
45 EmbeddingModel, EmbeddingModelLoader, EmbeddingModelPaths, EmbeddingModule,
46 EmbeddingModulePaths, EmbeddingModuleType, FluxLoader, GLM4Loader, GLM4MoeLiteLoader,
47 GLM4MoeLoader, Gemma2Loader, Gemma3Loader, Gemma3nLoader, GemmaLoader, GptOssLoader,
48 GraniteMoeHybridLoader, Idefics2Loader, Idefics3Loader, LLaVALoader, LLaVANextLoader,
49 LlamaLoader, Loader, LocalModelPaths, MiniCpmOLoader, Mistral3Loader, MistralLoader,
50 MixtralLoader, ModelKind, ModelPaths, NormalLoaderType, NormalLoadingMetadata, NormalModel,
51 NormalModelLoader, Phi2Loader, Phi3Loader, Phi3VLoader, Phi3_5MoELoader, Phi4MMLoader,
52 PrettyName, QuantizationKind, Qwen2Loader, Qwen2VLLoader, Qwen2_5VLLoader,
53 Qwen3EmbeddingLoader, Qwen3Loader, Qwen3MoELoader, Qwen3VLLoader, Qwen3VLMoELoader,
54 SmolLm3Loader, Starcoder2Loader, TokenSource, VLlama4Loader, VLlamaLoader, VisionLoaderType,
55 VisionModel, VisionModelLoader,
56};
57#[allow(clippy::too_many_arguments)]
58pub(crate) fn get_device_layers_for_loader(
59 loader: &dyn loaders::DeviceMappedModelLoader,
60 config: &str,
61 num_layers: usize,
62 layer_sizes_in_bytes: Vec<usize>,
63 non_mapped_size_in_bytes: usize,
64 total_model_size_in_bytes: usize,
65 devices: &[Device],
66 dtype: DType,
67 params: &loaders::AutoDeviceMapParams,
68 paged_attn_config: Option<&PagedAttentionConfig>,
69) -> Result<crate::device_map::DeviceMapMetadata> {
70 loaders::auto_device_map::get_device_layers(
71 loader,
72 config,
73 num_layers,
74 layer_sizes_in_bytes,
75 non_mapped_size_in_bytes,
76 total_model_size_in_bytes,
77 devices,
78 dtype,
79 params,
80 paged_attn_config,
81 )
82}
83use mistralrs_quant::IsqType;
84pub use normal::{NormalLoader, NormalLoaderBuilder, NormalSpecificConfig};
85pub(crate) use paths::{get_chat_template, get_model_paths, get_xlora_paths};
86pub use paths::{AdapterPaths, LoraAdapterPaths};
87pub(crate) use processing::{
88 apply_chat_template, BasicProcessor, MessagesAction, Processor, ProcessorCreator,
89};
90use rand_isaac::Isaac64Rng;
91pub use speculative::{SpeculativeConfig, SpeculativeLoader, SpeculativePipeline};
92pub use speech::{SpeechLoader, SpeechPipeline};
93use std::any::Any;
94use std::collections::HashMap;
95use std::fmt::Debug;
96use std::sync::Arc;
97use std::time::{Duration, Instant};
98use tokenizers::Tokenizer;
99pub use vision::{VisionLoader, VisionLoaderBuilder, VisionSpecificConfig};
100
101use anyhow::Result;
102use candle_core::{DType, Device, IndexOp, Tensor, Var};
103
104use crate::sequence::Sequence;
105
106pub use self::inputs_processor::{
107 text_models_inputs_processor, InputsProcessor, InputsProcessorType,
108};
109use self::text_models_inputs_processor::PagedAttentionMeta;
110pub use crate::kv_cache::{
111 Cache, CacheManager, EitherCache, KvCache, LayerCaches, NormalCache, NormalCacheType,
112};
113
114#[derive(Clone, PartialEq, Eq)]
115pub enum SupportedModality {
116 Text,
117 Audio,
118 Vision,
119 Embedding,
120}
121
122impl Debug for SupportedModality {
123 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
124 match self {
125 Self::Text => write!(f, "📝 Text"),
126 Self::Audio => write!(f, "🔊 Audio"),
127 Self::Vision => write!(f, "🖼️ Vision"),
128 Self::Embedding => write!(f, "🔢 Embedding"),
129 }
130 }
131}
132
133#[derive(Debug, Clone)]
134pub struct Modalities {
135 pub input: Vec<SupportedModality>,
136 pub output: Vec<SupportedModality>,
137}
138
139pub struct GeneralMetadata {
140 pub max_seq_len: usize,
141 pub llg_factory: Option<Arc<llguidance::ParserFactory>>,
143 pub no_kv_cache: bool,
144 pub no_prefix_cache: bool,
145 pub num_hidden_layers: usize,
146 pub eos_tok: Vec<u32>,
147 pub kind: ModelKind,
148 pub is_xlora: bool,
150 pub activation_dtype: DType,
151 pub sliding_window: Option<usize>,
152 pub cache_config: Option<CacheConfig>,
154 pub cache_engine: Option<CacheEngine>,
155 pub model_metadata: Option<Arc<dyn ModelConfigLike + Send + Sync>>,
156 pub modalities: Modalities,
157}
158
159impl GeneralMetadata {
160 pub fn tok_env(&self) -> Option<TokEnv> {
161 self.llg_factory.as_ref().map(|f| f.tok_env().clone())
162 }
163}
164
165pub enum CacheInstruction {
166 In,
167 Out,
168 Reset {
170 load_preallocated_cache: bool,
171 reset_non_granular: bool,
172 },
173 Nothing,
174}
175
176pub trait PreProcessingMixin: MetadataMixin {
177 fn get_processor(&self) -> Arc<dyn Processor> {
178 Arc::new(BasicProcessor)
179 }
180 fn get_chat_template(&self) -> Option<Arc<ChatTemplate>>;
182 fn get_input_processor_config(&self) -> Option<Arc<dyn Any>>;
183}
184
185pub trait IsqPipelineMixin {
186 fn re_isq_model(&mut self, dtype: IsqType) -> Result<()>;
187}
188
189pub trait CacheManagerMixin {
190 fn clone_in_cache(&self, seqs: &mut [&mut Sequence]);
193 fn clone_out_cache(&self, seqs: &mut [&mut Sequence]);
196 fn set_none_cache(
200 &self,
201 seqs: &mut [&mut Sequence],
202 reset_non_granular: bool,
203 modify_draft_cache: bool,
204 load_preallocated_cache: bool,
205 );
206 fn cache(&self) -> &EitherCache;
207 fn do_preallocated_cache(&self) -> bool {
208 matches!(self.cache(), EitherCache::Normal(_))
209 }
210}
211
212pub trait MetadataMixin {
213 fn device(&self) -> Device;
214 fn tokenizer(&self) -> Option<Arc<Tokenizer>>;
216 fn name(&self) -> String;
217 fn reset_non_granular_state(&self);
218 fn get_metadata(&self) -> Arc<GeneralMetadata>;
219 fn device_mapper(&self) -> Option<&dyn DeviceMapper>;
220}
221
222pub trait AnyMoePipelineMixin {
224 fn amoe_layer_vars(&self) -> Vec<Vec<Var>> {
226 unreachable!()
227 }
228 fn amoe_finish_training(&mut self, _gate_model_id: Option<String>) -> candle_core::Result<()> {
229 unreachable!()
230 }
231 fn amoe_base_model_trainable_params(&self) -> usize {
232 unreachable!()
233 }
234 fn amoe_supported(&self) -> bool {
235 false
236 }
237 fn amoe_take_cached_gating_outputs(&mut self) -> Vec<Tensor> {
239 unreachable!()
240 }
241 #[allow(clippy::too_many_arguments)]
243 fn amoe_create_layers(
244 &mut self,
245 _model_ids: Vec<String>,
246 _token: &TokenSource,
247 _revision: Option<String>,
248 _match_regex: &str,
249 _config: AnyMoeConfig,
250 _dtype: DType,
251 _dev: &Device,
252 (_prefix, _mlp): (String, String),
253 _layers: Vec<usize>,
254 _expert_type: AnyMoeExpertType,
255 _silent: bool,
256 _gate_model_id: Option<String>,
257 ) -> candle_core::Result<()> {
258 unreachable!()
259 }
260 #[allow(clippy::too_many_arguments)]
262 fn amoe_pre_train(
263 &self,
264 _inputs: AnyMoeTrainingInputs,
265 (_prefix, _mlp): (String, String),
266 _model_ids: Vec<String>,
267 _token: TokenSource,
268 _revision: Option<String>,
269 _layers: Vec<usize>,
270 _silent: bool,
271 ) -> Result<Option<AnyMoeTrainingResult>, candle_core::Error> {
272 unreachable!()
273 }
274}
275
276#[derive(Clone)]
279pub enum ModelCategory {
280 Text,
281 Vision {
282 prefixer: Arc<dyn MultimodalPromptPrefixer>,
283 },
284 Diffusion,
285 Audio,
286 Speech,
287 Embedding,
288}
289
290impl std::fmt::Debug for ModelCategory {
291 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
292 match self {
293 ModelCategory::Text => write!(f, "ModelCategory::Text"),
294 ModelCategory::Vision { .. } => write!(f, "ModelCategory::Vision {{ prefixer: .. }}"),
295 ModelCategory::Diffusion => write!(f, "ModelCategory::Diffusion"),
296 ModelCategory::Audio => write!(f, "ModelCategory::Audio"),
297 ModelCategory::Speech => write!(f, "ModelCategory::Speech"),
298 ModelCategory::Embedding => write!(f, "ModelCategory::Embedding"),
299 }
300 }
301}
302
303impl PartialEq for ModelCategory {
304 fn eq(&self, other: &Self) -> bool {
305 match (self, other) {
306 (Self::Text, Self::Text) => true,
307 (Self::Vision { .. }, Self::Vision { .. }) => true,
308 (Self::Audio, Self::Audio) => true,
309 (Self::Speech, Self::Speech) => true,
310 (Self::Diffusion, Self::Diffusion) => true,
311 (Self::Embedding, Self::Embedding) => true,
312 (
313 Self::Text
314 | Self::Vision { .. }
315 | Self::Diffusion
316 | Self::Audio
317 | Self::Speech
318 | Self::Embedding,
319 _,
320 ) => false,
321 }
322 }
323}
324
325pub trait MultimodalPromptPrefixer: Send + Sync {
327 fn prefix_image(&self, _image_indices: Vec<usize>, prompt: &str) -> String {
329 prompt.to_string()
330 }
331 fn prefix_audio(&self, _audio_indexes: Vec<usize>, prompt: &str) -> String {
333 prompt.to_string()
334 }
335}
336
337pub enum CacheBackendMetadata {
338 DefaultInstructions {
339 pre_op: CacheInstruction,
340 post_op: CacheInstruction,
341 },
342 PagedAttention {
343 metadata: PagedAttentionMeta,
344 blocks_to_copy: HashMap<usize, Vec<usize>>,
345 },
346}
347
348#[derive(Clone, Debug)]
349pub enum ForwardInputsResult {
350 RawLogits {
351 logits: Tensor,
352 },
353 Embeddings {
354 embeddings: Tensor,
355 },
356 CausalGeneration {
357 logits: Tensor,
358 },
359 Image {
360 images: Vec<DynamicImage>,
361 },
362 Speech {
363 pcms: Vec<Arc<Vec<f32>>>,
364 rates: Vec<usize>,
365 channels: Vec<usize>,
366 },
367}
368
369impl ForwardInputsResult {
370 fn index_bs(&self, bs_idx: usize) -> candle_core::Result<Self> {
371 match self {
372 Self::CausalGeneration { logits } => Ok(Self::CausalGeneration {
373 logits: logits.i(bs_idx)?,
374 }),
375 Self::Embeddings { embeddings } => Ok(Self::Embeddings {
376 embeddings: embeddings.i(bs_idx)?,
377 }),
378 Self::RawLogits { logits } => Ok(Self::RawLogits {
379 logits: logits.i(bs_idx)?,
380 }),
381 Self::Image { images } => Ok(Self::Image {
382 images: vec![images[bs_idx].clone()],
383 }),
384 Self::Speech {
385 pcms,
386 rates,
387 channels,
388 } => Ok(Self::Speech {
389 pcms: vec![pcms[bs_idx].clone()],
390 rates: vec![rates[bs_idx]],
391 channels: vec![channels[bs_idx]],
392 }),
393 }
394 }
395
396 fn to_device(&self, device: &Device) -> candle_core::Result<Self> {
397 match self {
398 Self::CausalGeneration { logits } => Ok(Self::CausalGeneration {
399 logits: logits.to_device(device)?,
400 }),
401 Self::RawLogits { logits } => Ok(Self::RawLogits {
402 logits: logits.to_device(device)?,
403 }),
404 Self::Embeddings { embeddings } => Ok(Self::Embeddings {
405 embeddings: embeddings.to_device(device)?,
406 }),
407 Self::Image { .. } => Ok(self.clone()),
408 Self::Speech { .. } => Ok(self.clone()),
409 }
410 }
411}
412
413#[derive(serde::Serialize, serde::Deserialize)]
414pub(crate) struct FileListCache {
415 files: Vec<String>,
416}
417
418#[async_trait::async_trait]
419pub trait Pipeline:
420 Send
421 + Sync
422 + PreProcessingMixin
423 + IsqPipelineMixin
424 + CacheManagerMixin
425 + MetadataMixin
426 + AnyMoePipelineMixin
427{
428 fn forward_inputs(
429 &mut self,
430 inputs: Box<dyn Any>,
431 return_raw_logits: bool,
432 ) -> Result<ForwardInputsResult, candle_core::Error>;
433
434 #[allow(clippy::too_many_arguments)]
436 async fn step(
437 &mut self,
438 input_seqs: &mut [&mut Sequence],
439 is_prompt: bool,
440 return_raw_logits: bool,
441 prefix_cacher: &mut PrefixCacheManagerV2,
442 disable_eos_stop: bool,
443 rng: Arc<std::sync::Mutex<Isaac64Rng>>,
444 backend_metadata: CacheBackendMetadata,
445 ) -> Result<Duration, candle_core::Error> {
446 match backend_metadata {
447 CacheBackendMetadata::DefaultInstructions { pre_op, post_op } => {
448 let inputs_iter =
449 std::iter::once(self.get_processor().inputs_processor().process_inputs(
450 self.tokenizer(),
451 input_seqs,
452 is_prompt,
453 self.get_metadata().is_xlora,
454 &self.device(),
455 self.get_metadata().no_kv_cache,
456 None,
457 return_raw_logits,
458 self.get_input_processor_config(),
459 None,
460 self.device_mapper(),
461 ));
462
463 let mut logits = vec![None; input_seqs.len()];
464 let len_inputs = 1;
465 let mut raw_out_logits = vec![vec![None; len_inputs]; input_seqs.len()];
466 let mut embedding_logits = vec![None; input_seqs.len()];
467
468 let mut exec_duration = Duration::ZERO;
469 for (i, inputs) in inputs_iter.into_iter().enumerate() {
470 let InputProcessorOutput {
471 inputs,
472 seq_indices,
473 } = inputs.map_err(candle_core::Error::msg)?;
474 if i == 0 {
475 match pre_op {
476 CacheInstruction::In => self.clone_in_cache(input_seqs),
477 CacheInstruction::Nothing => (),
478 CacheInstruction::Reset {
479 load_preallocated_cache,
480 reset_non_granular,
481 } => self.set_none_cache(
482 input_seqs,
483 reset_non_granular,
484 false,
485 load_preallocated_cache,
486 ),
487 _ => unreachable!("Unreachable PRE cache op."),
488 }
489 }
490
491 let start = Instant::now();
492 let raw_logits = self.forward_inputs(inputs, return_raw_logits)?;
493 let end = Instant::now();
494 exec_duration += end.duration_since(start);
495
496 for (logit_idx, seq_idx) in seq_indices.into_iter().enumerate() {
497 if let ForwardInputsResult::RawLogits { logits } = &raw_logits {
498 raw_out_logits[seq_idx][i] =
499 Some(logits.i(logit_idx)?.to_device(&Device::Cpu)?);
500 } else if let ForwardInputsResult::Embeddings { embeddings } = &raw_logits {
501 embedding_logits[seq_idx] =
502 Some(embeddings.i(logit_idx)?.to_device(&Device::Cpu)?);
503 } else {
504 logits[seq_idx] = Some(raw_logits.index_bs(logit_idx)?);
505 }
506 }
507 }
508
509 match post_op {
510 CacheInstruction::Out => self.clone_out_cache(input_seqs),
511 CacheInstruction::Nothing => (),
512 CacheInstruction::Reset {
513 load_preallocated_cache,
514 reset_non_granular,
515 } => self.set_none_cache(
516 input_seqs,
517 reset_non_granular,
518 false,
519 load_preallocated_cache,
520 ),
521 _ => unreachable!("Unreachable POST cache op."),
522 }
523
524 if raw_out_logits[0][0].is_some() {
525 let start = Instant::now();
526 response::send_raw_responses(
527 input_seqs,
528 raw_out_logits
529 .into_iter()
530 .map(|raw| raw.into_iter().flatten().collect::<Vec<_>>())
531 .collect(),
532 )
533 .await?;
534 let end = Instant::now();
535 exec_duration += end.duration_since(start);
536
537 return Ok(exec_duration);
538 }
539 if embedding_logits[0].is_some() {
540 let start = Instant::now();
541 response::send_embedding_responses(
542 input_seqs,
543 embedding_logits
544 .into_iter()
545 .map(|raw| {
546 raw.unwrap()
547 .to_dtype(DType::F32)
548 .unwrap()
549 .to_vec1::<f32>()
550 .unwrap()
551 })
552 .collect(),
553 )
554 .await?;
555 let end = Instant::now();
556 exec_duration += end.duration_since(start);
557
558 return Ok(exec_duration);
559 }
560
561 let start = Instant::now();
562 let logits_on_cpu = logits.len() > 1;
563 let logits = logits
564 .into_iter()
565 .map(|l| {
566 let l = l.expect("Did not get any inputs. This is shocking.");
567 if logits_on_cpu {
568 l.to_device(&Device::Cpu)
569 } else {
570 Ok(l)
571 }
572 })
573 .collect::<candle_core::Result<Vec<_>>>()?;
574
575 match &logits[0] {
576 ForwardInputsResult::RawLogits { .. }
577 | ForwardInputsResult::Embeddings { .. } => unreachable!(),
578 ForwardInputsResult::CausalGeneration { .. } => {
579 self.sample_causal_gen(
580 input_seqs,
581 logits
582 .into_iter()
583 .map(|r| {
584 #[allow(irrefutable_let_patterns)]
585 let ForwardInputsResult::CausalGeneration { logits } = r
586 else {
587 unreachable!(
588 "All results must have same type, `CausalGeneration`"
589 )
590 };
591 logits
592 })
593 .collect::<Vec<_>>(),
594 prefix_cacher,
595 disable_eos_stop,
596 rng,
597 )
598 .await?;
599 }
600 ForwardInputsResult::Image { .. } => {
601 response::send_image_responses(
602 input_seqs,
603 logits
604 .into_iter()
605 .map(|r| {
606 #[allow(irrefutable_let_patterns)]
607 let ForwardInputsResult::Image { images } = r
608 else {
609 unreachable!("All results must have same type, `Image`")
610 };
611 images
612 .into_iter()
613 .next()
614 .expect("Must have at least 1 element.")
615 })
616 .collect::<Vec<_>>(),
617 )
618 .await?;
619 }
620 ForwardInputsResult::Speech { .. } => {
621 let rates = logits
622 .iter()
623 .map(|r| {
624 #[allow(irrefutable_let_patterns)]
625 let ForwardInputsResult::Speech { rates, .. } = r
626 else {
627 unreachable!("All results must have same type, `Speech`")
628 };
629 assert_eq!(rates.len(), 1, "Each sequence must have 1 PCM output.");
630 *rates.first().unwrap()
631 })
632 .collect::<Vec<_>>();
633 let channels = logits
634 .iter()
635 .map(|r| {
636 #[allow(irrefutable_let_patterns)]
637 let ForwardInputsResult::Speech { channels, .. } = r
638 else {
639 unreachable!("All results must have same type, `Speech`")
640 };
641 assert_eq!(
642 channels.len(),
643 1,
644 "Each sequence must have 1 PCM output."
645 );
646 *channels.first().unwrap()
647 })
648 .collect::<Vec<_>>();
649 let pcms = logits
650 .into_iter()
651 .map(|r| {
652 #[allow(irrefutable_let_patterns)]
653 let ForwardInputsResult::Speech { pcms, .. } = r
654 else {
655 unreachable!("All results must have same type, `Speech`")
656 };
657 assert_eq!(pcms.len(), 1, "Each sequence must have 1 PCM output.");
658 pcms.into_iter().nth(0).unwrap()
659 })
660 .collect::<Vec<_>>();
661 response::send_speech_responses(input_seqs, &pcms, &rates, &channels)
662 .await?;
663 }
664 }
665 let end = Instant::now();
666 exec_duration += end.duration_since(start);
667
668 Ok(exec_duration)
669 }
670 CacheBackendMetadata::PagedAttention {
671 metadata,
672 blocks_to_copy,
673 } => {
674 self.get_metadata()
676 .cache_engine
677 .as_ref()
678 .expect("PagedAttention must have cache engines.")
679 .execute_scheduler_ops(&blocks_to_copy)?;
680
681 let inputs_iter =
682 std::iter::once(self.get_processor().inputs_processor().process_inputs(
683 self.tokenizer(),
684 input_seqs,
685 is_prompt,
686 self.get_metadata().is_xlora,
687 &self.device(),
688 self.get_metadata().no_kv_cache,
689 None,
690 return_raw_logits,
691 self.get_input_processor_config(),
692 Some(metadata),
693 self.device_mapper(),
694 ));
695
696 let mut logits = vec![None; input_seqs.len()];
697 let len_inputs = 1;
698 let mut raw_out_logits = vec![vec![None; len_inputs]; input_seqs.len()];
699 let mut embedding_logits = vec![None; input_seqs.len()];
700
701 let mut exec_duration = Duration::ZERO;
702 for (i, inputs) in inputs_iter.into_iter().enumerate() {
703 let InputProcessorOutput {
704 inputs,
705 seq_indices,
706 } = inputs.map_err(candle_core::Error::msg)?;
707
708 let start = Instant::now();
709 let raw_logits = self.forward_inputs(inputs, return_raw_logits)?;
710 let end = Instant::now();
711 exec_duration += end.duration_since(start);
712
713 for (logit_idx, seq_idx) in seq_indices.into_iter().enumerate() {
714 if let ForwardInputsResult::RawLogits { logits } = &raw_logits {
715 raw_out_logits[seq_idx][i] =
716 Some(logits.i(logit_idx)?.to_device(&Device::Cpu)?);
717 } else if let ForwardInputsResult::Embeddings { embeddings } = &raw_logits {
718 embedding_logits[seq_idx] =
719 Some(embeddings.i(logit_idx)?.to_device(&Device::Cpu)?);
720 } else {
721 logits[seq_idx] = Some(raw_logits.index_bs(logit_idx)?);
722 }
723 }
724 }
725
726 if raw_out_logits[0][0].is_some() {
727 let start = Instant::now();
728 response::send_raw_responses(
729 input_seqs,
730 raw_out_logits
731 .into_iter()
732 .map(|raw| raw.into_iter().flatten().collect::<Vec<_>>())
733 .collect(),
734 )
735 .await?;
736 let end = Instant::now();
737 exec_duration += end.duration_since(start);
738
739 return Ok(exec_duration);
740 }
741 if embedding_logits[0].is_some() {
742 let start = Instant::now();
743 response::send_embedding_responses(
744 input_seqs,
745 embedding_logits
746 .into_iter()
747 .map(|raw| {
748 raw.unwrap()
749 .to_dtype(DType::F32)
750 .unwrap()
751 .to_vec1::<f32>()
752 .unwrap()
753 })
754 .collect(),
755 )
756 .await?;
757 let end = Instant::now();
758 exec_duration += end.duration_since(start);
759
760 return Ok(exec_duration);
761 }
762
763 let start = Instant::now();
764 let logits_on_cpu = logits.len() > 1;
765 let logits = logits
766 .into_iter()
767 .map(|l| {
768 let l = l.expect("Did not get any inputs. This is shocking.");
769 if logits_on_cpu {
770 l.to_device(&Device::Cpu)
771 } else {
772 Ok(l)
773 }
774 })
775 .collect::<candle_core::Result<Vec<_>>>()?;
776
777 match &logits[0] {
778 ForwardInputsResult::RawLogits { .. }
779 | ForwardInputsResult::Embeddings { .. } => unreachable!(),
780 ForwardInputsResult::CausalGeneration { .. } => {
781 self.sample_causal_gen(
782 input_seqs,
783 logits
784 .into_iter()
785 .map(|r| {
786 #[allow(irrefutable_let_patterns)]
787 let ForwardInputsResult::CausalGeneration { logits } = r
788 else {
789 unreachable!("All results must have same type")
790 };
791 logits
792 })
793 .collect::<Vec<_>>(),
794 prefix_cacher,
795 disable_eos_stop,
796 rng,
797 )
798 .await?;
799 }
800 ForwardInputsResult::Image { .. } => {
801 response::send_image_responses(
802 input_seqs,
803 logits
804 .into_iter()
805 .map(|r| {
806 #[allow(irrefutable_let_patterns)]
807 let ForwardInputsResult::Image { images } = r
808 else {
809 unreachable!("All results must have same type, `Image`")
810 };
811 images
812 .into_iter()
813 .next()
814 .expect("Must have at least 1 element.")
815 })
816 .collect::<Vec<_>>(),
817 )
818 .await?;
819 }
820 ForwardInputsResult::Speech { .. } => {
821 let rates = logits
822 .iter()
823 .map(|r| {
824 #[allow(irrefutable_let_patterns)]
825 let ForwardInputsResult::Speech { rates, .. } = r
826 else {
827 unreachable!("All results must have same type, `Speech`")
828 };
829 assert_eq!(rates.len(), 1, "Each sequence must have 1 PCM output.");
830 *rates.first().unwrap()
831 })
832 .collect::<Vec<_>>();
833 let channels = logits
834 .iter()
835 .map(|r| {
836 #[allow(irrefutable_let_patterns)]
837 let ForwardInputsResult::Speech { channels, .. } = r
838 else {
839 unreachable!("All results must have same type, `Speech`")
840 };
841 assert_eq!(
842 channels.len(),
843 1,
844 "Each sequence must have 1 PCM output."
845 );
846 *channels.first().unwrap()
847 })
848 .collect::<Vec<_>>();
849 let pcms = logits
850 .into_iter()
851 .map(|r| {
852 #[allow(irrefutable_let_patterns)]
853 let ForwardInputsResult::Speech { pcms, .. } = r
854 else {
855 unreachable!("All results must have same type, `Speech`")
856 };
857 assert_eq!(pcms.len(), 1, "Each sequence must have 1 PCM output.");
858 pcms.into_iter().nth(0).unwrap()
859 })
860 .collect::<Vec<_>>();
861 response::send_speech_responses(input_seqs, &pcms, &rates, &channels)
862 .await?;
863 }
864 }
865 let end = Instant::now();
866 exec_duration += end.duration_since(start);
867
868 Ok(exec_duration)
869 }
870 }
871 }
872
873 async fn sample_causal_gen(
874 &self,
875 seqs: &mut [&mut Sequence],
876 logits: Vec<Tensor>,
877 prefix_cacher: &mut PrefixCacheManagerV2,
878 disable_eos_stop: bool,
879 rng: Arc<std::sync::Mutex<Isaac64Rng>>,
880 ) -> Result<(), candle_core::Error>;
881
882 fn category(&self) -> ModelCategory;
883}
884
885pub(crate) fn extract_logits(
886 logits: &Tensor,
887 context_lens: Vec<(usize, usize)>,
888) -> candle_core::Result<Tensor> {
889 let mut toks = Vec::new();
890 for (dim, (start, len)) in logits.chunk(logits.dims()[0], 0)?.iter().zip(context_lens) {
891 toks.push(dim.narrow(1, start, len)?);
892 }
893 Tensor::cat(&toks, 0)
894}
895
896#[cfg(test)]
897mod tests {
898 use crate::MessageContent;
899 use either::Either;
900 use indexmap::IndexMap;
901 use serde_json::Value;
902
903 macro_rules! hashmap {
904 (@single $($x:tt)*) => (());
905 (@count $($rest:expr),*) => (<[()]>::len(&[$(hashmap!(@single $rest)),*]));
906
907 ($($key:expr => $value:expr,)+) => { hashmap!($($key => $value),+) };
908 ($($key:expr => $value:expr),*) => {
909 {
910 let _cap = hashmap!(@count $($key),*);
911 let mut _map = ::indexmap::IndexMap::with_capacity(_cap);
912 $(
913 let _ = _map.insert($key, Value::String($value));
914 )*
915 _map
916 }
917 };
918 }
919
920 #[cfg(test)]
921 #[track_caller]
922 fn test_with_inputs(
923 templates: &[(bool, &str, &str, &str, &str)],
924 expected_outputs: &[&str],
925 inputs: Vec<IndexMap<String, MessageContent>>,
926 ) {
927 use crate::pipeline::chat_template::ChatTemplateValue;
928
929 use super::chat_template::apply_chat_template_to;
930 let mut failed = Vec::new();
931 let n_templates = templates.len();
932 for ((has_system, bos, eos, unk, template), expected) in
933 templates.iter().zip(expected_outputs)
934 {
935 let output = match apply_chat_template_to(
936 if !has_system {
937 inputs[1..].to_vec()
938 } else {
939 inputs.clone()
940 },
941 true,
942 None,
943 None, &ChatTemplateValue(Either::Left(template.to_string())),
945 Some(bos.to_string()),
946 Some(eos.to_string()),
947 Some(unk.to_string()),
948 Vec::new(),
949 ) {
950 Ok(v) => v,
951 Err(e) => {
952 failed.push(format!("Failed with {e}."));
953 continue;
954 }
955 };
956 if output != *expected {
957 failed.push(format!(
958 "Expected: `{}` \n\nGot: `{}`",
959 expected.replace('\n', "\\n"),
960 output.replace('\n', "\\n")
961 ));
962 }
963 }
964 if !failed.is_empty() {
965 for (i, line) in failed.iter().enumerate() {
966 println!("------------ Template {i} ------------");
967 println!("{line}");
968 }
969 println!("------------------------");
970 panic!("{}/{n_templates} chat templates failed.", failed.len());
971 }
972 }
973
974 #[test]
975 fn test_chat_templates() {
984 let templates = [
985 (true, "<s>", "</s>", "<unk>", "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"),
987 (false, "<s>", "</s>", "<unk>", "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"),
989 (true, "<s>", "</s>", "<unk>", "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}"),
991 (false, "<s>", "</s>", "<unk>", "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"),
993 (false, "<bos>", "<eos>", "<unk>", "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}"),
995 (true, "<s>", "</s>", "<unk>", "{% for message in messages %}{{message['role'].capitalize()}}{% if message['content'][0]['type'] == 'image' %}{{':'}}{% else %}{{': '}}{% endif %}{% for line in message['content'] %}{% if line['type'] == 'text' %}{{line['text']}}{% elif line['type'] == 'image' %}{{ '<image>' }}{% endif %}{% endfor %}<end_of_utterance>\n{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"),
997 ];
998 let expected_outputs = [
999 "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi there<|im_end|>\n<|im_start|>user\nWho are you<|im_end|>\n<|im_start|>assistant\n I am an assistant <|im_end|>\n<|im_start|>user\nAnother question<|im_end|>\n<|im_start|>assistant\n",
1001 "<s>[INST] Hello [/INST]Hi there</s> [INST] Who are you [/INST] I am an assistant </s> [INST] Another question [/INST]",
1003 "<s>[INST] <<SYS>>\nYou are a helpful assistant\n<</SYS>>\n\nHello [/INST] Hi there </s><s>[INST] Who are you [/INST] I am an assistant </s><s>[INST] Another question [/INST]",
1005 "<s>[INST] Hello [/INST]Hi there</s>[INST] Who are you [/INST] I am an assistant </s>[INST] Another question [/INST]",
1007 "<bos><start_of_turn>user\nHello<end_of_turn>\n<start_of_turn>model\nHi there<end_of_turn>\n<start_of_turn>user\nWho are you<end_of_turn>\n<start_of_turn>model\nI am an assistant<end_of_turn>\n<start_of_turn>user\nAnother question<end_of_turn>\n<start_of_turn>model\n",
1009 ];
1010 let messages = [
1011 ["system", "You are a helpful assistant"],
1012 ["user", "Hello"],
1013 ["assistant", "Hi there"],
1014 ["user", "Who are you"],
1015 ["assistant", " I am an assistant "],
1016 ["user", "Another question"],
1017 ];
1018 let mut inputs = Vec::new();
1019 for [role, content] in messages {
1020 let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
1021 IndexMap::new();
1022 message.insert("role".to_string(), Either::Left(role.to_string()));
1023 message.insert("content".to_string(), Either::Left(content.to_string()));
1024 inputs.push(message);
1025 }
1026 test_with_inputs(&templates, &expected_outputs, inputs);
1027 }
1028
1029 #[test]
1030 fn test_image_chat_templates() {
1043 let templates = [
1044 (true, "<s>", "</s>", "<unk>", "{% for message in messages %}{{message['role'].capitalize()}}{% if message['content'][0]['type'] == 'image' %}{{':'}}{% else %}{{': '}}{% endif %}{% for line in message['content'] %}{% if line['type'] == 'text' %}{{line['text']}}{% elif line['type'] == 'image' %}{{ '<image>' }}{% endif %}{% endfor %}<end_of_utterance>\n{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"),
1046 ];
1047 let expected_outputs = [
1048 "System: You are a helpful assistant<end_of_utterance>\nUser:<image>Hello, please describe the above.<end_of_utterance>\nAssistant: Hi there<end_of_utterance>\nUser:<image>This is me, who are you<end_of_utterance>\nAssistant: I am an assistant <end_of_utterance>\nUser:<image>Another question, what is this?<end_of_utterance>\nAssistant:",
1050 ];
1051
1052 let mut inputs = Vec::new();
1053
1054 let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
1055 IndexMap::new();
1056 message.insert("role".to_string(), Either::Left("system".to_string()));
1057 message.insert(
1058 "content".to_string(),
1059 Either::Right(vec![hashmap! {
1060 "type".to_string() => "text".to_string(),
1061 "text".to_string() => "You are a helpful assistant".to_string()
1062 }]),
1063 );
1064 inputs.push(message);
1065
1066 let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
1067 IndexMap::new();
1068 message.insert("role".to_string(), Either::Left("user".to_string()));
1069 message.insert(
1070 "content".to_string(),
1071 Either::Right(vec![
1072 hashmap! {
1073 "type".to_string() => "image".to_string()
1074 },
1075 hashmap! {
1076 "type".to_string() => "text".to_string(),
1077 "text".to_string() => "Hello, please describe the above.".to_string()
1078 },
1079 ]),
1080 );
1081 inputs.push(message);
1082
1083 let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
1084 IndexMap::new();
1085 message.insert("role".to_string(), Either::Left("assistant".to_string()));
1086 message.insert(
1087 "content".to_string(),
1088 Either::Right(vec![hashmap! {
1089 "type".to_string() => "text".to_string(),
1090 "text".to_string() => "Hi there".to_string()
1091 }]),
1092 );
1093 inputs.push(message);
1094
1095 let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
1096 IndexMap::new();
1097 message.insert("role".to_string(), Either::Left("user".to_string()));
1098 message.insert(
1099 "content".to_string(),
1100 Either::Right(vec![
1101 hashmap! {
1102 "type".to_string() => "image".to_string()
1103 },
1104 hashmap! {
1105 "type".to_string() => "text".to_string(),
1106 "text".to_string() => "This is me, who are you".to_string()
1107 },
1108 ]),
1109 );
1110 inputs.push(message);
1111
1112 let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
1113 IndexMap::new();
1114 message.insert("role".to_string(), Either::Left("assistant".to_string()));
1115 message.insert(
1116 "content".to_string(),
1117 Either::Right(vec![hashmap! {
1118 "type".to_string() => "text".to_string(),
1119 "text".to_string() => " I am an assistant ".to_string()
1120 }]),
1121 );
1122 inputs.push(message);
1123
1124 let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
1125 IndexMap::new();
1126 message.insert("role".to_string(), Either::Left("user".to_string()));
1127 message.insert(
1128 "content".to_string(),
1129 Either::Right(vec![
1130 hashmap! {
1131 "type".to_string() => "image".to_string()
1132 },
1133 hashmap! {
1134 "type".to_string() => "text".to_string(),
1135 "text".to_string() => "Another question, what is this?".to_string()
1136 },
1137 ]),
1138 );
1139 inputs.push(message);
1140
1141 test_with_inputs(&templates, &expected_outputs, inputs);
1142 }
1143}