use mistralrs_core::*;
use crate::model_builder_trait::{build_model_from_pipeline, build_speech_pipeline};
use crate::Model;
pub struct SpeechModelBuilder {
pub(crate) model_id: String,
pub(crate) dac_model_id: Option<String>,
pub(crate) token_source: TokenSource,
pub(crate) hf_revision: Option<String>,
pub(crate) cfg: Option<SpeechGenerationConfig>,
pub(crate) loader_type: SpeechLoaderType,
pub(crate) dtype: ModelDType,
pub(crate) force_cpu: bool,
pub(crate) max_num_seqs: usize,
pub(crate) with_logging: bool,
}
impl SpeechModelBuilder {
pub fn new(model_id: impl ToString, loader_type: SpeechLoaderType) -> Self {
Self {
model_id: model_id.to_string(),
loader_type,
dtype: ModelDType::Auto,
force_cpu: false,
token_source: TokenSource::CacheToken,
hf_revision: None,
max_num_seqs: 32,
with_logging: false,
cfg: None,
dac_model_id: None,
}
}
pub fn with_dac_model_id(mut self, dac_model_id: String) -> Self {
self.dac_model_id = Some(dac_model_id);
self
}
pub fn with_dtype(mut self, dtype: ModelDType) -> Self {
self.dtype = dtype;
self
}
pub fn with_force_cpu(mut self) -> Self {
self.force_cpu = true;
self
}
pub fn with_token_source(mut self, token_source: TokenSource) -> Self {
self.token_source = token_source;
self
}
pub fn with_hf_revision(mut self, revision: impl ToString) -> Self {
self.hf_revision = Some(revision.to_string());
self
}
pub fn with_max_num_seqs(mut self, max_num_seqs: usize) -> Self {
self.max_num_seqs = max_num_seqs;
self
}
pub fn with_logging(mut self) -> Self {
self.with_logging = true;
self
}
pub async fn build(self) -> anyhow::Result<Model> {
let (pipeline, scheduler_config, add_model_config) = build_speech_pipeline(self).await?;
Ok(build_model_from_pipeline(pipeline, scheduler_config, add_model_config).await)
}
}