1use super::text_models_inputs_processor::PagedAttentionMeta;
2use super::{
3 AdapterPaths, AnyMoePipelineMixin, Cache, CacheManagerMixin, EitherCache, ForwardInputsResult,
4 GeneralMetadata, InputProcessorOutput, InputsProcessor, InputsProcessorType, IsqPipelineMixin,
5 Loader, MessagesAction, MetadataMixin, ModelCategory, ModelKind, ModelPaths,
6 PreProcessingMixin, Processor, TokenSource,
7};
8use crate::device_map::{self, DeviceMapper};
9use crate::distributed::{use_ring, WorkerTransferData};
10use crate::pipeline::{ChatTemplate, EmbeddingModulePaths, Modalities, SupportedModality};
11use crate::prefix_cacher::PrefixCacheManagerV2;
12use crate::sequence::Sequence;
13use crate::speech_models::{DiaConfig, DiaPipeline, SpeechGenerationOutput, SpeechLoaderType};
14use crate::utils::progress::ProgressScopeGuard;
15use crate::utils::varbuilder_utils::DeviceForLoadTensor;
16use crate::utils::{tokens::get_token, varbuilder_utils::from_mmaped_safetensors};
17use crate::{
18 api_get_file, distributed, DeviceMapSetting, MessageContent, PagedAttentionConfig, Pipeline,
19 SpeechGenerationConfig, TryIntoDType,
20};
21use anyhow::Result;
22use hanzo_ml::{Device, Tensor};
23use hanzo_nn::VarBuilder;
24use hanzo_quant::IsqType;
25use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
26use indexmap::IndexMap;
27use rand_isaac::Isaac64Rng;
28use regex::Regex;
29use std::any::Any;
30use std::env;
31use std::path::PathBuf;
32use std::sync::Arc;
33use tokenizers::Tokenizer;
34use tokio::sync::Mutex;
35
36#[derive(Clone, Debug)]
37pub struct SpeechModelPaths {
38 weights: Vec<PathBuf>,
39 config: PathBuf,
40}
41
42impl ModelPaths for SpeechModelPaths {
43 fn get_config_filename(&self) -> &PathBuf {
44 &self.config
45 }
46 fn get_tokenizer_filename(&self) -> &PathBuf {
47 unreachable!("Use `std::any::Any`.")
48 }
49 fn get_weight_filenames(&self) -> &[PathBuf] {
50 &self.weights
51 }
52 fn get_template_filename(&self) -> &Option<PathBuf> {
53 unreachable!("Use `std::any::Any`.")
54 }
55 fn get_gen_conf_filename(&self) -> Option<&PathBuf> {
56 unreachable!("Use `std::any::Any`.")
57 }
58 fn get_preprocessor_config(&self) -> &Option<PathBuf> {
59 unreachable!("Use `std::any::Any`.")
60 }
61 fn get_processor_config(&self) -> &Option<PathBuf> {
62 unreachable!("Use `std::any::Any`.")
63 }
64 fn get_chat_template_explicit(&self) -> &Option<PathBuf> {
65 unreachable!("Use `std::any::Any`.")
66 }
67 fn get_adapter_paths(&self) -> &AdapterPaths {
68 unreachable!("Use `std::any::Any`.")
69 }
70 fn get_modules(&self) -> Option<&[EmbeddingModulePaths]> {
71 unreachable!("Use `std::any::Any`.")
72 }
73}
74
75pub struct SpeechProcessor;
76
77impl Processor for SpeechProcessor {
78 fn process(
79 &self,
80 _pipeline: &dyn Pipeline,
81 _messages: Vec<IndexMap<String, MessageContent>>,
82 _add_generation_prompt: bool,
83 _add_special_tokens: bool,
84 _enable_thinking: Option<bool>,
85 _reasoning_effort: Option<crate::request::ReasoningEffort>,
86 _tools: Vec<crate::Tool>,
87 ) -> Result<(Vec<u32>, String)> {
88 anyhow::bail!(
89 "SpeechProcessor::process should not be used. It does not expect chat messages."
90 )
91 }
92 fn inputs_processor(&self) -> Arc<dyn InputsProcessor> {
93 Arc::new(SpeechInputsProcessor)
94 }
95 fn get_special_tokens(&self) -> &[&'static str] {
96 &[]
97 }
98 fn template_action(&self) -> MessagesAction {
99 MessagesAction::FlattenOnlyText
101 }
102}
103
104pub struct SpeechInputsProcessor;
105
106#[derive(Clone)]
107pub struct ModelInputs {
108 pub(crate) prompts: Vec<String>,
109}
110
111impl InputsProcessor for SpeechInputsProcessor {
112 fn get_type(&self) -> InputsProcessorType {
113 InputsProcessorType::Text
114 }
115
116 fn process_inputs(
117 &self,
118 _tokenizer: Option<Arc<Tokenizer>>,
119 input_seqs: &mut [&mut Sequence],
120 _is_prompt: bool,
121 _is_xlora: bool,
122 _device: &Device,
123 _no_kv_cache: bool,
124 _last_n_context_len: Option<(usize, usize)>,
125 _return_raw_logits: bool,
126 _sliding_window: Option<usize>,
127 _other_config: Option<Arc<dyn Any>>,
128 _paged_attn_metadata: Option<PagedAttentionMeta>,
129 _mapper: Option<&dyn DeviceMapper>,
130 ) -> Result<InputProcessorOutput> {
131 let inputs = ModelInputs {
132 prompts: input_seqs
133 .iter()
134 .map(|seq| seq.get_initial_prompt().to_string())
135 .collect(),
136 };
137 Ok(InputProcessorOutput {
138 inputs: Box::new(inputs),
139 seq_indices: (0..input_seqs.len()).collect::<Vec<_>>(),
140 })
141 }
142}
143
144pub struct SpeechPipeline {
145 model_id: String,
146 model: DiaPipeline,
147 metadata: Arc<GeneralMetadata>,
148 dummy_cache: EitherCache,
149 cfg: SpeechGenerationConfig,
150}
151
152pub struct SpeechLoader {
153 pub model_id: String,
154 pub dac_model_id: Option<String>,
155 pub arch: SpeechLoaderType,
156 pub cfg: Option<SpeechGenerationConfig>,
157}
158
159impl Loader for SpeechLoader {
160 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
161 fn load_model_from_hf(
162 &self,
163 revision: Option<String>,
164 token_source: TokenSource,
165 dtype: &dyn TryIntoDType,
166 device: &Device,
167 silent: bool,
168 mapper: DeviceMapSetting,
169 in_situ_quant: Option<IsqType>,
170 paged_attn_config: Option<PagedAttentionConfig>,
171 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
172 let _progress_guard = ProgressScopeGuard::new(silent);
173 let paths: anyhow::Result<Box<dyn ModelPaths>> = {
174 let mut weights = Vec::new();
176
177 let config = {
179 let api = ApiBuilder::new()
180 .with_progress(!silent)
181 .with_token(get_token(&token_source)?)
182 .build()?;
183 let revision = revision.clone().unwrap_or("main".to_string());
184 let api = api.repo(Repo::with_revision(
185 self.model_id.to_string(),
186 RepoType::Model,
187 revision.clone(),
188 ));
189 let model_id = std::path::Path::new(&self.model_id);
190
191 let weight = api_get_file!(api, "model.safetensors", &model_id, &revision);
192 let config = api_get_file!(api, "config.json", &model_id, &revision);
193 weights.push(weight);
194 config
195 };
196
197 {
199 let api = ApiBuilder::new()
200 .with_progress(!silent)
201 .with_token(get_token(&token_source)?)
202 .build()?;
203 let revision = revision.unwrap_or("main".to_string());
204
205 let dac_model = self
207 .dac_model_id
208 .clone()
209 .unwrap_or_else(|| match self.arch {
210 SpeechLoaderType::Dia => "hanzoai/dac_44khz".to_string(),
211 });
212
213 let api = api.repo(Repo::with_revision(
214 dac_model.clone(),
215 RepoType::Model,
216 revision.clone(),
217 ));
218 let model_id = std::path::Path::new(&dac_model);
219
220 let weight = api_get_file!(api, "model.safetensors", &model_id, &revision);
221 weights.push(weight);
222 }
223
224 Ok(Box::new(SpeechModelPaths { weights, config }))
225 };
226 self.load_model_from_path(
227 &paths?,
228 dtype,
229 device,
230 silent,
231 mapper,
232 in_situ_quant,
233 paged_attn_config,
234 )
235 }
236
237 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
238 fn load_model_from_path(
239 &self,
240 paths: &Box<dyn ModelPaths>,
241 dtype: &dyn TryIntoDType,
242 device: &Device,
243 silent: bool,
244 mapper: DeviceMapSetting,
245 in_situ_quant: Option<IsqType>,
246 _paged_attn_config: Option<PagedAttentionConfig>,
247 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
248 let _progress_guard = ProgressScopeGuard::new(silent);
249 let paths = &paths
250 .as_ref()
251 .as_any()
252 .downcast_ref::<SpeechModelPaths>()
253 .expect("Path downcast failed.");
254
255 if matches!(mapper, DeviceMapSetting::Map(_)) {
256 anyhow::bail!("Device mapping is not supported for speech models.")
257 }
258
259 hanzo_quant::set_immediate_isq(in_situ_quant, vec![Regex::new(".*")?]);
260
261 let cfg: DiaConfig = serde_json::from_str(&std::fs::read_to_string(&paths.config)?)?;
262
263 #[cfg(feature = "cuda")]
264 if let Device::Cuda(dev) = &device {
265 unsafe { dev.disable_event_tracking() };
266 }
267 let use_nccl = hanzo_quant::distributed::use_nccl();
268 let available_devices = if let Ok(payload) = env::var(distributed::IS_DAEMON_FLAG) {
269 let payload: WorkerTransferData = serde_json::from_str(&payload)?;
270 let WorkerTransferData::Init { id: _, worker_rank } = payload;
271 vec![hanzo_ml::Device::new_cuda(worker_rank + 1)?]
272 } else if use_nccl || use_ring() {
273 vec![hanzo_ml::Device::new_cuda(0)?]
274 } else {
275 device_map::get_all_similar_devices(device)?
276 };
277
278 let mapper =
279 DeviceMapSetting::dummy().into_mapper(usize::MAX, device, None, &available_devices)?;
280 let dtype = mapper.get_min_dtype(dtype)?;
281
282 let model_weights = paths.weights[..paths.weights.len() - 1].to_vec();
284 let vb = from_mmaped_safetensors(
285 model_weights,
286 Vec::new(),
287 Some(dtype),
288 device,
289 vec![None],
290 silent,
291 None,
292 |_| true,
293 Arc::new(|_| DeviceForLoadTensor::Base),
294 )?;
295
296 let dac_vb = unsafe {
297 VarBuilder::from_mmaped_safetensors(&[paths.weights.last().unwrap()], dtype, device)?
298 };
299
300 assert_eq!(self.arch, SpeechLoaderType::Dia);
302
303 let model = DiaPipeline::new(&cfg, vb, dac_vb)?;
304
305 Ok(Arc::new(Mutex::new(SpeechPipeline {
306 model_id: self.model_id.clone(),
307 model,
308 metadata: Arc::new(GeneralMetadata {
309 max_seq_len: 1024,
310 llg_factory: None,
311 is_xlora: false,
312 no_prefix_cache: false,
313 num_hidden_layers: 1, eos_tok: vec![],
315 kind: ModelKind::Normal,
316 no_kv_cache: true, activation_dtype: dtype,
318 sliding_window: None,
319 cache_config: None,
320 cache_engine: None,
321 model_metadata: None,
322 modalities: Modalities {
323 input: vec![SupportedModality::Text],
324 output: vec![SupportedModality::Audio],
325 },
326 }),
327 dummy_cache: EitherCache::Full(Cache::new(0, false)),
328 cfg: self
329 .cfg
330 .unwrap_or_else(|| SpeechGenerationConfig::default(self.arch)),
331 })))
332 }
333
334 fn get_id(&self) -> String {
335 self.model_id.clone()
336 }
337
338 fn get_kind(&self) -> ModelKind {
339 ModelKind::Normal
340 }
341}
342
343impl PreProcessingMixin for SpeechPipeline {
344 fn get_processor(&self) -> Arc<dyn Processor> {
345 Arc::new(SpeechProcessor)
346 }
347 fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
348 None
349 }
350 fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
351 None
352 }
353}
354
355impl IsqPipelineMixin for SpeechPipeline {
356 fn re_isq_model(&mut self, _dtype: IsqType) -> Result<()> {
357 anyhow::bail!("Speech models do not support ISQ for now.")
358 }
359}
360
361impl CacheManagerMixin for SpeechPipeline {
362 fn clone_in_cache(&self, _seqs: &mut [&mut Sequence]) {}
363 fn clone_out_cache(&self, _seqs: &mut [&mut Sequence]) {}
364 fn set_none_cache(
365 &self,
366 _seqs: &mut [&mut Sequence],
367 _reset_non_granular: bool,
368 _modify_draft_cache: bool,
369 _load_preallocated_cache: bool,
370 ) {
371 }
372 fn cache(&self) -> &EitherCache {
373 &self.dummy_cache
374 }
375}
376
377impl MetadataMixin for SpeechPipeline {
378 fn device(&self) -> Device {
379 self.model.device().clone()
380 }
381 fn get_metadata(&self) -> Arc<GeneralMetadata> {
382 self.metadata.clone()
383 }
384 fn name(&self) -> String {
385 self.model_id.clone()
386 }
387 fn reset_non_granular_state(&self) {}
388 fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
389 None
390 }
391 fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
392 None
393 }
394}
395
396#[async_trait::async_trait]
397impl Pipeline for SpeechPipeline {
398 fn forward_inputs(
399 &mut self,
400 inputs: Box<dyn Any>,
401 return_raw_logits: bool,
402 ) -> hanzo_ml::Result<ForwardInputsResult> {
403 assert!(!return_raw_logits);
404
405 let ModelInputs { prompts } = *inputs.downcast().expect("Downcast failed.");
406 let mut pcms = Vec::new();
407 let mut rates = Vec::new();
408 let mut channels_all = Vec::new();
409 for prompt in prompts {
410 let SpeechGenerationOutput {
411 pcm,
412 rate,
413 channels,
414 } = self.model.generate(&prompt, &self.cfg)?;
415 pcms.push(pcm);
416 rates.push(rate);
417 channels_all.push(channels);
418 }
419
420 Ok(ForwardInputsResult::Speech {
421 pcms,
422 rates,
423 channels: channels_all,
424 })
425 }
426
427 async fn sample_causal_gen(
428 &self,
429 _seqs: &mut [&mut Sequence],
430 _logits: Vec<Tensor>,
431 _prefix_cacher: &mut PrefixCacheManagerV2,
432 _disable_eos_stop: bool,
433 _srng: Arc<std::sync::Mutex<Isaac64Rng>>,
434 ) -> Result<(), hanzo_ml::Error> {
435 hanzo_ml::bail!("`sample_causal_gen` is incompatible with `SpeechPipeline`");
436 }
437
438 fn category(&self) -> ModelCategory {
439 ModelCategory::Speech
440 }
441}
442
443impl AnyMoePipelineMixin for SpeechPipeline {}