Skip to main content

mistralrs_core/pipeline/
speech.rs

1use super::text_models_inputs_processor::PagedAttentionMeta;
2use super::{
3    AdapterPaths, AnyMoePipelineMixin, Cache, CacheManagerMixin, EitherCache, ForwardInputsResult,
4    GeneralMetadata, InputProcessorOutput, InputsProcessor, InputsProcessorType, IsqPipelineMixin,
5    Loader, MessagesAction, MetadataMixin, ModelCategory, ModelKind, ModelPaths,
6    PreProcessingMixin, Processor, TokenSource,
7};
8use crate::device_map::{self, DeviceMapper};
9use crate::distributed::WorkerTransferData;
10use crate::pipeline::{ChatTemplate, EmbeddingModulePaths, Modalities, SupportedModality};
11use crate::prefix_cacher::PrefixCacheManagerV2;
12use crate::sequence::Sequence;
13use crate::speech_models::{DiaConfig, DiaPipeline, SpeechGenerationOutput, SpeechLoaderType};
14use crate::utils::progress::ProgressScopeGuard;
15use crate::utils::varbuilder_utils::DeviceForLoadTensor;
16use crate::utils::{tokens::get_token, varbuilder_utils::from_mmaped_safetensors};
17use crate::{
18    api_get_file, distributed, DeviceMapSetting, MessageContent, PagedAttentionConfig, Pipeline,
19    SpeechGenerationConfig, TryIntoDType,
20};
21use anyhow::Result;
22use candle_core::{Device, Tensor};
23use candle_nn::VarBuilder;
24use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
25use indexmap::IndexMap;
26use mistralrs_quant::IsqType;
27use rand_isaac::Isaac64Rng;
28use regex::Regex;
29use std::any::Any;
30use std::env;
31use std::path::PathBuf;
32use std::sync::Arc;
33use tokenizers::Tokenizer;
34use tokio::sync::Mutex;
35
36#[derive(Clone, Debug)]
37pub struct SpeechModelPaths {
38    weights: Vec<PathBuf>,
39    config: PathBuf,
40}
41
42impl ModelPaths for SpeechModelPaths {
43    fn get_config_filename(&self) -> &PathBuf {
44        &self.config
45    }
46    fn get_tokenizer_filename(&self) -> &PathBuf {
47        unreachable!("Use `std::any::Any`.")
48    }
49    fn get_weight_filenames(&self) -> &[PathBuf] {
50        &self.weights
51    }
52    fn get_template_filename(&self) -> &Option<PathBuf> {
53        unreachable!("Use `std::any::Any`.")
54    }
55    fn get_gen_conf_filename(&self) -> Option<&PathBuf> {
56        unreachable!("Use `std::any::Any`.")
57    }
58    fn get_preprocessor_config(&self) -> &Option<PathBuf> {
59        unreachable!("Use `std::any::Any`.")
60    }
61    fn get_processor_config(&self) -> &Option<PathBuf> {
62        unreachable!("Use `std::any::Any`.")
63    }
64    fn get_chat_template_explicit(&self) -> &Option<PathBuf> {
65        unreachable!("Use `std::any::Any`.")
66    }
67    fn get_adapter_paths(&self) -> &AdapterPaths {
68        unreachable!("Use `std::any::Any`.")
69    }
70    fn get_modules(&self) -> Option<&[EmbeddingModulePaths]> {
71        unreachable!("Use `std::any::Any`.")
72    }
73}
74
75pub struct SpeechProcessor;
76
77impl Processor for SpeechProcessor {
78    fn process(
79        &self,
80        _pipeline: &dyn Pipeline,
81        _messages: Vec<IndexMap<String, MessageContent>>,
82        _add_generation_prompt: bool,
83        _add_special_tokens: bool,
84        _enable_thinking: Option<bool>,
85        _reasoning_effort: Option<crate::request::ReasoningEffort>,
86        _tools: Vec<crate::Tool>,
87    ) -> Result<(Vec<u32>, String)> {
88        anyhow::bail!(
89            "SpeechProcessor::process should not be used. It does not expect chat messages."
90        )
91    }
92    fn inputs_processor(&self) -> Arc<dyn InputsProcessor> {
93        Arc::new(SpeechInputsProcessor)
94    }
95    fn get_special_tokens(&self) -> &[&'static str] {
96        &[]
97    }
98    fn template_action(&self) -> MessagesAction {
99        // Just a default
100        MessagesAction::FlattenOnlyText
101    }
102}
103
104pub struct SpeechInputsProcessor;
105
106#[derive(Clone)]
107pub struct ModelInputs {
108    pub(crate) prompts: Vec<String>,
109}
110
111impl InputsProcessor for SpeechInputsProcessor {
112    fn get_type(&self) -> InputsProcessorType {
113        InputsProcessorType::Text
114    }
115
116    fn process_inputs(
117        &self,
118        _tokenizer: Option<Arc<Tokenizer>>,
119        input_seqs: &mut [&mut Sequence],
120        _is_prompt: bool,
121        _is_xlora: bool,
122        _device: &Device,
123        _no_kv_cache: bool,
124        _last_n_context_len: Option<(usize, usize)>,
125        _return_raw_logits: bool,
126        _other_config: Option<Arc<dyn Any>>,
127        _paged_attn_metadata: Option<PagedAttentionMeta>,
128        _mapper: Option<&dyn DeviceMapper>,
129    ) -> Result<InputProcessorOutput> {
130        let inputs = ModelInputs {
131            prompts: input_seqs
132                .iter()
133                .map(|seq| seq.get_initial_prompt().to_string())
134                .collect(),
135        };
136        Ok(InputProcessorOutput {
137            inputs: Box::new(inputs),
138            seq_indices: (0..input_seqs.len()).collect::<Vec<_>>(),
139        })
140    }
141}
142
143pub struct SpeechPipeline {
144    model_id: String,
145    model: DiaPipeline,
146    metadata: Arc<GeneralMetadata>,
147    dummy_cache: EitherCache,
148    cfg: SpeechGenerationConfig,
149}
150
151pub struct SpeechLoader {
152    pub model_id: String,
153    pub dac_model_id: Option<String>,
154    pub arch: SpeechLoaderType,
155    pub cfg: Option<SpeechGenerationConfig>,
156}
157
158impl Loader for SpeechLoader {
159    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
160    fn load_model_from_hf(
161        &self,
162        revision: Option<String>,
163        token_source: TokenSource,
164        dtype: &dyn TryIntoDType,
165        device: &Device,
166        silent: bool,
167        mapper: DeviceMapSetting,
168        in_situ_quant: Option<IsqType>,
169        paged_attn_config: Option<PagedAttentionConfig>,
170    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
171        let _progress_guard = ProgressScopeGuard::new(silent);
172        let paths: anyhow::Result<Box<dyn ModelPaths>> = {
173            // Main weights first, DAC is the final one.
174            let mut weights = Vec::new();
175
176            // Main model
177            let config = {
178                let api = ApiBuilder::new()
179                    .with_progress(!silent)
180                    .with_token(get_token(&token_source)?)
181                    .build()?;
182                let revision = revision.clone().unwrap_or("main".to_string());
183                let api = api.repo(Repo::with_revision(
184                    self.model_id.to_string(),
185                    RepoType::Model,
186                    revision.clone(),
187                ));
188                let model_id = std::path::Path::new(&self.model_id);
189
190                let weight = api_get_file!(api, "model.safetensors", &model_id);
191                let config = api_get_file!(api, "config.json", &model_id);
192                weights.push(weight);
193                config
194            };
195
196            // DAC model
197            {
198                let api = ApiBuilder::new()
199                    .with_progress(!silent)
200                    .with_token(get_token(&token_source)?)
201                    .build()?;
202                let revision = revision.unwrap_or("main".to_string());
203
204                // Apply default here
205                let dac_model = self
206                    .dac_model_id
207                    .clone()
208                    .unwrap_or_else(|| match self.arch {
209                        SpeechLoaderType::Dia => "EricB/dac_44khz".to_string(),
210                    });
211
212                let api = api.repo(Repo::with_revision(
213                    dac_model.clone(),
214                    RepoType::Model,
215                    revision.clone(),
216                ));
217                let model_id = std::path::Path::new(&dac_model);
218
219                let weight = api_get_file!(api, "model.safetensors", &model_id);
220                weights.push(weight);
221            }
222
223            Ok(Box::new(SpeechModelPaths { weights, config }))
224        };
225        self.load_model_from_path(
226            &paths?,
227            dtype,
228            device,
229            silent,
230            mapper,
231            in_situ_quant,
232            paged_attn_config,
233        )
234    }
235
236    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
237    fn load_model_from_path(
238        &self,
239        paths: &Box<dyn ModelPaths>,
240        dtype: &dyn TryIntoDType,
241        device: &Device,
242        silent: bool,
243        mapper: DeviceMapSetting,
244        in_situ_quant: Option<IsqType>,
245        _paged_attn_config: Option<PagedAttentionConfig>,
246    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
247        let _progress_guard = ProgressScopeGuard::new(silent);
248        let paths = &paths
249            .as_ref()
250            .as_any()
251            .downcast_ref::<SpeechModelPaths>()
252            .expect("Path downcast failed.");
253
254        if matches!(mapper, DeviceMapSetting::Map(_)) {
255            anyhow::bail!("Device mapping is not supported for speech models.")
256        }
257
258        mistralrs_quant::set_immediate_isq(in_situ_quant, vec![Regex::new(".*")?]);
259
260        let cfg: DiaConfig = serde_json::from_str(&std::fs::read_to_string(&paths.config)?)?;
261
262        #[cfg(feature = "cuda")]
263        if let Device::Cuda(dev) = &device {
264            unsafe { dev.disable_event_tracking() };
265        }
266        let use_nccl = mistralrs_quant::distributed::use_nccl();
267        let available_devices = if let Ok(payload) = env::var(distributed::IS_DAEMON_FLAG) {
268            let payload: WorkerTransferData = serde_json::from_str(&payload)?;
269            let WorkerTransferData::Init { id: _, worker_rank } = payload;
270            vec![candle_core::Device::new_cuda(worker_rank + 1)?]
271        } else if use_nccl {
272            vec![candle_core::Device::new_cuda(0)?]
273        } else {
274            device_map::get_all_similar_devices(device)?
275        };
276
277        let mapper =
278            DeviceMapSetting::dummy().into_mapper(usize::MAX, device, None, &available_devices)?;
279        let dtype = mapper.get_min_dtype(dtype)?;
280
281        // Last weight is the dac.
282        let model_weights = paths.weights[..paths.weights.len() - 1].to_vec();
283        let vb = from_mmaped_safetensors(
284            model_weights,
285            Vec::new(),
286            Some(dtype),
287            device,
288            vec![None],
289            silent,
290            None,
291            |_| true,
292            Arc::new(|_| DeviceForLoadTensor::Base),
293        )?;
294
295        let dac_vb = unsafe {
296            VarBuilder::from_mmaped_safetensors(&[paths.weights.last().unwrap()], dtype, device)?
297        };
298
299        // Only Dia is supported for now.
300        assert_eq!(self.arch, SpeechLoaderType::Dia);
301
302        let model = DiaPipeline::new(&cfg, vb, dac_vb)?;
303
304        Ok(Arc::new(Mutex::new(SpeechPipeline {
305            model_id: self.model_id.clone(),
306            model,
307            metadata: Arc::new(GeneralMetadata {
308                max_seq_len: 1024,
309                llg_factory: None,
310                is_xlora: false,
311                no_prefix_cache: false,
312                num_hidden_layers: 1, // FIXME(EricLBuehler): we know this is only for caching, so its OK.
313                eos_tok: vec![],
314                kind: ModelKind::Normal,
315                no_kv_cache: true, // NOTE(EricLBuehler): no cache for these.
316                activation_dtype: dtype,
317                sliding_window: None,
318                cache_config: None,
319                cache_engine: None,
320                model_metadata: None,
321                modalities: Modalities {
322                    input: vec![SupportedModality::Text],
323                    output: vec![SupportedModality::Audio],
324                },
325            }),
326            dummy_cache: EitherCache::Full(Cache::new(0, false)),
327            cfg: self
328                .cfg
329                .unwrap_or_else(|| SpeechGenerationConfig::default(self.arch)),
330        })))
331    }
332
333    fn get_id(&self) -> String {
334        self.model_id.clone()
335    }
336
337    fn get_kind(&self) -> ModelKind {
338        ModelKind::Normal
339    }
340}
341
342impl PreProcessingMixin for SpeechPipeline {
343    fn get_processor(&self) -> Arc<dyn Processor> {
344        Arc::new(SpeechProcessor)
345    }
346    fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
347        None
348    }
349    fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
350        None
351    }
352}
353
354impl IsqPipelineMixin for SpeechPipeline {
355    fn re_isq_model(&mut self, _dtype: IsqType) -> Result<()> {
356        anyhow::bail!("Speech models do not support ISQ for now.")
357    }
358}
359
360impl CacheManagerMixin for SpeechPipeline {
361    fn clone_in_cache(&self, _seqs: &mut [&mut Sequence]) {}
362    fn clone_out_cache(&self, _seqs: &mut [&mut Sequence]) {}
363    fn set_none_cache(
364        &self,
365        _seqs: &mut [&mut Sequence],
366        _reset_non_granular: bool,
367        _modify_draft_cache: bool,
368        _load_preallocated_cache: bool,
369    ) {
370    }
371    fn cache(&self) -> &EitherCache {
372        &self.dummy_cache
373    }
374}
375
376impl MetadataMixin for SpeechPipeline {
377    fn device(&self) -> Device {
378        self.model.device().clone()
379    }
380    fn get_metadata(&self) -> Arc<GeneralMetadata> {
381        self.metadata.clone()
382    }
383    fn name(&self) -> String {
384        self.model_id.clone()
385    }
386    fn reset_non_granular_state(&self) {}
387    fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
388        None
389    }
390    fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
391        None
392    }
393}
394
395#[async_trait::async_trait]
396impl Pipeline for SpeechPipeline {
397    fn forward_inputs(
398        &mut self,
399        inputs: Box<dyn Any>,
400        return_raw_logits: bool,
401    ) -> candle_core::Result<ForwardInputsResult> {
402        assert!(!return_raw_logits);
403
404        let ModelInputs { prompts } = *inputs.downcast().expect("Downcast failed.");
405        let mut pcms = Vec::new();
406        let mut rates = Vec::new();
407        let mut channels_all = Vec::new();
408        for prompt in prompts {
409            let SpeechGenerationOutput {
410                pcm,
411                rate,
412                channels,
413            } = self.model.generate(&prompt, &self.cfg)?;
414            pcms.push(pcm);
415            rates.push(rate);
416            channels_all.push(channels);
417        }
418
419        Ok(ForwardInputsResult::Speech {
420            pcms,
421            rates,
422            channels: channels_all,
423        })
424    }
425
426    async fn sample_causal_gen(
427        &self,
428        _seqs: &mut [&mut Sequence],
429        _logits: Vec<Tensor>,
430        _prefix_cacher: &mut PrefixCacheManagerV2,
431        _disable_eos_stop: bool,
432        _srng: Arc<std::sync::Mutex<Isaac64Rng>>,
433    ) -> Result<(), candle_core::Error> {
434        candle_core::bail!("`sample_causal_gen` is incompatible with `SpeechPipeline`");
435    }
436
437    fn category(&self) -> ModelCategory {
438        ModelCategory::Speech
439    }
440}
441
442impl AnyMoePipelineMixin for SpeechPipeline {}