use std::path::PathBuf;
use crate::clustering::plda::PldaTransform;
use crate::inference::ExecutionMode;
use crate::inference::embedding::EmbeddingModel;
use crate::inference::segmentation::SegmentationModel;
use crate::models::ModelBundle;
use crate::powerset::PowersetMapping;
use super::OwnedDiarizationPipeline;
use super::config::{PipelineConfig, RuntimeConfig, segmentation_step_seconds};
use super::queued::{QueueReceiver, QueueSender};
use super::types::PipelineError;
pub struct PipelineBuilder {
bundle: ModelBundle,
mode: ExecutionMode,
runtime: Option<RuntimeConfig>,
pipeline: Option<PipelineConfig>,
}
impl PipelineBuilder {
pub fn from_dir(models_dir: impl Into<PathBuf>, mode: ExecutionMode) -> Self {
Self {
bundle: ModelBundle::from_dir(models_dir),
mode,
runtime: None,
pipeline: None,
}
}
pub fn from_bundle(bundle: ModelBundle, mode: ExecutionMode) -> Self {
Self {
bundle,
mode,
runtime: None,
pipeline: None,
}
}
#[cfg(feature = "online")]
pub fn from_pretrained(mode: ExecutionMode) -> Result<Self, PipelineError> {
mode.validate()?;
let bundle = ModelBundle::from_pretrained(mode)?;
Ok(Self::from_bundle(bundle, mode))
}
pub fn runtime(mut self, config: RuntimeConfig) -> Self {
self.runtime = Some(config);
self
}
pub fn pipeline(mut self, config: PipelineConfig) -> Self {
self.pipeline = Some(config);
self
}
pub fn build(self) -> Result<OwnedDiarizationPipeline, PipelineError> {
self.mode.validate()?;
let pipeline = self
.pipeline
.unwrap_or_else(|| PipelineConfig::for_mode(self.mode));
let runtime = self.runtime.unwrap_or_default();
let step = segmentation_step_seconds(self.mode);
let seg_model =
SegmentationModel::with_mode(self.bundle.segmentation_path(), step as f32, self.mode)?;
let emb_model = EmbeddingModel::with_mode_and_config(
self.bundle.embedding_path(),
self.mode,
&runtime,
)?;
let plda = PldaTransform::from_dir(self.bundle.plda_dir())?;
Ok(OwnedDiarizationPipeline {
seg_model,
emb_model,
plda,
powerset: PowersetMapping::new(3, 2),
default_config: pipeline,
})
}
pub fn build_queued(self) -> Result<(QueueSender, QueueReceiver), PipelineError> {
let pipeline = self.build()?;
Ok(pipeline.into_queued()?)
}
}