use std::borrow::Cow;
use std::io::Cursor;
use async_trait::async_trait;
use super::normalize::{collapse_punctuation, normalise_markdown_for_tts, strip_emojis_for_tts};
use super::ssml::{apply_ssml_hints, strip_voice_markers};
use super::{Result, VoiceError};
pub const EDGE_AUDIO_FORMAT: &str = "audio-24khz-48kbitrate-mono-mp3";
pub const VOICE_NOTE_MIME: &str = "audio/ogg; codecs=opus";
#[derive(Debug, Clone)]
pub struct VoiceNote {
pub audio_bytes: Vec<u8>,
pub mimetype: &'static str,
pub transcript: String,
}
#[async_trait]
pub trait TtsProvider: Send + Sync {
async fn synthesize_raw(&self, body: &str, voice_id: &str) -> Result<Vec<u8>>;
}
#[derive(Debug, Clone)]
pub struct EdgeTtsProvider {
pub rate: i32,
pub audio_format: String,
}
impl Default for EdgeTtsProvider {
fn default() -> Self {
Self {
rate: -8,
audio_format: EDGE_AUDIO_FORMAT.to_string(),
}
}
}
#[async_trait]
impl TtsProvider for EdgeTtsProvider {
async fn synthesize_raw(&self, body: &str, voice_id: &str) -> Result<Vec<u8>> {
if voice_id.trim().is_empty() {
return Err(VoiceError::EmptyVoiceId);
}
let mp3 = match call_edge(body, voice_id, &self.audio_format, self.rate).await {
Ok(bytes) if !bytes.is_empty() => bytes,
Ok(_empty) => {
let plain = strip_ssml_tags(body);
tracing::warn!(
ssml_body_len = body.len(),
plain_body_len = plain.len(),
ssml_body = %body,
"voice: edge returned 0 bytes; retrying with plain text",
);
match call_edge(&plain, voice_id, &self.audio_format, self.rate).await {
Ok(bytes) if !bytes.is_empty() => {
tracing::warn!(
"voice: plain-text fallback succeeded — SSML body was rejected by edge",
);
bytes
}
Ok(_) => return Err(VoiceError::EmptySynthesis),
Err(e) => return Err(e),
}
}
Err(e) => return Err(e),
};
Ok(mp3)
}
}
async fn call_edge(body: &str, voice: &str, audio_format: &str, rate: i32) -> Result<Vec<u8>> {
let body_owned = body.to_string();
let voice_owned = voice.to_string();
let format_owned = audio_format.to_string();
tokio::task::spawn_blocking(move || -> Result<Vec<u8>> {
let mut client = msedge_tts::tts::client::connect()
.map_err(|e| VoiceError::Edge(format!("connect: {e}")))?;
let cfg = msedge_tts::tts::SpeechConfig {
voice_name: voice_owned,
audio_format: format_owned,
pitch: 0,
rate,
volume: 0,
};
let synthesized = client
.synthesize(&body_owned, &cfg)
.map_err(|e| VoiceError::Edge(format!("synthesize: {e}")))?;
Ok(synthesized.audio_bytes)
})
.await
.map_err(|e| VoiceError::Edge(format!("synthesize join: {e}")))?
}
fn strip_ssml_tags(input: &str) -> String {
let mut out = String::with_capacity(input.len());
let mut in_tag = false;
for ch in input.chars() {
if ch == '<' {
in_tag = true;
continue;
}
if ch == '>' {
in_tag = false;
continue;
}
if !in_tag {
out.push(ch);
}
}
out.split_whitespace().collect::<Vec<_>>().join(" ")
}
pub async fn transcode_mp3_to_opus_ogg(mp3: &[u8]) -> Result<Vec<u8>> {
let mp3_owned = mp3.to_vec();
tokio::task::spawn_blocking(move || transcode_mp3_to_opus_ogg_blocking(&mp3_owned))
.await
.map_err(|e| VoiceError::Ffmpeg(format!("transcode join: {e}")))?
}
fn transcode_mp3_to_opus_ogg_blocking(mp3: &[u8]) -> Result<Vec<u8>> {
use ogg::PacketWriteEndInfo;
use opus_wave::{Application, Channels, OpusEncoder, SampleRate};
use symphonia::core::audio::SampleBuffer;
use symphonia::core::codecs::{DecoderOptions, CODEC_TYPE_NULL};
use symphonia::core::errors::Error as SymError;
use symphonia::core::formats::FormatOptions;
use symphonia::core::io::MediaSourceStream;
use symphonia::core::meta::MetadataOptions;
use symphonia::core::probe::Hint;
if mp3.is_empty() {
return Err(VoiceError::Ffmpeg("mp3 input is empty".into()));
}
let cursor = Cursor::new(mp3.to_vec());
let mss = MediaSourceStream::new(Box::new(cursor), Default::default());
let mut hint = Hint::new();
hint.with_extension("mp3");
let probed = symphonia::default::get_probe()
.format(
&hint,
mss,
&FormatOptions::default(),
&MetadataOptions::default(),
)
.map_err(|e| VoiceError::Ffmpeg(format!("probe mp3: {e}")))?;
let mut format = probed.format;
let track = format
.tracks()
.iter()
.find(|t| t.codec_params.codec != CODEC_TYPE_NULL)
.ok_or_else(|| VoiceError::Ffmpeg("no audio track in mp3".into()))?;
let track_id = track.id;
let source_rate = track
.codec_params
.sample_rate
.ok_or_else(|| VoiceError::Ffmpeg("mp3 has no sample rate".into()))?;
let mut decoder = symphonia::default::get_codecs()
.make(&track.codec_params, &DecoderOptions::default())
.map_err(|e| VoiceError::Ffmpeg(format!("mp3 decoder init: {e}")))?;
let mut pcm_at_source: Vec<f32> = Vec::new();
let mut sample_buf: Option<SampleBuffer<f32>> = None;
let mut input_channels: usize = 1;
loop {
let packet = match format.next_packet() {
Ok(p) => p,
Err(SymError::IoError(ref e)) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
break;
}
Err(SymError::ResetRequired) => break,
Err(e) => return Err(VoiceError::Ffmpeg(format!("mp3 packet: {e}"))),
};
if packet.track_id() != track_id {
continue;
}
let decoded = match decoder.decode(&packet) {
Ok(d) => d,
Err(SymError::DecodeError(_)) => continue,
Err(e) => return Err(VoiceError::Ffmpeg(format!("mp3 decode: {e}"))),
};
if sample_buf.is_none() {
let spec = *decoded.spec();
input_channels = spec.channels.count().max(1);
sample_buf = Some(SampleBuffer::<f32>::new(decoded.capacity() as u64, spec));
}
let sb = sample_buf.as_mut().unwrap();
sb.copy_interleaved_ref(decoded);
let interleaved = sb.samples();
if input_channels == 1 {
pcm_at_source.extend_from_slice(interleaved);
} else {
for chunk in interleaved.chunks_exact(input_channels) {
let avg = chunk.iter().sum::<f32>() / input_channels as f32;
pcm_at_source.push(avg);
}
}
}
if pcm_at_source.is_empty() {
return Err(VoiceError::Ffmpeg("mp3 decoded to empty PCM".into()));
}
const TARGET_RATE: u32 = 24_000;
let pcm_24k = if source_rate == TARGET_RATE {
pcm_at_source
} else {
resample_linear(&pcm_at_source, source_rate, TARGET_RATE)
};
let mut encoder = OpusEncoder::new(SampleRate::Hz24000, Channels::Mono, Application::Voip)
.map_err(|e| VoiceError::Ffmpeg(format!("opus encoder init: {e:?}")))?;
const FRAME_MS: u32 = 20;
let frame_size_per_channel: i32 = (TARGET_RATE * FRAME_MS / 1000) as i32;
let frame_size_usize = frame_size_per_channel as usize;
let mut output = Vec::<u8>::new();
let serial: u32 = 0xCA5C_ADE0; {
let mut writer = ogg::PacketWriter::new(Cursor::new(&mut output));
let mut head = Vec::with_capacity(19);
head.extend_from_slice(b"OpusHead");
head.push(1); head.push(1); head.extend_from_slice(&312u16.to_le_bytes());
head.extend_from_slice(&TARGET_RATE.to_le_bytes());
head.extend_from_slice(&0i16.to_le_bytes()); head.push(0);
writer
.write_packet(
Cow::<[u8]>::Owned(head),
serial,
PacketWriteEndInfo::EndPage,
0,
)
.map_err(|e| VoiceError::Ffmpeg(format!("ogg OpusHead: {e}")))?;
let vendor = b"nexo-microapp-sdk";
let mut tags = Vec::with_capacity(8 + 4 + vendor.len() + 4);
tags.extend_from_slice(b"OpusTags");
tags.extend_from_slice(&(vendor.len() as u32).to_le_bytes());
tags.extend_from_slice(vendor);
tags.extend_from_slice(&0u32.to_le_bytes()); writer
.write_packet(
Cow::<[u8]>::Owned(tags),
serial,
PacketWriteEndInfo::EndPage,
0,
)
.map_err(|e| VoiceError::Ffmpeg(format!("ogg OpusTags: {e}")))?;
let total_frames = pcm_24k.len() / frame_size_usize;
if total_frames == 0 {
return Err(VoiceError::Ffmpeg(
"mp3 too short to encode a single 20 ms opus frame".into(),
));
}
let mut packet_buf = vec![0u8; 4000];
let packet_buf_cap: i32 = packet_buf.len() as i32;
let mut granule_48k: u64 = 0;
for i in 0..total_frames {
let start = i * frame_size_usize;
let frame = &pcm_24k[start..start + frame_size_usize];
let len = encoder
.encode_float(
frame,
frame_size_per_channel,
&mut packet_buf,
packet_buf_cap,
)
.map_err(|e| VoiceError::Ffmpeg(format!("opus encode: {e:?}")))?;
if len <= 0 {
continue;
}
granule_48k += (frame_size_usize as u64) * 2; let info = if i + 1 == total_frames {
PacketWriteEndInfo::EndStream
} else {
PacketWriteEndInfo::NormalPacket
};
let bytes = packet_buf[..len as usize].to_vec();
writer
.write_packet(Cow::<[u8]>::Owned(bytes), serial, info, granule_48k)
.map_err(|e| VoiceError::Ffmpeg(format!("ogg audio: {e}")))?;
}
}
if output.is_empty() {
return Err(VoiceError::Ffmpeg("produced 0 bytes".into()));
}
Ok(output)
}
fn resample_linear(input: &[f32], from_hz: u32, to_hz: u32) -> Vec<f32> {
if from_hz == to_hz || input.is_empty() {
return input.to_vec();
}
let ratio = from_hz as f64 / to_hz as f64;
let out_len = ((input.len() as f64) / ratio).floor() as usize;
let mut out = Vec::with_capacity(out_len);
let last_idx = input.len() - 1;
for i in 0..out_len {
let src = i as f64 * ratio;
let i0 = src.floor() as usize;
let i1 = (i0 + 1).min(last_idx);
let frac = (src - i0 as f64) as f32;
let s0 = input[i0];
let s1 = input[i1];
out.push(s0 + (s1 - s0) * frac);
}
out
}
pub async fn synthesize_voice_note(
text: &str,
voice_id: &str,
provider: &dyn TtsProvider,
) -> Result<VoiceNote> {
if voice_id.trim().is_empty() {
return Err(VoiceError::EmptyVoiceId);
}
let normalised = normalise_markdown_for_tts(text);
let with_ssml = apply_ssml_hints(&normalised);
let stripped = strip_emojis_for_tts(&with_ssml);
let body = collapse_punctuation(stripped.trim());
if body.is_empty() {
return Err(VoiceError::EmptyText);
}
let counts = count_markers(text);
let tag_counts = count_ssml_tags(&body);
let body_preview: String = body.chars().take(400).collect();
tracing::info!(
marker_pause = counts.pause,
marker_em = counts.em,
marker_strong = counts.strong,
marker_spell = counts.spell,
marker_slow = counts.slow,
marker_fast = counts.fast,
ssml_break = tag_counts.break_,
ssml_emphasis = tag_counts.emphasis,
ssml_say_as = tag_counts.say_as,
ssml_prosody = tag_counts.prosody,
body_len = body.len(),
body_preview = %body_preview,
"voice: ssml pipeline ready",
);
let mp3 = provider.synthesize_raw(&body, voice_id).await?;
let audio_bytes = transcode_mp3_to_opus_ogg(&mp3).await?;
let transcript = strip_voice_markers(text);
Ok(VoiceNote {
audio_bytes,
mimetype: VOICE_NOTE_MIME,
transcript,
})
}
#[derive(Default)]
struct MarkerCounts {
pause: usize,
em: usize,
strong: usize,
spell: usize,
slow: usize,
fast: usize,
}
fn count_markers(input: &str) -> MarkerCounts {
MarkerCounts {
pause: input.matches("[pause=").count(),
em: input.matches("[em]").count(),
strong: input.matches("[strong]").count(),
spell: input.matches("[spell]").count(),
slow: input.matches("[slow]").count(),
fast: input.matches("[fast]").count(),
}
}
#[derive(Default)]
struct TagCounts {
break_: usize,
emphasis: usize,
say_as: usize,
prosody: usize,
}
fn count_ssml_tags(input: &str) -> TagCounts {
TagCounts {
break_: input.matches("<break ").count(),
emphasis: input.matches("<emphasis ").count(),
say_as: input.matches("<say-as ").count(),
prosody: input.matches("<prosody ").count(),
}
}
#[cfg(test)]
mod tests {
use super::*;
struct StubProvider {
canned: Result<Vec<u8>>,
}
#[async_trait]
impl TtsProvider for StubProvider {
async fn synthesize_raw(&self, _body: &str, _voice: &str) -> Result<Vec<u8>> {
match &self.canned {
Ok(b) => Ok(b.clone()),
Err(e) => Err(match e {
VoiceError::Edge(s) => VoiceError::Edge(s.clone()),
VoiceError::EmptySynthesis => VoiceError::EmptySynthesis,
_ => VoiceError::Edge("stub".into()),
}),
}
}
}
#[tokio::test]
async fn synthesize_voice_note_rejects_empty_text() {
let p = StubProvider {
canned: Ok(b"x".to_vec()),
};
let r = synthesize_voice_note(" ", "es-MX-DaliaNeural", &p).await;
assert!(matches!(r, Err(VoiceError::EmptyText)));
}
#[tokio::test]
async fn synthesize_voice_note_rejects_empty_voice() {
let p = StubProvider {
canned: Ok(b"x".to_vec()),
};
let r = synthesize_voice_note("hola", "", &p).await;
assert!(matches!(r, Err(VoiceError::EmptyVoiceId)));
}
#[test]
fn strip_ssml_tags_drops_break_and_say_as() {
let s = strip_ssml_tags(
r#"hola <break time="200ms"/> mundo <say-as interpret-as="characters">SIC</say-as>"#,
);
assert_eq!(s, "hola mundo SIC");
}
#[test]
fn marker_counts_tracks_each_kind() {
let raw = "[pause=400ms] [em]foo[/em] [strong]bar[/strong] [spell]X[/spell]";
let c = count_markers(raw);
assert_eq!(c.pause, 1);
assert_eq!(c.em, 1);
assert_eq!(c.strong, 1);
assert_eq!(c.spell, 1);
}
}