1use super::text_models_inputs_processor::PagedAttentionMeta;
2use super::{
3 AdapterPaths, AnyMoePipelineMixin, Cache, CacheManagerMixin, EitherCache, ForwardInputsResult,
4 GeneralMetadata, InputProcessorOutput, InputsProcessor, InputsProcessorType, IsqPipelineMixin,
5 Loader, MessagesAction, MetadataMixin, ModelCategory, ModelKind, ModelPaths,
6 PreProcessingMixin, Processor, TokenSource,
7};
8use crate::device_map::{self, DeviceMapper};
9use crate::distributed::WorkerTransferData;
10use crate::pipeline::{ChatTemplate, EmbeddingModulePaths, Modalities, SupportedModality};
11use crate::prefix_cacher::PrefixCacheManagerV2;
12use crate::sequence::Sequence;
13use crate::speech_models::{DiaConfig, DiaPipeline, SpeechGenerationOutput, SpeechLoaderType};
14use crate::utils::progress::ProgressScopeGuard;
15use crate::utils::varbuilder_utils::DeviceForLoadTensor;
16use crate::utils::{tokens::get_token, varbuilder_utils::from_mmaped_safetensors};
17use crate::{
18 api_get_file, distributed, DeviceMapSetting, MessageContent, PagedAttentionConfig, Pipeline,
19 SpeechGenerationConfig, TryIntoDType,
20};
21use anyhow::Result;
22use candle_core::{Device, Tensor};
23use candle_nn::VarBuilder;
24use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
25use indexmap::IndexMap;
26use mistralrs_quant::IsqType;
27use rand_isaac::Isaac64Rng;
28use regex::Regex;
29use std::any::Any;
30use std::env;
31use std::path::PathBuf;
32use std::sync::Arc;
33use tokenizers::Tokenizer;
34use tokio::sync::Mutex;
35
36#[derive(Clone, Debug)]
37pub struct SpeechModelPaths {
38 weights: Vec<PathBuf>,
39 config: PathBuf,
40}
41
42impl ModelPaths for SpeechModelPaths {
43 fn get_config_filename(&self) -> &PathBuf {
44 &self.config
45 }
46 fn get_tokenizer_filename(&self) -> &PathBuf {
47 unreachable!("Use `std::any::Any`.")
48 }
49 fn get_weight_filenames(&self) -> &[PathBuf] {
50 &self.weights
51 }
52 fn get_template_filename(&self) -> &Option<PathBuf> {
53 unreachable!("Use `std::any::Any`.")
54 }
55 fn get_gen_conf_filename(&self) -> Option<&PathBuf> {
56 unreachable!("Use `std::any::Any`.")
57 }
58 fn get_preprocessor_config(&self) -> &Option<PathBuf> {
59 unreachable!("Use `std::any::Any`.")
60 }
61 fn get_processor_config(&self) -> &Option<PathBuf> {
62 unreachable!("Use `std::any::Any`.")
63 }
64 fn get_chat_template_explicit(&self) -> &Option<PathBuf> {
65 unreachable!("Use `std::any::Any`.")
66 }
67 fn get_adapter_paths(&self) -> &AdapterPaths {
68 unreachable!("Use `std::any::Any`.")
69 }
70 fn get_modules(&self) -> Option<&[EmbeddingModulePaths]> {
71 unreachable!("Use `std::any::Any`.")
72 }
73}
74
75pub struct SpeechProcessor;
76
77impl Processor for SpeechProcessor {
78 fn process(
79 &self,
80 _pipeline: &dyn Pipeline,
81 _messages: Vec<IndexMap<String, MessageContent>>,
82 _add_generation_prompt: bool,
83 _add_special_tokens: bool,
84 _enable_thinking: Option<bool>,
85 _reasoning_effort: Option<crate::request::ReasoningEffort>,
86 _tools: Vec<crate::Tool>,
87 ) -> Result<(Vec<u32>, String)> {
88 anyhow::bail!(
89 "SpeechProcessor::process should not be used. It does not expect chat messages."
90 )
91 }
92 fn inputs_processor(&self) -> Arc<dyn InputsProcessor> {
93 Arc::new(SpeechInputsProcessor)
94 }
95 fn get_special_tokens(&self) -> &[&'static str] {
96 &[]
97 }
98 fn template_action(&self) -> MessagesAction {
99 MessagesAction::FlattenOnlyText
101 }
102}
103
104pub struct SpeechInputsProcessor;
105
106#[derive(Clone)]
107pub struct ModelInputs {
108 pub(crate) prompts: Vec<String>,
109}
110
111impl InputsProcessor for SpeechInputsProcessor {
112 fn get_type(&self) -> InputsProcessorType {
113 InputsProcessorType::Text
114 }
115
116 fn process_inputs(
117 &self,
118 _tokenizer: Option<Arc<Tokenizer>>,
119 input_seqs: &mut [&mut Sequence],
120 _is_prompt: bool,
121 _is_xlora: bool,
122 _device: &Device,
123 _no_kv_cache: bool,
124 _last_n_context_len: Option<(usize, usize)>,
125 _return_raw_logits: bool,
126 _other_config: Option<Arc<dyn Any>>,
127 _paged_attn_metadata: Option<PagedAttentionMeta>,
128 _mapper: Option<&dyn DeviceMapper>,
129 ) -> Result<InputProcessorOutput> {
130 let inputs = ModelInputs {
131 prompts: input_seqs
132 .iter()
133 .map(|seq| seq.get_initial_prompt().to_string())
134 .collect(),
135 };
136 Ok(InputProcessorOutput {
137 inputs: Box::new(inputs),
138 seq_indices: (0..input_seqs.len()).collect::<Vec<_>>(),
139 })
140 }
141}
142
143pub struct SpeechPipeline {
144 model_id: String,
145 model: DiaPipeline,
146 metadata: Arc<GeneralMetadata>,
147 dummy_cache: EitherCache,
148 cfg: SpeechGenerationConfig,
149}
150
151pub struct SpeechLoader {
152 pub model_id: String,
153 pub dac_model_id: Option<String>,
154 pub arch: SpeechLoaderType,
155 pub cfg: Option<SpeechGenerationConfig>,
156}
157
158impl Loader for SpeechLoader {
159 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
160 fn load_model_from_hf(
161 &self,
162 revision: Option<String>,
163 token_source: TokenSource,
164 dtype: &dyn TryIntoDType,
165 device: &Device,
166 silent: bool,
167 mapper: DeviceMapSetting,
168 in_situ_quant: Option<IsqType>,
169 paged_attn_config: Option<PagedAttentionConfig>,
170 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
171 let _progress_guard = ProgressScopeGuard::new(silent);
172 let paths: anyhow::Result<Box<dyn ModelPaths>> = {
173 let mut weights = Vec::new();
175
176 let config = {
178 let api = ApiBuilder::new()
179 .with_progress(!silent)
180 .with_token(get_token(&token_source)?)
181 .build()?;
182 let revision = revision.clone().unwrap_or("main".to_string());
183 let api = api.repo(Repo::with_revision(
184 self.model_id.to_string(),
185 RepoType::Model,
186 revision.clone(),
187 ));
188 let model_id = std::path::Path::new(&self.model_id);
189
190 let weight = api_get_file!(api, "model.safetensors", &model_id);
191 let config = api_get_file!(api, "config.json", &model_id);
192 weights.push(weight);
193 config
194 };
195
196 {
198 let api = ApiBuilder::new()
199 .with_progress(!silent)
200 .with_token(get_token(&token_source)?)
201 .build()?;
202 let revision = revision.unwrap_or("main".to_string());
203
204 let dac_model = self
206 .dac_model_id
207 .clone()
208 .unwrap_or_else(|| match self.arch {
209 SpeechLoaderType::Dia => "EricB/dac_44khz".to_string(),
210 });
211
212 let api = api.repo(Repo::with_revision(
213 dac_model.clone(),
214 RepoType::Model,
215 revision.clone(),
216 ));
217 let model_id = std::path::Path::new(&dac_model);
218
219 let weight = api_get_file!(api, "model.safetensors", &model_id);
220 weights.push(weight);
221 }
222
223 Ok(Box::new(SpeechModelPaths { weights, config }))
224 };
225 self.load_model_from_path(
226 &paths?,
227 dtype,
228 device,
229 silent,
230 mapper,
231 in_situ_quant,
232 paged_attn_config,
233 )
234 }
235
236 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
237 fn load_model_from_path(
238 &self,
239 paths: &Box<dyn ModelPaths>,
240 dtype: &dyn TryIntoDType,
241 device: &Device,
242 silent: bool,
243 mapper: DeviceMapSetting,
244 in_situ_quant: Option<IsqType>,
245 _paged_attn_config: Option<PagedAttentionConfig>,
246 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
247 let _progress_guard = ProgressScopeGuard::new(silent);
248 let paths = &paths
249 .as_ref()
250 .as_any()
251 .downcast_ref::<SpeechModelPaths>()
252 .expect("Path downcast failed.");
253
254 if matches!(mapper, DeviceMapSetting::Map(_)) {
255 anyhow::bail!("Device mapping is not supported for speech models.")
256 }
257
258 mistralrs_quant::set_immediate_isq(in_situ_quant, vec![Regex::new(".*")?]);
259
260 let cfg: DiaConfig = serde_json::from_str(&std::fs::read_to_string(&paths.config)?)?;
261
262 #[cfg(feature = "cuda")]
263 if let Device::Cuda(dev) = &device {
264 unsafe { dev.disable_event_tracking() };
265 }
266 let use_nccl = mistralrs_quant::distributed::use_nccl();
267 let available_devices = if let Ok(payload) = env::var(distributed::IS_DAEMON_FLAG) {
268 let payload: WorkerTransferData = serde_json::from_str(&payload)?;
269 let WorkerTransferData::Init { id: _, worker_rank } = payload;
270 vec![candle_core::Device::new_cuda(worker_rank + 1)?]
271 } else if use_nccl {
272 vec![candle_core::Device::new_cuda(0)?]
273 } else {
274 device_map::get_all_similar_devices(device)?
275 };
276
277 let mapper =
278 DeviceMapSetting::dummy().into_mapper(usize::MAX, device, None, &available_devices)?;
279 let dtype = mapper.get_min_dtype(dtype)?;
280
281 let model_weights = paths.weights[..paths.weights.len() - 1].to_vec();
283 let vb = from_mmaped_safetensors(
284 model_weights,
285 Vec::new(),
286 Some(dtype),
287 device,
288 vec![None],
289 silent,
290 None,
291 |_| true,
292 Arc::new(|_| DeviceForLoadTensor::Base),
293 )?;
294
295 let dac_vb = unsafe {
296 VarBuilder::from_mmaped_safetensors(&[paths.weights.last().unwrap()], dtype, device)?
297 };
298
299 assert_eq!(self.arch, SpeechLoaderType::Dia);
301
302 let model = DiaPipeline::new(&cfg, vb, dac_vb)?;
303
304 Ok(Arc::new(Mutex::new(SpeechPipeline {
305 model_id: self.model_id.clone(),
306 model,
307 metadata: Arc::new(GeneralMetadata {
308 max_seq_len: 1024,
309 llg_factory: None,
310 is_xlora: false,
311 no_prefix_cache: false,
312 num_hidden_layers: 1, eos_tok: vec![],
314 kind: ModelKind::Normal,
315 no_kv_cache: true, activation_dtype: dtype,
317 sliding_window: None,
318 cache_config: None,
319 cache_engine: None,
320 model_metadata: None,
321 modalities: Modalities {
322 input: vec![SupportedModality::Text],
323 output: vec![SupportedModality::Audio],
324 },
325 }),
326 dummy_cache: EitherCache::Full(Cache::new(0, false)),
327 cfg: self
328 .cfg
329 .unwrap_or_else(|| SpeechGenerationConfig::default(self.arch)),
330 })))
331 }
332
333 fn get_id(&self) -> String {
334 self.model_id.clone()
335 }
336
337 fn get_kind(&self) -> ModelKind {
338 ModelKind::Normal
339 }
340}
341
342impl PreProcessingMixin for SpeechPipeline {
343 fn get_processor(&self) -> Arc<dyn Processor> {
344 Arc::new(SpeechProcessor)
345 }
346 fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
347 None
348 }
349 fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
350 None
351 }
352}
353
354impl IsqPipelineMixin for SpeechPipeline {
355 fn re_isq_model(&mut self, _dtype: IsqType) -> Result<()> {
356 anyhow::bail!("Speech models do not support ISQ for now.")
357 }
358}
359
360impl CacheManagerMixin for SpeechPipeline {
361 fn clone_in_cache(&self, _seqs: &mut [&mut Sequence]) {}
362 fn clone_out_cache(&self, _seqs: &mut [&mut Sequence]) {}
363 fn set_none_cache(
364 &self,
365 _seqs: &mut [&mut Sequence],
366 _reset_non_granular: bool,
367 _modify_draft_cache: bool,
368 _load_preallocated_cache: bool,
369 ) {
370 }
371 fn cache(&self) -> &EitherCache {
372 &self.dummy_cache
373 }
374}
375
376impl MetadataMixin for SpeechPipeline {
377 fn device(&self) -> Device {
378 self.model.device().clone()
379 }
380 fn get_metadata(&self) -> Arc<GeneralMetadata> {
381 self.metadata.clone()
382 }
383 fn name(&self) -> String {
384 self.model_id.clone()
385 }
386 fn reset_non_granular_state(&self) {}
387 fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
388 None
389 }
390 fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
391 None
392 }
393}
394
395#[async_trait::async_trait]
396impl Pipeline for SpeechPipeline {
397 fn forward_inputs(
398 &mut self,
399 inputs: Box<dyn Any>,
400 return_raw_logits: bool,
401 ) -> candle_core::Result<ForwardInputsResult> {
402 assert!(!return_raw_logits);
403
404 let ModelInputs { prompts } = *inputs.downcast().expect("Downcast failed.");
405 let mut pcms = Vec::new();
406 let mut rates = Vec::new();
407 let mut channels_all = Vec::new();
408 for prompt in prompts {
409 let SpeechGenerationOutput {
410 pcm,
411 rate,
412 channels,
413 } = self.model.generate(&prompt, &self.cfg)?;
414 pcms.push(pcm);
415 rates.push(rate);
416 channels_all.push(channels);
417 }
418
419 Ok(ForwardInputsResult::Speech {
420 pcms,
421 rates,
422 channels: channels_all,
423 })
424 }
425
426 async fn sample_causal_gen(
427 &self,
428 _seqs: &mut [&mut Sequence],
429 _logits: Vec<Tensor>,
430 _prefix_cacher: &mut PrefixCacheManagerV2,
431 _disable_eos_stop: bool,
432 _srng: Arc<std::sync::Mutex<Isaac64Rng>>,
433 ) -> Result<(), candle_core::Error> {
434 candle_core::bail!("`sample_causal_gen` is incompatible with `SpeechPipeline`");
435 }
436
437 fn category(&self) -> ModelCategory {
438 ModelCategory::Speech
439 }
440}
441
442impl AnyMoePipelineMixin for SpeechPipeline {}