use crate::band_correct::BandedCorrector;
use crate::denoise::SpectrumDenoiser;
use crate::ternary_gates::{GateMode, init_ternary_gates};
use anyhow::Result;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum GateLayout {
SingleSparse,
DualMelSpec,
AllForward,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum CorrectorKind {
Identity,
Affine,
BandNarrow,
BandWide,
Dense,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TernaryArchConfig {
pub gate_layout: GateLayout,
pub corrector: CorrectorKind,
pub target_compute_fraction: f32,
pub target_spec_compute_fraction: f32,
pub allow_reverse: bool,
}
impl Default for TernaryArchConfig {
fn default() -> Self {
Self {
gate_layout: GateLayout::SingleSparse,
corrector: CorrectorKind::BandWide,
target_compute_fraction: 0.96,
target_spec_compute_fraction: 1.0,
allow_reverse: true,
}
}
}
impl TernaryArchConfig {
pub fn fusion_only() -> Self {
Self {
gate_layout: GateLayout::AllForward,
corrector: CorrectorKind::Identity,
target_compute_fraction: 1.0,
target_spec_compute_fraction: 1.0,
allow_reverse: false,
}
}
pub fn dual_path_mel_sparse() -> Self {
Self {
gate_layout: GateLayout::DualMelSpec,
corrector: CorrectorKind::BandWide,
target_compute_fraction: 0.94,
target_spec_compute_fraction: 1.0,
allow_reverse: false,
}
}
pub fn single_sparse_band(target: f32) -> Self {
Self {
gate_layout: GateLayout::SingleSparse,
corrector: CorrectorKind::BandWide,
target_compute_fraction: target,
target_spec_compute_fraction: target,
allow_reverse: true,
}
}
pub fn single_sparse_affine(target: f32) -> Self {
Self {
gate_layout: GateLayout::SingleSparse,
corrector: CorrectorKind::Affine,
target_compute_fraction: target,
target_spec_compute_fraction: target,
allow_reverse: false,
}
}
pub fn single_sparse_dense(target: f32) -> Self {
Self {
gate_layout: GateLayout::SingleSparse,
corrector: CorrectorKind::Dense,
target_compute_fraction: target,
target_spec_compute_fraction: target,
allow_reverse: false,
}
}
}
#[derive(Debug, Clone)]
pub enum SpectrumCorrection {
Identity,
Affine(SpectrumDenoiser),
Band(BandedCorrector),
}
impl SpectrumCorrection {
pub fn from_kind(kind: CorrectorKind, n_fft: usize, mel_path: bool) -> Self {
match kind {
CorrectorKind::Identity => Self::Identity,
CorrectorKind::Affine => Self::Affine(SpectrumDenoiser::identity(n_fft)),
CorrectorKind::BandNarrow => Self::Band(BandedCorrector::identity(n_fft)),
CorrectorKind::BandWide => {
if mel_path {
Self::Band(BandedCorrector::identity(n_fft))
} else {
Self::Band(BandedCorrector::identity_spectrum(n_fft))
}
}
CorrectorKind::Dense => Self::Band(BandedCorrector::identity_dense(n_fft)),
}
}
pub fn apply_batch(&self, spectrum: &[f32], batch: usize, n_fft: usize) -> Result<Vec<f32>> {
match self {
Self::Identity => Ok(spectrum.to_vec()),
Self::Affine(d) => d.apply_batch(spectrum, batch, n_fft),
Self::Band(b) => b.apply_batch(spectrum, batch, n_fft),
}
}
pub fn train_step_mse(
&mut self,
input: &[f32],
target: &[f32],
batch: usize,
n_fft: usize,
lr: f32,
) -> Result<f32> {
match self {
Self::Identity => Ok(0.0),
Self::Affine(d) => d.train_step_affine(input, target, batch, n_fft, lr),
Self::Band(b) => b.train_step_mse(input, target, batch, n_fft, lr),
}
}
pub fn train_step_spectrum_grad(
&mut self,
input: &[f32],
spec_grad: &[f32],
batch: usize,
n_fft: usize,
lr: f32,
) {
if let Self::Band(b) = self {
b.train_step_spectrum_grad(input, spec_grad, batch, n_fft, lr);
}
}
pub fn dense_rhs_with_freq_mask(&self, freq_mask: &[f32]) -> Option<Vec<f32>> {
match self {
Self::Band(b) => Some(b.dense_rhs_with_freq_mask(freq_mask)),
_ => None,
}
}
pub fn bias(&self) -> Option<Vec<f32>> {
match self {
Self::Affine(d) => Some(d.bias.clone()),
Self::Band(b) => Some(b.bias.clone()),
Self::Identity => None,
}
}
pub fn affine_gain_bias(&self) -> Option<(Vec<f32>, Vec<f32>)> {
match self {
Self::Affine(d) => Some((d.scale.clone(), d.bias.clone())),
_ => None,
}
}
}
pub fn all_forward_gates(n_fft: usize) -> Vec<i8> {
init_ternary_gates(n_fft)
}
pub fn strip_reverse_gates(gates: &mut [i8]) {
for g in gates.iter_mut() {
if *g == GateMode::Reverse.to_i8() {
*g = GateMode::Forward.to_i8();
}
}
}
pub fn sync_spec_gates_for_layout(
gate_layout: GateLayout,
mel_gates: &[i8],
spec_gates: &mut [i8],
) {
assert_eq!(mel_gates.len(), spec_gates.len());
match gate_layout {
GateLayout::DualMelSpec => {
spec_gates.fill(GateMode::Forward.to_i8());
}
GateLayout::AllForward | GateLayout::SingleSparse => {
spec_gates.copy_from_slice(mel_gates);
}
}
}