use std::path::Path;
use std::sync::Arc;
use crate::constants::DB;
use crate::duration::DurationEstimator;
use crate::label::{LabelError, ToLabels};
use crate::mlpg_adjust::MlpgAdjust;
use crate::model::interporation_weight::InterporationWeight;
use crate::model::{ModelError, Models, VoiceSet};
use crate::speech::SpeechGenerator;
use crate::vocoder::Vocoder;
#[derive(Debug, thiserror::Error)]
pub enum EngineError {
#[error("Model error: {0}")]
ModelError(#[from] ModelError),
#[error("Failed to parse option {0}")]
ParseOptionError(String),
#[error("Label error: {0}")]
LabelError(#[from] LabelError),
}
#[derive(Debug, Clone)]
pub struct Condition {
sampling_frequency: usize,
fperiod: usize,
volume: f64,
msd_threshold: Vec<f64>,
gv_weight: Vec<f64>,
phoneme_alignment_flag: bool,
speed: f64,
stage: usize,
use_log_gain: bool,
alpha: f64,
beta: f64,
additional_half_tone: f64,
interporation_weight: InterporationWeight,
}
impl Default for Condition {
fn default() -> Self {
Self {
sampling_frequency: 0,
fperiod: 0,
volume: 1.0f64,
msd_threshold: Vec::new(),
gv_weight: Vec::new(),
speed: 1.0f64,
phoneme_alignment_flag: false,
stage: 0,
use_log_gain: false,
alpha: 0.0f64,
beta: 0.0f64,
additional_half_tone: 0.0f64,
interporation_weight: InterporationWeight::default(),
}
}
}
impl Condition {
pub fn load_model(&mut self, voices: &VoiceSet) -> Result<(), EngineError> {
let metadata = voices.global_metadata();
let nstream = metadata.num_streams;
self.sampling_frequency = metadata.sampling_frequency;
self.fperiod = metadata.frame_period;
self.msd_threshold = [0.5].repeat(nstream);
self.gv_weight = [1.0].repeat(nstream);
for option in &voices.stream_metadata(0).option {
let Some((key, value)) = option.split_once('=') else {
eprintln!("Skipped unrecognized option {option}.");
continue;
};
match key {
"GAMMA" => {
self.stage = value
.parse()
.map_err(|_| EngineError::ParseOptionError(key.to_string()))?
}
"LN_GAIN" => match value {
"1" => self.use_log_gain = true,
"0" => self.use_log_gain = false,
_ => return Err(EngineError::ParseOptionError(key.to_string())),
},
"ALPHA" => {
self.alpha = value
.parse()
.map_err(|_| EngineError::ParseOptionError(key.to_string()))?
}
_ => eprintln!("Skipped unrecognized option {option}."),
}
}
self.interporation_weight = InterporationWeight::new(voices.len(), nstream);
Ok(())
}
pub fn set_sampling_frequency(&mut self, i: usize) {
self.sampling_frequency = i.max(1);
}
pub fn get_sampling_frequency(&self) -> usize {
self.sampling_frequency
}
pub fn set_fperiod(&mut self, i: usize) {
self.fperiod = i.max(1);
}
pub fn get_fperiod(&self) -> usize {
self.fperiod
}
pub fn set_volume(&mut self, f: f64) {
self.volume = (f * DB).exp();
}
pub fn get_volume(&self) -> f64 {
self.volume.ln() / DB
}
pub fn set_msd_threshold(&mut self, stream_index: usize, f: f64) {
self.msd_threshold[stream_index] = f.clamp(0.0, 1.0);
}
pub fn get_msd_threshold(&self, stream_index: usize) -> f64 {
self.msd_threshold[stream_index]
}
pub fn set_gv_weight(&mut self, stream_index: usize, f: f64) {
self.gv_weight[stream_index] = f.max(0.0);
}
pub fn get_gv_weight(&self, stream_index: usize) -> f64 {
self.gv_weight[stream_index]
}
pub fn set_speed(&mut self, f: f64) {
self.speed = f.max(1.0E-06);
}
pub fn get_speed(&self) -> f64 {
self.speed
}
pub fn set_phoneme_alignment_flag(&mut self, b: bool) {
self.phoneme_alignment_flag = b;
}
pub fn get_phoneme_alignment_flag(&self) -> bool {
self.phoneme_alignment_flag
}
pub fn set_alpha(&mut self, f: f64) {
self.alpha = f.clamp(0.0, 1.0);
}
pub fn get_alpha(&self) -> f64 {
self.alpha
}
pub fn set_beta(&mut self, f: f64) {
self.beta = f.clamp(0.0, 1.0);
}
pub fn get_beta(&self) -> f64 {
self.beta
}
pub fn set_additional_half_tone(&mut self, f: f64) {
self.additional_half_tone = f;
}
pub fn get_additional_half_tone(&self) -> f64 {
self.additional_half_tone
}
pub fn get_interporation_weight(&self) -> &InterporationWeight {
&self.interporation_weight
}
pub fn get_interporation_weight_mut(&mut self) -> &mut InterporationWeight {
&mut self.interporation_weight
}
}
#[derive(Debug, Clone)]
pub struct Engine {
pub condition: Condition,
pub voices: VoiceSet,
}
impl Engine {
#[cfg(feature = "htsvoice")]
pub fn load<P: AsRef<Path>>(voices: impl IntoIterator<Item = P>) -> Result<Self, EngineError> {
Self::load_from_result_bytes(voices.into_iter().map(std::fs::read))
}
#[cfg(feature = "htsvoice")]
pub fn load_from_bytes<B: AsRef<[u8]>>(
voices: impl IntoIterator<Item = B>,
) -> Result<Self, EngineError> {
Self::load_from_result_bytes(voices.into_iter().map(Ok))
}
#[cfg(feature = "htsvoice")]
fn load_from_result_bytes<B: AsRef<[u8]>>(
voices: impl IntoIterator<Item = std::io::Result<B>>,
) -> Result<Self, EngineError> {
use crate::model::load_htsvoice_from_bytes;
let voices = voices
.into_iter()
.map(|bytes| Ok(Arc::new(load_htsvoice_from_bytes(bytes?.as_ref())?)))
.collect::<Result<Vec<_>, ModelError>>()?;
let voiceset = VoiceSet::new(voices)?;
let mut condition = Condition::default();
condition.load_model(&voiceset)?;
Ok(Self::new(voiceset, condition))
}
pub fn new(voices: VoiceSet, condition: Condition) -> Self {
Engine { voices, condition }
}
pub fn synthesize(&self, labels: impl ToLabels) -> Result<Vec<f64>, EngineError> {
Ok(self.generator(labels)?.generate_all())
}
pub fn generator(&self, labels: impl ToLabels) -> Result<SpeechGenerator, EngineError> {
let labels = labels.to_labels(&self.condition)?;
let vocoder = Vocoder::new(
self.voices.stream_metadata(0).vector_length,
self.voices.stream_metadata(2).vector_length,
self.condition.stage,
self.condition.use_log_gain,
self.condition.sampling_frequency,
self.condition.alpha,
self.condition.beta,
self.condition.volume,
self.condition.fperiod,
);
let models = Models::new(
labels.labels(),
&self.voices,
&self.condition.interporation_weight,
);
let estimator = DurationEstimator::new(models.duration(), models.nstate());
let durations = if self.condition.phoneme_alignment_flag {
estimator.create_with_alignment(labels.times())
} else {
estimator.create(self.condition.speed)
};
fn mutated<T, F: FnOnce(&mut T)>(mut value: T, f: F) -> T {
f(&mut value);
value
}
let spectrum = MlpgAdjust::new(
self.condition.gv_weight[0],
self.condition.msd_threshold[0],
models.model_stream(0),
)
.create(&durations);
let lf0 = MlpgAdjust::new(
self.condition.gv_weight[1],
self.condition.msd_threshold[1],
mutated(models.model_stream(1), |m| {
m.stream
.apply_additional_half_tone(self.condition.additional_half_tone);
}),
)
.create(&durations);
let lpf = if self.voices.global_metadata().num_streams > 2 {
MlpgAdjust::new(
self.condition.gv_weight[2],
self.condition.msd_threshold[2],
models.model_stream(2),
)
.create(&durations)
} else {
vec![vec![0.0; 0]; lf0.len()]
};
Ok(SpeechGenerator::new(
self.condition.fperiod,
vocoder,
spectrum,
lf0,
lpf,
))
}
}