use derive_more::{Display, IsVariant};
use smol_str::format_smolstr;
use crate::{
array::Array,
dtype::Dtype,
error::{
CapExceededPayload, DtypeMismatchPayload, EmptyInputPayload, Error, LayerKeyedPayload,
RankMismatchPayload, Result,
},
ops,
};
use super::model::TtsModel;
pub const DEFAULT_VOICE: &str = "af_heart";
pub const DEFAULT_LANGUAGE: &str = "en";
pub const DEFAULT_TEMPERATURE: f32 = 0.7;
pub const DEFAULT_MAX_TOKENS: usize = 1200;
pub const DEFAULT_STREAMING_INTERVAL: f32 = 2.0;
pub const MAX_TEXT_BYTES: usize = 1024 * 1024;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Display, IsVariant)]
#[display("{}", self.as_str())]
pub enum AudioFormat {
#[default]
Wav,
Flac,
}
impl AudioFormat {
#[must_use]
pub const fn as_str(&self) -> &'static str {
match self {
Self::Wav => "wav",
Self::Flac => "flac",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Display, IsVariant)]
#[display("{}", self.as_str())]
pub enum TextSegmentation {
#[default]
Newlines,
Whole,
}
impl TextSegmentation {
#[must_use]
pub const fn as_str(&self) -> &'static str {
match self {
Self::Newlines => "newlines",
Self::Whole => "whole",
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct TtsGenConfig {
voice: String,
language: String,
speed: f32,
temperature: f32,
top_p: f32,
top_k: i32,
repetition_penalty: Option<f32>,
max_tokens: usize,
segmentation: TextSegmentation,
audio_format: AudioFormat,
streaming_interval: f32,
}
impl Default for TtsGenConfig {
fn default() -> Self {
Self {
voice: DEFAULT_VOICE.to_string(),
language: DEFAULT_LANGUAGE.to_string(),
speed: 1.0,
temperature: DEFAULT_TEMPERATURE,
top_p: 0.0,
top_k: 0,
repetition_penalty: None,
max_tokens: DEFAULT_MAX_TOKENS,
segmentation: TextSegmentation::Newlines,
audio_format: AudioFormat::Wav,
streaming_interval: DEFAULT_STREAMING_INTERVAL,
}
}
}
impl TtsGenConfig {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_voice(mut self, voice: impl Into<String>) -> Self {
self.voice = voice.into();
self
}
#[must_use]
pub fn with_language(mut self, language: impl Into<String>) -> Self {
self.language = language.into();
self
}
#[must_use]
pub fn with_speed(mut self, speed: f32) -> Self {
self.speed = speed;
self
}
#[must_use]
pub fn with_temperature(mut self, temperature: f32) -> Self {
self.temperature = temperature;
self
}
#[must_use]
pub fn with_top_p(mut self, top_p: f32) -> Self {
self.top_p = top_p;
self
}
#[must_use]
pub fn with_top_k(mut self, top_k: i32) -> Self {
self.top_k = top_k;
self
}
#[must_use]
pub fn with_repetition_penalty(mut self, repetition_penalty: Option<f32>) -> Self {
self.repetition_penalty = repetition_penalty;
self
}
#[must_use]
pub fn with_max_tokens(mut self, max_tokens: usize) -> Self {
self.max_tokens = max_tokens;
self
}
#[must_use]
pub fn with_segmentation(mut self, segmentation: TextSegmentation) -> Self {
self.segmentation = segmentation;
self
}
#[must_use]
pub fn with_audio_format(mut self, audio_format: AudioFormat) -> Self {
self.audio_format = audio_format;
self
}
#[must_use]
pub fn with_streaming_interval(mut self, streaming_interval: f32) -> Self {
self.streaming_interval = streaming_interval;
self
}
#[inline(always)]
#[must_use]
pub fn voice(&self) -> &str {
&self.voice
}
#[inline(always)]
#[must_use]
pub fn language(&self) -> &str {
&self.language
}
#[inline(always)]
#[must_use]
pub fn speed(&self) -> f32 {
self.speed
}
#[inline(always)]
#[must_use]
pub fn temperature(&self) -> f32 {
self.temperature
}
#[inline(always)]
#[must_use]
pub fn top_p(&self) -> f32 {
self.top_p
}
#[inline(always)]
#[must_use]
pub fn top_k(&self) -> i32 {
self.top_k
}
#[inline(always)]
#[must_use]
pub fn repetition_penalty(&self) -> Option<f32> {
self.repetition_penalty
}
#[inline(always)]
#[must_use]
pub fn max_tokens(&self) -> usize {
self.max_tokens
}
#[inline(always)]
#[must_use]
pub fn segmentation(&self) -> TextSegmentation {
self.segmentation
}
#[inline(always)]
#[must_use]
pub fn audio_format(&self) -> AudioFormat {
self.audio_format
}
#[inline(always)]
#[must_use]
pub fn streaming_interval(&self) -> f32 {
self.streaming_interval
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct TtsReference<'a> {
ref_audio: Option<&'a Array>,
ref_text: Option<&'a str>,
}
impl<'a> TtsReference<'a> {
#[must_use]
pub const fn new(ref_audio: Option<&'a Array>, ref_text: Option<&'a str>) -> Self {
Self {
ref_audio,
ref_text,
}
}
#[inline(always)]
#[must_use]
pub fn ref_audio(&self) -> Option<&'a Array> {
self.ref_audio
}
#[inline(always)]
#[must_use]
pub fn ref_text(&self) -> Option<&'a str> {
self.ref_text
}
}
#[derive(Debug, Clone, Copy)]
pub struct TtsSegment<'a> {
text: &'a str,
voice: &'a str,
language: &'a str,
speed: f32,
temperature: f32,
top_p: f32,
top_k: i32,
repetition_penalty: Option<f32>,
max_tokens: usize,
streaming_interval: f32,
segment_idx: usize,
ref_audio: Option<&'a Array>,
ref_text: Option<&'a str>,
}
impl<'a> TtsSegment<'a> {
#[allow(clippy::too_many_arguments)]
#[must_use]
pub const fn new(
text: &'a str,
voice: &'a str,
language: &'a str,
speed: f32,
temperature: f32,
top_p: f32,
top_k: i32,
repetition_penalty: Option<f32>,
max_tokens: usize,
streaming_interval: f32,
segment_idx: usize,
ref_audio: Option<&'a Array>,
ref_text: Option<&'a str>,
) -> Self {
Self {
text,
voice,
language,
speed,
temperature,
top_p,
top_k,
repetition_penalty,
max_tokens,
streaming_interval,
segment_idx,
ref_audio,
ref_text,
}
}
#[inline(always)]
#[must_use]
pub fn text(&self) -> &'a str {
self.text
}
#[inline(always)]
#[must_use]
pub fn voice(&self) -> &'a str {
self.voice
}
#[inline(always)]
#[must_use]
pub fn language(&self) -> &'a str {
self.language
}
#[inline(always)]
#[must_use]
pub fn speed(&self) -> f32 {
self.speed
}
#[inline(always)]
#[must_use]
pub fn temperature(&self) -> f32 {
self.temperature
}
#[inline(always)]
#[must_use]
pub fn top_p(&self) -> f32 {
self.top_p
}
#[inline(always)]
#[must_use]
pub fn top_k(&self) -> i32 {
self.top_k
}
#[inline(always)]
#[must_use]
pub fn repetition_penalty(&self) -> Option<f32> {
self.repetition_penalty
}
#[inline(always)]
#[must_use]
pub fn max_tokens(&self) -> usize {
self.max_tokens
}
#[inline(always)]
#[must_use]
pub fn streaming_interval(&self) -> f32 {
self.streaming_interval
}
#[inline(always)]
#[must_use]
pub fn segment_idx(&self) -> usize {
self.segment_idx
}
#[inline(always)]
#[must_use]
pub fn ref_audio(&self) -> Option<&'a Array> {
self.ref_audio
}
#[inline(always)]
#[must_use]
pub fn ref_text(&self) -> Option<&'a str> {
self.ref_text
}
}
#[derive(Debug)]
pub struct AudioChunk {
audio: Array,
sample_rate: u32,
segment_idx: usize,
is_streaming_chunk: bool,
is_final_chunk: bool,
}
impl AudioChunk {
#[must_use]
pub fn new(
audio: Array,
sample_rate: u32,
segment_idx: usize,
is_streaming_chunk: bool,
is_final_chunk: bool,
) -> Self {
Self {
audio,
sample_rate,
segment_idx,
is_streaming_chunk,
is_final_chunk,
}
}
#[inline(always)]
#[must_use]
pub fn audio_ref(&self) -> &Array {
&self.audio
}
#[inline(always)]
#[must_use]
pub fn sample_rate(&self) -> u32 {
self.sample_rate
}
#[inline(always)]
#[must_use]
pub fn segment_idx(&self) -> usize {
self.segment_idx
}
#[inline(always)]
#[must_use]
pub fn is_streaming_chunk(&self) -> bool {
self.is_streaming_chunk
}
#[inline(always)]
#[must_use]
pub fn is_final_chunk(&self) -> bool {
self.is_final_chunk
}
#[inline(always)]
#[must_use]
pub fn len_samples(&self) -> usize {
self.audio.shape().first().copied().unwrap_or(0)
}
#[inline(always)]
#[must_use]
pub fn is_empty(&self) -> bool {
self.len_samples() == 0
}
#[inline(always)]
#[must_use]
pub fn duration_seconds(&self) -> f64 {
if self.sample_rate == 0 {
return 0.0;
}
self.len_samples() as f64 / f64::from(self.sample_rate)
}
#[must_use]
pub fn into_audio(self) -> Array {
self.audio
}
pub fn samples(&mut self) -> Result<Vec<f32>> {
self.audio.to_vec::<f32>()
}
}
fn segment_ranges(text: &str, mode: TextSegmentation) -> Vec<(usize, usize)> {
match mode {
TextSegmentation::Whole => {
if text.trim().is_empty() {
Vec::new()
} else {
vec![(0, text.len())]
}
}
TextSegmentation::Newlines => {
let mut out = Vec::new();
let mut seg_start: Option<usize> = None;
for (i, ch) in text.char_indices() {
if ch == '\n' {
if let Some(start) = seg_start.take() {
push_if_nonblank(&mut out, text, start, i);
}
} else if seg_start.is_none() {
seg_start = Some(i);
}
}
if let Some(start) = seg_start {
push_if_nonblank(&mut out, text, start, text.len());
}
out
}
}
}
fn push_if_nonblank(out: &mut Vec<(usize, usize)>, text: &str, start: usize, end: usize) {
if !text[start..end].trim().is_empty() {
out.push((start, end));
}
}
pub struct TtsGenerator<'a, M> {
model: &'a M,
text: &'a str,
cfg: &'a TtsGenConfig,
reference: TtsReference<'a>,
segments: Vec<(usize, usize)>,
next_segment: usize,
done: bool,
}
impl<M: TtsModel> TtsGenerator<'_, M> {
#[must_use]
pub fn segment_count(&self) -> usize {
self.segments.len()
}
fn synthesize(&self, idx: usize) -> Result<AudioChunk> {
let (start, end) = self.segments[idx];
let segment = TtsSegment::new(
&self.text[start..end],
self.cfg.voice(),
self.cfg.language(),
self.cfg.speed(),
self.cfg.temperature(),
self.cfg.top_p(),
self.cfg.top_k(),
self.cfg.repetition_penalty(),
self.cfg.max_tokens(),
self.cfg.streaming_interval(),
idx,
self.reference.ref_audio(),
self.reference.ref_text(),
);
let audio = self.model.synthesize_segment(&segment)?;
let shape = audio.shape();
if shape.len() != 1 {
return Err(Error::LayerKeyed(LayerKeyedPayload::new(
format_smolstr!("tts_generate: segment {idx}"),
Error::RankMismatch(RankMismatchPayload::new(
"tts_generate: `synthesize_segment` must return a rank-1 [samples] audio tensor",
shape.len() as u32,
shape,
)),
)));
}
let dtype = audio.dtype()?;
if dtype != Dtype::F32 {
return Err(Error::DtypeMismatch(DtypeMismatchPayload::new(
Dtype::F32,
dtype,
)));
}
Ok(AudioChunk::new(
audio,
self.model.sample_rate(),
idx,
false,
idx + 1 == self.segments.len(),
))
}
}
impl<M: TtsModel> Iterator for TtsGenerator<'_, M> {
type Item = Result<AudioChunk>;
fn next(&mut self) -> Option<Self::Item> {
if self.done {
return None;
}
if self.next_segment >= self.segments.len() {
self.done = true;
return None;
}
let idx = self.next_segment;
match self.synthesize(idx) {
Ok(chunk) => {
self.next_segment += 1;
if self.next_segment >= self.segments.len() {
self.done = true;
}
Some(Ok(chunk))
}
Err(e) => {
self.done = true;
Some(Err(e))
}
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
let remaining = self.segments.len().saturating_sub(self.next_segment);
(0, Some(remaining))
}
}
pub fn tts_generate<'a, M: TtsModel>(
model: &'a M,
text: &'a str,
cfg: &'a TtsGenConfig,
) -> Result<TtsGenerator<'a, M>> {
tts_generate_with_reference(model, text, cfg, TtsReference::default())
}
pub fn tts_generate_with_reference<'a, M: TtsModel>(
model: &'a M,
text: &'a str,
cfg: &'a TtsGenConfig,
reference: TtsReference<'a>,
) -> Result<TtsGenerator<'a, M>> {
if text.len() > MAX_TEXT_BYTES {
return Err(Error::CapExceeded(CapExceededPayload::new(
"tts_generate: input text size (split the request into smaller calls)",
"MAX_TEXT_BYTES",
MAX_TEXT_BYTES as u64,
text.len() as u64,
)));
}
let segments = segment_ranges(text, cfg.segmentation());
if segments.is_empty() {
return Err(Error::EmptyInput(EmptyInputPayload::new(
"tts_generate: input text has no non-blank segments (provide non-empty text)",
)));
}
Ok(TtsGenerator {
model,
text,
cfg,
reference,
segments,
next_segment: 0,
done: false,
})
}
pub fn join_audio<M: TtsModel>(model: &M, text: &str, cfg: &TtsGenConfig) -> Result<Array> {
join_audio_with_reference(model, text, cfg, TtsReference::default())
}
pub fn join_audio_with_reference<M: TtsModel>(
model: &M,
text: &str,
cfg: &TtsGenConfig,
reference: TtsReference<'_>,
) -> Result<Array> {
let mut chunks: Vec<Array> = Vec::new();
for chunk in tts_generate_with_reference(model, text, cfg, reference)? {
chunks.push(chunk?.into_audio());
}
match chunks.len() {
1 => Ok(chunks.into_iter().next().expect("len checked == 1")),
_ => {
let refs: Vec<&Array> = chunks.iter().collect();
ops::shape::concatenate(&refs, 0)
}
}
}
#[cfg(test)]
mod tests;