use serde::{Deserialize, Serialize};
#[cfg(feature = "schema")]
use schemars::JsonSchema;
#[derive(Debug, Clone, Serialize, Deserialize)]
#[cfg_attr(feature = "schema", derive(JsonSchema))]
#[serde(tag = "type")]
pub enum PreprocessingStep {
MelSpectrogram {
#[serde(default)]
preset: Option<String>,
#[serde(default = "default_n_mels")]
n_mels: usize,
#[serde(default = "default_sample_rate")]
sample_rate: u32,
#[serde(default = "default_fft_size")]
fft_size: usize,
#[serde(default = "default_hop_length")]
hop_length: usize,
#[serde(default)]
mel_scale: MelScaleType,
#[serde(default = "default_max_frames")]
max_frames: Option<usize>,
},
Tokenize {
vocab_file: String,
tokenizer_type: TokenizerType,
#[serde(default)]
max_length: Option<usize>,
},
Normalize {
mean: Vec<f32>,
std: Vec<f32>,
},
Resize {
width: usize,
height: usize,
#[serde(default)]
interpolation: InterpolationMethod,
},
CenterCrop {
width: usize,
height: usize,
},
AudioDecode {
sample_rate: u32,
channels: usize,
},
Reshape {
shape: Vec<usize>,
},
PhonemeRaw {
#[serde(default)]
backend: PhonemizerBackend,
#[serde(default)]
language: Option<String>,
},
Phonemize {
tokens_file: String,
#[serde(default)]
backend: PhonemizerBackend,
#[serde(default)]
dict_file: Option<String>,
#[serde(default)]
language: Option<String>,
#[serde(default = "default_add_padding")]
add_padding: bool,
#[serde(default)]
normalize_text: bool,
#[serde(default)]
silence_tokens: Option<u8>,
},
}
impl PreprocessingStep {
pub fn step_name(&self) -> &'static str {
match self {
PreprocessingStep::MelSpectrogram { .. } => "MelSpectrogram",
PreprocessingStep::Tokenize { .. } => "Tokenize",
PreprocessingStep::Normalize { .. } => "Normalize",
PreprocessingStep::Resize { .. } => "Resize",
PreprocessingStep::CenterCrop { .. } => "CenterCrop",
PreprocessingStep::AudioDecode { .. } => "AudioDecode",
PreprocessingStep::Reshape { .. } => "Reshape",
PreprocessingStep::PhonemeRaw { .. } => "PhonemeRaw",
PreprocessingStep::Phonemize { .. } => "Phonemize",
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[cfg_attr(feature = "schema", derive(JsonSchema))]
#[serde(tag = "type")]
pub enum PostprocessingStep {
BPEDecode {
vocab_file: String,
},
Argmax {
#[serde(default)]
dim: Option<usize>,
},
Softmax {
#[serde(default)]
dim: Option<usize>,
},
TopK {
k: usize,
#[serde(default)]
dim: Option<usize>,
},
Threshold {
threshold: f32,
#[serde(default)]
return_indices: bool,
},
TemperatureSample {
temperature: f32,
#[serde(default)]
top_k: Option<usize>,
#[serde(default)]
top_p: Option<f32>,
},
Denormalize {
mean: Vec<f32>,
std: Vec<f32>,
},
MeanPool {
#[serde(default = "default_pool_dim")]
dim: usize,
},
CTCDecode {
vocab_file: String,
#[serde(default)]
blank_index: usize,
},
TTSAudioEncode {
sample_rate: u32,
#[serde(default = "default_tts_postprocess")]
apply_postprocessing: bool,
#[serde(default)]
trim_trailing_silence: bool,
},
WhisperDecode {
tokenizer_file: String,
},
CodecDecode {
decoder_model: String,
sample_rate: u32,
token_pattern: String,
#[serde(default = "default_tts_postprocess")]
apply_postprocessing: bool,
},
}
impl PostprocessingStep {
pub fn step_name(&self) -> &'static str {
match self {
PostprocessingStep::BPEDecode { .. } => "BPEDecode",
PostprocessingStep::Argmax { .. } => "Argmax",
PostprocessingStep::Softmax { .. } => "Softmax",
PostprocessingStep::TopK { .. } => "TopK",
PostprocessingStep::Threshold { .. } => "Threshold",
PostprocessingStep::TemperatureSample { .. } => "TemperatureSample",
PostprocessingStep::Denormalize { .. } => "Denormalize",
PostprocessingStep::MeanPool { .. } => "MeanPool",
PostprocessingStep::CTCDecode { .. } => "CTCDecode",
PostprocessingStep::TTSAudioEncode { .. } => "TTSAudioEncode",
PostprocessingStep::WhisperDecode { .. } => "WhisperDecode",
PostprocessingStep::CodecDecode { .. } => "CodecDecode",
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
#[cfg_attr(feature = "schema", derive(JsonSchema))]
pub enum PhonemizerBackend {
CmuDictionary,
EspeakNG,
#[default]
MisakiDictionary,
OpenPhonemizer,
}
impl PhonemizerBackend {
pub fn create(
&self,
base_path: &str,
dict_path: Option<&str>,
language: Option<&str>,
) -> Box<dyn crate::execution::preprocessing::backends::PhonemizerBackend> {
use crate::execution::preprocessing::backends::{
CmuDictionaryBackend, EspeakBackend, MisakiBackend, OpenPhonemizerBackend,
};
match self {
PhonemizerBackend::CmuDictionary => {
Box::new(CmuDictionaryBackend::new(dict_path.map(|s| s.to_string())))
}
PhonemizerBackend::MisakiDictionary => {
Box::new(MisakiBackend::new(base_path.to_string()))
}
PhonemizerBackend::EspeakNG => {
let lang = language.unwrap_or("en-us").to_string();
Box::new(EspeakBackend::new(lang))
}
PhonemizerBackend::OpenPhonemizer => {
Box::new(OpenPhonemizerBackend::new(base_path.to_string()))
}
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[cfg_attr(feature = "schema", derive(JsonSchema))]
#[serde(rename_all = "lowercase")]
pub enum MelScaleType {
#[default]
Slaney,
Htk,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[cfg_attr(feature = "schema", derive(JsonSchema))]
pub enum TokenizerType {
BPE,
WordPiece,
SentencePiece,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
#[cfg_attr(feature = "schema", derive(JsonSchema))]
pub enum InterpolationMethod {
Nearest,
#[default]
Bilinear,
Bicubic,
}
fn default_add_padding() -> bool {
true
}
fn default_n_mels() -> usize {
80
}
fn default_sample_rate() -> u32 {
16000
}
fn default_fft_size() -> usize {
400
}
fn default_hop_length() -> usize {
160
}
fn default_max_frames() -> Option<usize> {
Some(3000) }
fn default_pool_dim() -> usize {
1
}
fn default_tts_postprocess() -> bool {
true
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn phoneme_raw_serde_round_trip() {
let json = r#"{
"type": "PhonemeRaw",
"backend": "EspeakNG",
"language": "en-us"
}"#;
let step: PreprocessingStep = serde_json::from_str(json).unwrap();
match &step {
PreprocessingStep::PhonemeRaw { backend, language } => {
assert!(matches!(backend, PhonemizerBackend::EspeakNG));
assert_eq!(language.as_deref(), Some("en-us"));
}
_ => panic!("Expected PhonemeRaw variant"),
}
let serialized = serde_json::to_string(&step).unwrap();
let deserialized: PreprocessingStep = serde_json::from_str(&serialized).unwrap();
let reserialized = serde_json::to_string(&deserialized).unwrap();
assert_eq!(serialized, reserialized);
}
#[test]
fn phoneme_raw_defaults() {
let json = r#"{"type": "PhonemeRaw"}"#;
let step: PreprocessingStep = serde_json::from_str(json).unwrap();
match &step {
PreprocessingStep::PhonemeRaw { backend, language } => {
assert!(matches!(backend, PhonemizerBackend::MisakiDictionary));
assert!(language.is_none());
}
_ => panic!("Expected PhonemeRaw variant"),
}
}
#[test]
fn codec_decode_serde_round_trip() {
let json = r#"{
"type": "CodecDecode",
"decoder_model": "neucodec_mini_decoder.onnx",
"sample_rate": 24000,
"token_pattern": "<\\|speech_(\\d+)\\|>",
"apply_postprocessing": true
}"#;
let step: PostprocessingStep = serde_json::from_str(json).unwrap();
match &step {
PostprocessingStep::CodecDecode {
decoder_model,
sample_rate,
token_pattern,
apply_postprocessing,
} => {
assert_eq!(decoder_model, "neucodec_mini_decoder.onnx");
assert_eq!(*sample_rate, 24000);
assert_eq!(token_pattern, r"<\|speech_(\d+)\|>");
assert!(*apply_postprocessing);
}
_ => panic!("Expected CodecDecode variant"),
}
let serialized = serde_json::to_string(&step).unwrap();
let deserialized: PostprocessingStep = serde_json::from_str(&serialized).unwrap();
let reserialized = serde_json::to_string(&deserialized).unwrap();
assert_eq!(serialized, reserialized);
}
#[test]
fn codec_decode_default_apply_postprocessing() {
let json = r#"{
"type": "CodecDecode",
"decoder_model": "decoder.onnx",
"sample_rate": 24000,
"token_pattern": "<\\|speech_(\\d+)\\|>"
}"#;
let step: PostprocessingStep = serde_json::from_str(json).unwrap();
match &step {
PostprocessingStep::CodecDecode {
apply_postprocessing,
..
} => {
assert!(*apply_postprocessing);
}
_ => panic!("Expected CodecDecode variant"),
}
}
}