use std::sync::Arc;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use async_trait::async_trait;
use futures::{SinkExt, StreamExt};
use log;
use serde::Deserialize;
use tokio::sync::{mpsc, Mutex};
use tokio::task::JoinHandle;
use tokio_tungstenite::tungstenite::http::Request;
use tokio_tungstenite::tungstenite::Message;
use crate::error::Result;
use crate::frames::{
AudioRawData, ControlFrame, DataFrame, Frame, FrameDirection, FrameHandler, FrameInner,
FrameProcessor, SystemFrame,
};
fn now() -> f64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs_f64()
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DeepgramEncoding {
Linear16,
Mulaw,
Alaw,
}
impl DeepgramEncoding {
pub fn as_str(&self) -> &'static str {
match self {
Self::Linear16 => "linear16",
Self::Mulaw => "mulaw",
Self::Alaw => "alaw",
}
}
pub fn from_str(s: &str) -> std::result::Result<Self, String> {
match s.to_lowercase().as_str() {
"linear16" => Ok(Self::Linear16),
"mulaw" => Ok(Self::Mulaw),
"alaw" => Ok(Self::Alaw),
other => Err(format!(
"Unsupported encoding '{}'. Must be one of: linear16, mulaw, alaw",
other
)),
}
}
pub fn bytes_per_sample(&self) -> u32 {
match self {
Self::Linear16 => 2,
Self::Mulaw | Self::Alaw => 1,
}
}
}
#[derive(Debug, Clone)]
pub struct DeepgramTtsConfig {
pub api_key: String,
pub voice: String,
pub encoding: DeepgramEncoding,
pub sample_rate: u32,
pub mip_opt_out: Option<bool>,
pub base_url: String,
}
impl Default for DeepgramTtsConfig {
fn default() -> Self {
Self {
api_key: String::new(),
voice: "aura-2-helena-en".to_string(),
encoding: DeepgramEncoding::Linear16,
sample_rate: 24000,
mip_opt_out: None,
base_url: "wss://api.deepgram.com".to_string(),
}
}
}
#[derive(Debug, Deserialize)]
struct DeepgramMessage {
#[serde(rename = "type")]
msg_type: String,
#[serde(default)]
description: Option<String>,
}
struct TtsState {
ws_tx: Option<mpsc::Sender<Message>>,
bot_speaking: bool,
send_task: Option<JoinHandle<()>>,
receive_task: Option<JoinHandle<()>>,
flush_sent_at: Option<f64>,
first_audio_logged: bool,
}
impl TtsState {
fn new() -> Self {
Self {
ws_tx: None,
bot_speaking: false,
send_task: None,
receive_task: None,
flush_sent_at: None,
first_audio_logged: true,
}
}
}
pub struct DeepgramTtsHandler {
config: DeepgramTtsConfig,
sample_rate: u32,
state: Arc<Mutex<TtsState>>,
}
impl DeepgramTtsHandler {
pub fn new(config: DeepgramTtsConfig) -> std::result::Result<Self, String> {
if config.api_key.is_empty() {
return Err("DeepgramTts: api_key is required".to_string());
}
let sample_rate = config.sample_rate;
Ok(Self {
config,
sample_rate,
state: Arc::new(Mutex::new(TtsState::new())),
})
}
pub fn into_processor(self) -> FrameProcessor {
FrameProcessor::new("DeepgramTts", Box::new(self), false)
}
async fn connect(&self, processor: FrameProcessor) {
let mut params = vec![
format!("model={}", urlencoding(&self.config.voice)),
format!("encoding={}", self.config.encoding.as_str()),
format!("sample_rate={}", self.sample_rate),
];
if let Some(mip) = self.config.mip_opt_out {
params.push(format!("mip_opt_out={}", mip));
}
let ws_url = format!(
"{}/v1/speak?{}",
self.config.base_url,
params.join("&")
);
log::info!("DeepgramTts: connecting to {}", ws_url);
let request = match Request::builder()
.uri(&ws_url)
.header("Host", extract_host(&self.config.base_url))
.header("Authorization", format!("Token {}", self.config.api_key))
.header("Connection", "Upgrade")
.header("Upgrade", "websocket")
.header("Sec-WebSocket-Version", "13")
.header(
"Sec-WebSocket-Key",
tokio_tungstenite::tungstenite::handshake::client::generate_key(),
)
.body(())
{
Ok(r) => r,
Err(e) => {
let _ = processor
.push_error(format!("DeepgramTts: request build failed: {}", e), false)
.await;
return;
}
};
let ws_stream = match tokio_tungstenite::connect_async(request).await {
Ok((s, _)) => s,
Err(e) => {
let _ = processor
.push_error(format!("DeepgramTts: connect failed: {}", e), false)
.await;
return;
}
};
let (sink, stream) = ws_stream.split();
let (ws_tx, ws_rx) = mpsc::channel::<Message>(64);
let send_task = tokio::spawn(run_send_task(sink, ws_rx));
let receive_task = tokio::spawn(run_receive_task(
stream,
processor.clone(),
self.sample_rate,
self.state.clone(),
));
{
let mut state = self.state.lock().await;
state.ws_tx = Some(ws_tx);
state.send_task = Some(send_task);
state.receive_task = Some(receive_task);
}
log::info!("DeepgramTts: connected (voice={}, encoding={}, sample_rate={})",
self.config.voice, self.config.encoding.as_str(), self.sample_rate);
}
async fn disconnect(&self) {
self.send_ws_json(r#"{"type":"Close"}"#).await;
let mut state = self.state.lock().await;
if let Some(h) = state.send_task.take() { h.abort(); }
if let Some(h) = state.receive_task.take() { h.abort(); }
state.ws_tx = None;
log::info!("DeepgramTts: disconnected");
}
async fn send_ws_json(&self, json: &str) {
let tx = { self.state.lock().await.ws_tx.clone() };
if let Some(tx) = tx {
let _ = tx.send(Message::Text(json.into())).await;
}
}
async fn send_speak(&self, text: &str) {
let tx = { self.state.lock().await.ws_tx.clone() };
if let Some(tx) = tx {
let ts = now();
println!("[{:.3}] [tts] speak ({} chars): {:?}", ts, text.len(), text);
let msg = serde_json::json!({
"type": "Speak",
"text": text
});
let _ = tx.send(Message::Text(msg.to_string().into())).await;
}
}
async fn send_flush(&self) {
let tx = { self.state.lock().await.ws_tx.clone() };
if let Some(tx) = tx {
let ts = now();
println!("[{:.3}] [tts] flush ← synthesis starts now", ts);
{
let mut state = self.state.lock().await;
state.flush_sent_at = Some(ts);
state.first_audio_logged = false;
}
let msg = r#"{"type":"Flush"}"#;
let _ = tx.send(Message::Text(msg.into())).await;
}
}
async fn send_clear(&self) {
let ts = now();
println!("[{:.3}] [tts] clear ← interruption", ts);
self.send_ws_json(r#"{"type":"Clear"}"#).await;
}
}
#[async_trait]
impl FrameHandler for DeepgramTtsHandler {
async fn on_process_frame(
&self,
processor: &FrameProcessor,
frame: Frame,
direction: FrameDirection,
) -> Result<()> {
match &frame.inner {
FrameInner::System(SystemFrame::Start(_)) => {
processor.push_frame(frame, direction).await?;
self.connect(processor.clone()).await;
}
FrameInner::Control(ControlFrame::LLMFullResponseStart) => {
processor.push_frame(frame, direction).await?;
}
FrameInner::Data(DataFrame::LLMText(text)) => {
let text = text.clone();
if !text.is_empty() {
self.send_speak(&text).await;
}
processor.push_frame(frame, direction).await?;
}
FrameInner::Control(ControlFrame::LLMFullResponseEnd) => {
self.send_flush().await;
processor.push_frame(frame, direction).await?;
}
FrameInner::System(SystemFrame::BotStartedSpeaking) => {
self.state.lock().await.bot_speaking = true;
processor.push_frame(frame, direction).await?;
}
FrameInner::System(SystemFrame::BotStoppedSpeaking) => {
self.state.lock().await.bot_speaking = false;
processor.push_frame(frame, direction).await?;
}
FrameInner::System(SystemFrame::Interruption) => {
let bot_speaking = self.state.lock().await.bot_speaking;
if bot_speaking {
log::info!("DeepgramTts: interruption while speaking — sending Clear");
self.send_clear().await;
}
processor.push_frame(frame, direction).await?;
}
FrameInner::Control(ControlFrame::End { .. })
| FrameInner::System(SystemFrame::Cancel { .. }) => {
self.disconnect().await;
processor.push_frame(frame, direction).await?;
}
_ => {
processor.push_frame(frame, direction).await?;
}
}
Ok(())
}
}
type WsSink = futures::stream::SplitSink<
tokio_tungstenite::WebSocketStream<
tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
>,
Message,
>;
type WsStream = futures::stream::SplitStream<
tokio_tungstenite::WebSocketStream<
tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
>,
>;
async fn run_send_task(mut sink: WsSink, mut rx: mpsc::Receiver<Message>) {
while let Some(msg) = rx.recv().await {
if sink.send(msg).await.is_err() {
log::warn!("DeepgramTts: send task — write failed, closing");
break;
}
}
let _ = sink.close().await;
log::debug!("DeepgramTts: send task exited");
}
async fn run_receive_task(
mut stream: WsStream,
processor: FrameProcessor,
sample_rate: u32,
state: Arc<Mutex<TtsState>>,
) {
log::debug!("DeepgramTts: receive task started");
while let Some(result) = stream.next().await {
match result {
Ok(Message::Binary(audio)) => {
if audio.is_empty() {
continue;
}
{
let mut s = state.lock().await;
if !s.first_audio_logged {
s.first_audio_logged = true;
let ts = now();
if let Some(flush_at) = s.flush_sent_at {
println!(
"[{:.3}] [tts] first_audio ← Deepgram latency: {:.3}s",
ts,
ts - flush_at
);
} else {
println!("[{:.3}] [tts] first_audio", ts);
}
}
}
log::debug!("DeepgramTts: received {} audio bytes", audio.len());
let frame = Frame::output_audio_raw(AudioRawData::new(
audio.to_vec(),
sample_rate,
1,
));
let _ = processor
.push_frame(frame, FrameDirection::Downstream)
.await;
}
Ok(Message::Text(text)) => {
handle_text_message(text.as_str(), &processor).await;
}
Ok(Message::Close(_)) => {
log::info!("DeepgramTts: server closed WebSocket");
break;
}
Err(e) => {
let _ = processor
.push_error(format!("DeepgramTts: receive error: {}", e), false)
.await;
break;
}
_ => {}
}
}
log::debug!("DeepgramTts: receive task exited");
}
async fn handle_text_message(text: &str, processor: &FrameProcessor) {
log::trace!("DeepgramTts: raw message: {}", text);
let msg: DeepgramMessage = match serde_json::from_str(text) {
Ok(m) => m,
Err(e) => {
log::warn!("DeepgramTts: parse error: {} — raw: {}", e, text);
return;
}
};
match msg.msg_type.as_str() {
"Metadata" => {
log::trace!("DeepgramTts: metadata received");
}
"Flushed" => {
log::debug!("DeepgramTts: Flushed — synthesis complete for current buffer");
}
"Cleared" => {
log::debug!("DeepgramTts: Cleared — buffer cleared after interruption");
}
"Warning" => {
let desc = msg.description.as_deref().unwrap_or("unknown warning");
log::warn!("DeepgramTts: server warning: {}", desc);
}
other => {
log::debug!("DeepgramTts: unhandled message type: '{}'", other);
}
}
}
fn urlencoding(s: &str) -> String {
s.chars()
.flat_map(|c| match c {
'A'..='Z' | 'a'..='z' | '0'..='9'
| '-' | '_' | '.' | '~' => vec![c],
':' => vec!['%', '3', 'A'],
'/' => vec!['%', '2', 'F'],
_ => format!("%{:02X}", c as u32).chars().collect(),
})
.collect()
}
fn extract_host(url: &str) -> String {
url.trim_start_matches("wss://")
.trim_start_matches("ws://")
.split('/')
.next()
.unwrap_or("api.deepgram.com")
.to_string()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_encoding_from_str() {
assert_eq!(DeepgramEncoding::from_str("linear16").unwrap(), DeepgramEncoding::Linear16);
assert_eq!(DeepgramEncoding::from_str("mulaw").unwrap(), DeepgramEncoding::Mulaw);
assert_eq!(DeepgramEncoding::from_str("alaw").unwrap(), DeepgramEncoding::Alaw);
assert_eq!(DeepgramEncoding::from_str("LINEAR16").unwrap(), DeepgramEncoding::Linear16);
assert!(DeepgramEncoding::from_str("mp3").is_err());
}
#[test]
fn test_encoding_as_str() {
assert_eq!(DeepgramEncoding::Linear16.as_str(), "linear16");
assert_eq!(DeepgramEncoding::Mulaw.as_str(), "mulaw");
assert_eq!(DeepgramEncoding::Alaw.as_str(), "alaw");
}
#[test]
fn test_bytes_per_sample() {
assert_eq!(DeepgramEncoding::Linear16.bytes_per_sample(), 2);
assert_eq!(DeepgramEncoding::Mulaw.bytes_per_sample(), 1);
assert_eq!(DeepgramEncoding::Alaw.bytes_per_sample(), 1);
}
#[test]
fn test_default_config() {
let cfg = DeepgramTtsConfig::default();
assert_eq!(cfg.voice, "aura-2-helena-en");
assert_eq!(cfg.encoding, DeepgramEncoding::Linear16);
assert_eq!(cfg.sample_rate, 24000);
assert_eq!(cfg.base_url, "wss://api.deepgram.com");
assert!(cfg.mip_opt_out.is_none());
}
#[test]
fn test_handler_requires_api_key() {
let result = DeepgramTtsHandler::new(DeepgramTtsConfig::default());
assert!(result.is_err());
}
#[test]
fn test_handler_with_api_key() {
let h = DeepgramTtsHandler::new(DeepgramTtsConfig {
api_key: "test-key".to_string(),
..DeepgramTtsConfig::default()
})
.unwrap();
assert_eq!(h.sample_rate, 24000);
assert_eq!(h.config.voice, "aura-2-helena-en");
}
#[test]
fn test_handler_custom_encoding() {
let h = DeepgramTtsHandler::new(DeepgramTtsConfig {
api_key: "test-key".to_string(),
encoding: DeepgramEncoding::Mulaw,
..DeepgramTtsConfig::default()
})
.unwrap();
assert_eq!(h.config.encoding, DeepgramEncoding::Mulaw);
}
#[test]
fn test_handler_custom_sample_rate() {
let h = DeepgramTtsHandler::new(DeepgramTtsConfig {
api_key: "test-key".to_string(),
sample_rate: 16000,
..DeepgramTtsConfig::default()
})
.unwrap();
assert_eq!(h.sample_rate, 16000);
}
#[test]
fn test_extract_host() {
assert_eq!(extract_host("wss://api.deepgram.com"), "api.deepgram.com");
assert_eq!(extract_host("wss://custom.host.io/path"), "custom.host.io");
assert_eq!(extract_host("ws://localhost:8080"), "localhost:8080");
}
#[test]
fn test_urlencoding() {
assert_eq!(urlencoding("aura-2-helena-en"), "aura-2-helena-en");
assert_eq!(urlencoding("model:v2"), "model%3Av2");
}
}