Skip to main content

hanzo_engine/pipeline/
speech.rs

1use super::text_models_inputs_processor::PagedAttentionMeta;
2use super::{
3    AdapterPaths, AnyMoePipelineMixin, Cache, CacheManagerMixin, EitherCache, ForwardInputsResult,
4    GeneralMetadata, InputProcessorOutput, InputsProcessor, InputsProcessorType, IsqPipelineMixin,
5    Loader, MessagesAction, MetadataMixin, ModelCategory, ModelKind, ModelPaths,
6    PreProcessingMixin, Processor, TokenSource,
7};
8use crate::device_map::{self, DeviceMapper};
9use crate::distributed::{use_ring, WorkerTransferData};
10use crate::pipeline::{ChatTemplate, EmbeddingModulePaths, Modalities, SupportedModality};
11use crate::prefix_cacher::PrefixCacheManagerV2;
12use crate::sequence::Sequence;
13use crate::speech_models::{DiaConfig, DiaPipeline, SpeechGenerationOutput, SpeechLoaderType};
14use crate::utils::progress::ProgressScopeGuard;
15use crate::utils::varbuilder_utils::DeviceForLoadTensor;
16use crate::utils::{tokens::get_token, varbuilder_utils::from_mmaped_safetensors};
17use crate::{
18    api_get_file, distributed, DeviceMapSetting, MessageContent, PagedAttentionConfig, Pipeline,
19    SpeechGenerationConfig, TryIntoDType,
20};
21use anyhow::Result;
22use hanzo_ml::{Device, Tensor};
23use hanzo_nn::VarBuilder;
24use hanzo_quant::IsqType;
25use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
26use indexmap::IndexMap;
27use rand_isaac::Isaac64Rng;
28use regex::Regex;
29use std::any::Any;
30use std::env;
31use std::path::PathBuf;
32use std::sync::Arc;
33use tokenizers::Tokenizer;
34use tokio::sync::Mutex;
35
36#[derive(Clone, Debug)]
37pub struct SpeechModelPaths {
38    weights: Vec<PathBuf>,
39    config: PathBuf,
40}
41
42impl ModelPaths for SpeechModelPaths {
43    fn get_config_filename(&self) -> &PathBuf {
44        &self.config
45    }
46    fn get_tokenizer_filename(&self) -> &PathBuf {
47        unreachable!("Use `std::any::Any`.")
48    }
49    fn get_weight_filenames(&self) -> &[PathBuf] {
50        &self.weights
51    }
52    fn get_template_filename(&self) -> &Option<PathBuf> {
53        unreachable!("Use `std::any::Any`.")
54    }
55    fn get_gen_conf_filename(&self) -> Option<&PathBuf> {
56        unreachable!("Use `std::any::Any`.")
57    }
58    fn get_preprocessor_config(&self) -> &Option<PathBuf> {
59        unreachable!("Use `std::any::Any`.")
60    }
61    fn get_processor_config(&self) -> &Option<PathBuf> {
62        unreachable!("Use `std::any::Any`.")
63    }
64    fn get_chat_template_explicit(&self) -> &Option<PathBuf> {
65        unreachable!("Use `std::any::Any`.")
66    }
67    fn get_adapter_paths(&self) -> &AdapterPaths {
68        unreachable!("Use `std::any::Any`.")
69    }
70    fn get_modules(&self) -> Option<&[EmbeddingModulePaths]> {
71        unreachable!("Use `std::any::Any`.")
72    }
73}
74
75pub struct SpeechProcessor;
76
77impl Processor for SpeechProcessor {
78    fn process(
79        &self,
80        _pipeline: &dyn Pipeline,
81        _messages: Vec<IndexMap<String, MessageContent>>,
82        _add_generation_prompt: bool,
83        _add_special_tokens: bool,
84        _enable_thinking: Option<bool>,
85        _reasoning_effort: Option<crate::request::ReasoningEffort>,
86        _tools: Vec<crate::Tool>,
87    ) -> Result<(Vec<u32>, String)> {
88        anyhow::bail!(
89            "SpeechProcessor::process should not be used. It does not expect chat messages."
90        )
91    }
92    fn inputs_processor(&self) -> Arc<dyn InputsProcessor> {
93        Arc::new(SpeechInputsProcessor)
94    }
95    fn get_special_tokens(&self) -> &[&'static str] {
96        &[]
97    }
98    fn template_action(&self) -> MessagesAction {
99        // Just a default
100        MessagesAction::FlattenOnlyText
101    }
102}
103
104pub struct SpeechInputsProcessor;
105
106#[derive(Clone)]
107pub struct ModelInputs {
108    pub(crate) prompts: Vec<String>,
109}
110
111impl InputsProcessor for SpeechInputsProcessor {
112    fn get_type(&self) -> InputsProcessorType {
113        InputsProcessorType::Text
114    }
115
116    fn process_inputs(
117        &self,
118        _tokenizer: Option<Arc<Tokenizer>>,
119        input_seqs: &mut [&mut Sequence],
120        _is_prompt: bool,
121        _is_xlora: bool,
122        _device: &Device,
123        _no_kv_cache: bool,
124        _last_n_context_len: Option<(usize, usize)>,
125        _return_raw_logits: bool,
126        _sliding_window: Option<usize>,
127        _other_config: Option<Arc<dyn Any>>,
128        _paged_attn_metadata: Option<PagedAttentionMeta>,
129        _mapper: Option<&dyn DeviceMapper>,
130    ) -> Result<InputProcessorOutput> {
131        let inputs = ModelInputs {
132            prompts: input_seqs
133                .iter()
134                .map(|seq| seq.get_initial_prompt().to_string())
135                .collect(),
136        };
137        Ok(InputProcessorOutput {
138            inputs: Box::new(inputs),
139            seq_indices: (0..input_seqs.len()).collect::<Vec<_>>(),
140        })
141    }
142}
143
144pub struct SpeechPipeline {
145    model_id: String,
146    model: DiaPipeline,
147    metadata: Arc<GeneralMetadata>,
148    dummy_cache: EitherCache,
149    cfg: SpeechGenerationConfig,
150}
151
152pub struct SpeechLoader {
153    pub model_id: String,
154    pub dac_model_id: Option<String>,
155    pub arch: SpeechLoaderType,
156    pub cfg: Option<SpeechGenerationConfig>,
157}
158
159impl Loader for SpeechLoader {
160    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
161    fn load_model_from_hf(
162        &self,
163        revision: Option<String>,
164        token_source: TokenSource,
165        dtype: &dyn TryIntoDType,
166        device: &Device,
167        silent: bool,
168        mapper: DeviceMapSetting,
169        in_situ_quant: Option<IsqType>,
170        paged_attn_config: Option<PagedAttentionConfig>,
171    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
172        let _progress_guard = ProgressScopeGuard::new(silent);
173        let paths: anyhow::Result<Box<dyn ModelPaths>> = {
174            // Main weights first, DAC is the final one.
175            let mut weights = Vec::new();
176
177            // Main model
178            let config = {
179                let api = ApiBuilder::new()
180                    .with_progress(!silent)
181                    .with_token(get_token(&token_source)?)
182                    .build()?;
183                let revision = revision.clone().unwrap_or("main".to_string());
184                let api = api.repo(Repo::with_revision(
185                    self.model_id.to_string(),
186                    RepoType::Model,
187                    revision.clone(),
188                ));
189                let model_id = std::path::Path::new(&self.model_id);
190
191                let weight = api_get_file!(api, "model.safetensors", &model_id, &revision);
192                let config = api_get_file!(api, "config.json", &model_id, &revision);
193                weights.push(weight);
194                config
195            };
196
197            // DAC model
198            {
199                let api = ApiBuilder::new()
200                    .with_progress(!silent)
201                    .with_token(get_token(&token_source)?)
202                    .build()?;
203                let revision = revision.unwrap_or("main".to_string());
204
205                // Apply default here
206                let dac_model = self
207                    .dac_model_id
208                    .clone()
209                    .unwrap_or_else(|| match self.arch {
210                        SpeechLoaderType::Dia => "hanzoai/dac_44khz".to_string(),
211                    });
212
213                let api = api.repo(Repo::with_revision(
214                    dac_model.clone(),
215                    RepoType::Model,
216                    revision.clone(),
217                ));
218                let model_id = std::path::Path::new(&dac_model);
219
220                let weight = api_get_file!(api, "model.safetensors", &model_id, &revision);
221                weights.push(weight);
222            }
223
224            Ok(Box::new(SpeechModelPaths { weights, config }))
225        };
226        self.load_model_from_path(
227            &paths?,
228            dtype,
229            device,
230            silent,
231            mapper,
232            in_situ_quant,
233            paged_attn_config,
234        )
235    }
236
237    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
238    fn load_model_from_path(
239        &self,
240        paths: &Box<dyn ModelPaths>,
241        dtype: &dyn TryIntoDType,
242        device: &Device,
243        silent: bool,
244        mapper: DeviceMapSetting,
245        in_situ_quant: Option<IsqType>,
246        _paged_attn_config: Option<PagedAttentionConfig>,
247    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
248        let _progress_guard = ProgressScopeGuard::new(silent);
249        let paths = &paths
250            .as_ref()
251            .as_any()
252            .downcast_ref::<SpeechModelPaths>()
253            .expect("Path downcast failed.");
254
255        if matches!(mapper, DeviceMapSetting::Map(_)) {
256            anyhow::bail!("Device mapping is not supported for speech models.")
257        }
258
259        hanzo_quant::set_immediate_isq(in_situ_quant, vec![Regex::new(".*")?]);
260
261        let cfg: DiaConfig = serde_json::from_str(&std::fs::read_to_string(&paths.config)?)?;
262
263        #[cfg(feature = "cuda")]
264        if let Device::Cuda(dev) = &device {
265            unsafe { dev.disable_event_tracking() };
266        }
267        let use_nccl = hanzo_quant::distributed::use_nccl();
268        let available_devices = if let Ok(payload) = env::var(distributed::IS_DAEMON_FLAG) {
269            let payload: WorkerTransferData = serde_json::from_str(&payload)?;
270            let WorkerTransferData::Init { id: _, worker_rank } = payload;
271            vec![hanzo_ml::Device::new_cuda(worker_rank + 1)?]
272        } else if use_nccl || use_ring() {
273            vec![hanzo_ml::Device::new_cuda(0)?]
274        } else {
275            device_map::get_all_similar_devices(device)?
276        };
277
278        let mapper =
279            DeviceMapSetting::dummy().into_mapper(usize::MAX, device, None, &available_devices)?;
280        let dtype = mapper.get_min_dtype(dtype)?;
281
282        // Last weight is the dac.
283        let model_weights = paths.weights[..paths.weights.len() - 1].to_vec();
284        let vb = from_mmaped_safetensors(
285            model_weights,
286            Vec::new(),
287            Some(dtype),
288            device,
289            vec![None],
290            silent,
291            None,
292            |_| true,
293            Arc::new(|_| DeviceForLoadTensor::Base),
294        )?;
295
296        let dac_vb = unsafe {
297            VarBuilder::from_mmaped_safetensors(&[paths.weights.last().unwrap()], dtype, device)?
298        };
299
300        // Only Dia is supported for now.
301        assert_eq!(self.arch, SpeechLoaderType::Dia);
302
303        let model = DiaPipeline::new(&cfg, vb, dac_vb)?;
304
305        Ok(Arc::new(Mutex::new(SpeechPipeline {
306            model_id: self.model_id.clone(),
307            model,
308            metadata: Arc::new(GeneralMetadata {
309                max_seq_len: 1024,
310                llg_factory: None,
311                is_xlora: false,
312                no_prefix_cache: false,
313                num_hidden_layers: 1, // FIXME(hanzoai): we know this is only for caching, so its OK.
314                eos_tok: vec![],
315                kind: ModelKind::Normal,
316                no_kv_cache: true, // NOTE(hanzoai): no cache for these.
317                activation_dtype: dtype,
318                sliding_window: None,
319                cache_config: None,
320                cache_engine: None,
321                model_metadata: None,
322                modalities: Modalities {
323                    input: vec![SupportedModality::Text],
324                    output: vec![SupportedModality::Audio],
325                },
326            }),
327            dummy_cache: EitherCache::Full(Cache::new(0, false)),
328            cfg: self
329                .cfg
330                .unwrap_or_else(|| SpeechGenerationConfig::default(self.arch)),
331        })))
332    }
333
334    fn get_id(&self) -> String {
335        self.model_id.clone()
336    }
337
338    fn get_kind(&self) -> ModelKind {
339        ModelKind::Normal
340    }
341}
342
343impl PreProcessingMixin for SpeechPipeline {
344    fn get_processor(&self) -> Arc<dyn Processor> {
345        Arc::new(SpeechProcessor)
346    }
347    fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
348        None
349    }
350    fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
351        None
352    }
353}
354
355impl IsqPipelineMixin for SpeechPipeline {
356    fn re_isq_model(&mut self, _dtype: IsqType) -> Result<()> {
357        anyhow::bail!("Speech models do not support ISQ for now.")
358    }
359}
360
361impl CacheManagerMixin for SpeechPipeline {
362    fn clone_in_cache(&self, _seqs: &mut [&mut Sequence]) {}
363    fn clone_out_cache(&self, _seqs: &mut [&mut Sequence]) {}
364    fn set_none_cache(
365        &self,
366        _seqs: &mut [&mut Sequence],
367        _reset_non_granular: bool,
368        _modify_draft_cache: bool,
369        _load_preallocated_cache: bool,
370    ) {
371    }
372    fn cache(&self) -> &EitherCache {
373        &self.dummy_cache
374    }
375}
376
377impl MetadataMixin for SpeechPipeline {
378    fn device(&self) -> Device {
379        self.model.device().clone()
380    }
381    fn get_metadata(&self) -> Arc<GeneralMetadata> {
382        self.metadata.clone()
383    }
384    fn name(&self) -> String {
385        self.model_id.clone()
386    }
387    fn reset_non_granular_state(&self) {}
388    fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
389        None
390    }
391    fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
392        None
393    }
394}
395
396#[async_trait::async_trait]
397impl Pipeline for SpeechPipeline {
398    fn forward_inputs(
399        &mut self,
400        inputs: Box<dyn Any>,
401        return_raw_logits: bool,
402    ) -> hanzo_ml::Result<ForwardInputsResult> {
403        assert!(!return_raw_logits);
404
405        let ModelInputs { prompts } = *inputs.downcast().expect("Downcast failed.");
406        let mut pcms = Vec::new();
407        let mut rates = Vec::new();
408        let mut channels_all = Vec::new();
409        for prompt in prompts {
410            let SpeechGenerationOutput {
411                pcm,
412                rate,
413                channels,
414            } = self.model.generate(&prompt, &self.cfg)?;
415            pcms.push(pcm);
416            rates.push(rate);
417            channels_all.push(channels);
418        }
419
420        Ok(ForwardInputsResult::Speech {
421            pcms,
422            rates,
423            channels: channels_all,
424        })
425    }
426
427    async fn sample_causal_gen(
428        &self,
429        _seqs: &mut [&mut Sequence],
430        _logits: Vec<Tensor>,
431        _prefix_cacher: &mut PrefixCacheManagerV2,
432        _disable_eos_stop: bool,
433        _srng: Arc<std::sync::Mutex<Isaac64Rng>>,
434    ) -> Result<(), hanzo_ml::Error> {
435        hanzo_ml::bail!("`sample_causal_gen` is incompatible with `SpeechPipeline`");
436    }
437
438    fn category(&self) -> ModelCategory {
439        ModelCategory::Speech
440    }
441}
442
443impl AnyMoePipelineMixin for SpeechPipeline {}