use crate::constants::env::ai_code;
use crate::constants::oauth;
use serde::{Deserialize, Serialize};
use std::fmt;
use std::sync::Arc;
use tokio::sync::{mpsc, oneshot, watch, Mutex};
use tokio::time::{interval, Duration};
use tokio_tungstenite::tungstenite::Message;
const VOICE_STREAM_PATH: &str = "/api/ws/speech_to_text/voice_stream";
const KEEPALIVE_INTERVAL_SECS: u64 = 8;
pub const FINALIZE_TIMEOUTS_MS: FinalizeTimeouts = FinalizeTimeouts {
safety: 5_000,
no_data: 1_500,
};
#[derive(Debug, Clone)]
pub struct FinalizeTimeouts {
pub safety: u64,
pub no_data: u64,
}
pub trait VoiceStreamCallbacks: Send + Sync {
fn on_transcript(&self, text: &str, is_final: bool);
fn on_error(&self, error: &str, fatal: bool);
fn on_close(&self);
fn on_ready(&self, connection: VoiceStreamConnection);
}
#[derive(Debug, Clone, PartialEq)]
pub enum FinalizeSource {
PostClosestreamEndpoint,
NoDataTimeout,
SafetyTimeout,
WsClose,
WsAlreadyClosed,
}
impl fmt::Display for FinalizeSource {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
FinalizeSource::PostClosestreamEndpoint => write!(f, "post_closestream_endpoint"),
FinalizeSource::NoDataTimeout => write!(f, "no_data_timeout"),
FinalizeSource::SafetyTimeout => write!(f, "safety_timeout"),
FinalizeSource::WsClose => write!(f, "ws_close"),
FinalizeSource::WsAlreadyClosed => write!(f, "ws_already_closed"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TranscriptTextEvent {
#[serde(rename = "type")]
pub event_type: String,
pub data: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TranscriptEndpointEvent {
#[serde(rename = "type")]
pub event_type: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TranscriptErrorEvent {
#[serde(rename = "type")]
pub event_type: String,
#[serde(rename = "errorCode")]
pub error_code: Option<String>,
pub description: Option<String>,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(tag = "type")]
pub enum VoiceStreamServerMessage {
TranscriptText { data: String },
TranscriptEndpoint,
TranscriptError {
#[serde(rename = "errorCode")]
error_code: Option<String>,
description: Option<String>,
},
Error { message: Option<String> },
}
#[derive(Debug)]
enum AudioFrame {
Binary(Vec<u8>),
}
#[derive(Debug)]
enum TranscriptResult {
Text(String),
Final(String),
Done,
}
#[derive(Clone)]
pub struct VoiceStreamConnection {
audio_tx: mpsc::UnboundedSender<AudioFrame>,
transcript_rx: mpsc::UnboundedReceiver<TranscriptResult>,
state: Arc<watch::Sender<ConnectionState>>,
on_tx: oneshot::Sender<FinalizeSource>,
on_rx: Mutex<Option<oneshot::Receiver<FinalizeSource>>>,
task_handle: Arc<Mutex<Option<tokio::task::JoinHandle<()>>>>,
}
#[derive(Debug, Clone, PartialEq)]
enum ConnectionState {
Connected,
Closing,
}
impl VoiceStreamConnection {
pub fn send(&self, audio_chunk: &[u8]) {
let state = *self.state.borrow();
if state != ConnectionState::Connected {
return;
}
let frame = AudioFrame::Binary(audio_chunk.to_vec());
if let Err(e) = self.audio_tx.send(frame) {
eprintln!("[voice_stream] Failed to queue audio chunk: {:?}", e);
}
}
pub async fn finalize(&self) -> FinalizeSource {
let mut handle_lock = self.task_handle.lock().await;
if handle_lock.is_none() {
return FinalizeSource::WsAlreadyClosed;
}
drop(self.audio_tx.clone());
drop(self.audio_tx);
let mut rx_lock = self.on_rx.lock().await;
if let Some(rx) = rx_lock.take() {
match rx.await {
Ok(source) => source,
Err(_) => FinalizeSource::WsClose,
}
} else {
FinalizeSource::WsAlreadyClosed
}
}
pub fn close(&self) {
self.state.send_replace(ConnectionState::Closing);
}
pub fn is_connected(&self) -> bool {
let state = *self.state.borrow();
state == ConnectionState::Connected
}
}
pub fn is_voice_stream_available() -> bool {
std::env::var(ai_code::OAUTH_TOKEN).is_ok()
}
fn get_access_token() -> Option<String> {
std::env::var(ai_code::OAUTH_TOKEN).ok()
}
pub async fn connect_voice_stream(
callbacks: Arc<dyn VoiceStreamCallbacks>,
language: Option<String>,
keyterms: Option<Vec<String>>,
) -> Option<VoiceStreamConnection> {
let access_token = match get_access_token() {
Some(t) => t,
None => {
eprintln!("[voice_stream] No OAuth token available");
return None;
}
};
let ws_base_url = std::env::var(ai_code::VOICE_STREAM_BASE_URL)
.ok()
.unwrap_or_else(|| {
oauth::get_oauth_config()
.base_api_url
.replace("https://", "wss://")
.replace("http://", "ws://")
});
if std::env::var(ai_code::VOICE_STREAM_BASE_URL).is_ok() {
eprintln!(
"[voice_stream] Using VOICE_STREAM_BASE_URL override: {}",
ws_base_url
);
}
let mut params: Vec<(&str, &str)> = vec![
("encoding", "linear16"),
("sample_rate", "16000"),
("channels", "1"),
("endpointing_ms", "300"),
("utterance_end_ms", "1000"),
("language", language.as_deref().unwrap_or("en")),
("use_conversation_engine", "true"),
("stt_provider", "deepgram-nova3"),
];
if let Some(ref terms) = keyterms {
for term in terms {
params.push(("keyterms", term.as_str()));
}
}
let query: String = params
.iter()
.map(|(k, v)| format!("{}={}", k, urlencoding::encode(v)))
.collect::<Vec<_>>()
.join("&");
let url = format!("{}{}?{}", ws_base_url, VOICE_STREAM_PATH, query);
eprintln!("[voice_stream] Connecting to {}", url);
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
reqwest::header::AUTHORIZATION,
format!("Bearer {}", access_token).parse().unwrap(),
);
headers.insert(
reqwest::header::USER_AGENT,
crate::utils::http::get_user_agent().parse().unwrap(),
);
headers.insert("x-app", "cli".parse().unwrap());
let (ws_stream, _response) =
match tokio_tungstenite::connect_async_with_headers(&url, headers).await {
Ok(r) => r,
Err(e) => {
eprintln!(
"[voice_stream] WebSocket connection failed: {}",
e
);
callbacks.on_error(
&format!("Voice stream connection error: {}", e),
false,
);
return None;
}
};
eprintln!("[voice_stream] WebSocket connected");
let (ws_write, ws_read) = ws_stream.split();
let (state_tx, _state_rx) = watch::channel(ConnectionState::Connected);
let (audio_tx, audio_rx) = mpsc::unbounded_channel::<AudioFrame>();
let (transcript_tx, transcript_rx) = mpsc::unbounded_channel::<TranscriptResult>();
let (on_tx, finalize_rx) = oneshot::channel::<FinalizeSource>();
let initial_keepalive = Message::Text("{\"type\":\"KeepAlive\"}".to_string());
if let Err(e) = ws_write.send(initial_keepalive).await {
eprintln!("[voice_stream] Failed to send initial KeepAlive: {:?}", e);
} else {
eprintln!("[voice_stream] Sent initial KeepAlive");
}
let task_handle: Arc<Mutex<Option<tokio::task::JoinHandle<()>>>> =
Arc::new(Mutex::new(None));
let task_handle_clone = task_handle.clone();
let handle = tokio::spawn(async move {
let mut keepalive_interval =
interval(Duration::from_secs(KEEPALIVE_INTERVAL_SECS));
keepalive_interval.tick().await;
let mut audio_rx = audio_rx;
let mut ws_read = ws_read;
let mut transcript_tx = transcript_tx;
let ws_write_keepalive = ws_write.clone();
tokio::spawn(async move {
let mut interval = keepalive_interval;
loop {
interval.tick().await;
let msg = Message::Text("{\"type\":\"KeepAlive\"}".to_string());
if let Err(e) = ws_write_keepalive.send(msg).await {
eprintln!("[voice_stream] Keepalive send failed: {:?}", e);
break;
} else {
eprintln!("[voice_stream] Sending periodic KeepAlive");
}
}
});
let mut finalized = false; let mut finalizing = false;
let mut last_transcript_text = String::new();
loop {
tokio::select! {
frame = audio_rx.recv() => {
match frame {
None => {
eprintln!("[voice_stream] Audio channel closed, finalizing");
finalized = true;
let close_msg = Message::Text("{\"type\":\"CloseStream\"}".to_string());
let ws_write_clone = ws_write.clone();
tokio::spawn(async move {
if let Err(e) = ws_write_clone.send(close_msg).await {
eprintln!("[voice_stream] Failed to send CloseStream: {:?}", e);
} else {
eprintln!("[voice_stream] Sent CloseStream (finalize)");
}
});
break;
}
Some(AudioFrame::Binary(data)) => {
if !finalized {
eprintln!(
"[voice_stream] Sending audio chunk: {} bytes",
data.len()
);
if let Err(e) = ws_write.send(Message::Binary(data)).await {
eprintln!("[voice_stream] Failed to send audio: {:?}", e);
break;
}
}
}
}
}
result = ws_read.next() => {
match result {
Some(Ok(Message::Text(text))) => {
eprintln!(
"[voice_stream] Message received ({} chars): {}",
text.len(),
text.chars().take(200).collect::<String>()
);
let msg: Result<VoiceStreamServerMessage, _> =
serde_json::from_str(&text);
match msg {
Ok(VoiceStreamServerMessage::TranscriptText { data }) => {
eprintln!(
"[voice_stream] TranscriptText: \"{}\"",
data
);
if finalized {
}
if !data.is_empty() {
last_transcript_text = data.clone();
let _ = transcript_tx.send(TranscriptResult::Text(data));
}
}
Ok(VoiceStreamServerMessage::TranscriptEndpoint) => {
eprintln!(
"[voice_stream] TranscriptEndpoint received, lastTranscriptText=\"{}\"",
last_transcript_text
);
let final_text = last_transcript_text.clone();
last_transcript_text.clear();
if !final_text.is_empty() {
let _ = transcript_tx.send(TranscriptResult::Final(final_text.clone()));
}
if finalized {
eprintln!(
"[voice_stream] Finalize resolved via post_closestream_endpoint"
);
let _ = on_tx.send(FinalizeSource::PostClosestreamEndpoint);
}
}
Ok(VoiceStreamServerMessage::TranscriptError { error_code, description }) => {
let desc = description
.or_else(|| error_code)
.unwrap_or_else(|| "unknown transcription error".to_string());
eprintln!("[voice_stream] TranscriptError: {}", desc);
if !finalizing {
callbacks.on_error(&desc, false);
}
}
Ok(VoiceStreamServerMessage::Error { message }) => {
let error_detail = message
.unwrap_or_else(|| {
serde_json::to_string(&msg).unwrap_or_default()
});
eprintln!("[voice_stream] Server error: {}", error_detail);
if !finalizing {
callbacks.on_error(&error_detail, false);
}
}
Err(e) => {
eprintln!("[voice_stream] Failed to parse message: {}", e);
}
}
}
Some(Ok(Message::Binary(data))) => {
eprintln!(
"[voice_stream] Binary message received: {} bytes",
data.len()
);
}
Some(Ok(Message::Close(close_frame))) => {
let code = close_frame
.as_ref()
.map(|c| c.code.as_u16())
.unwrap_or(1005);
let reason = close_frame
.as_ref()
.and_then(|c| c.reason.as_str())
.unwrap_or("");
eprintln!(
"[voice_stream] WebSocket closed: code={}, reason=\"{}\"",
code, reason
);
if !last_transcript_text.is_empty() {
eprintln!(
"[voice_stream] Promoting unreported interim transcript to final on close"
);
let final_text = last_transcript_text.clone();
last_transcript_text.clear();
callbacks.on_transcript(&final_text, true);
}
if let Err(_) = on_tx.send(FinalizeSource::WsClose) {
}
callbacks.on_close();
break;
}
Some(Ok(_)) => {
}
Some(Err(e)) => {
eprintln!("[voice_stream] WebSocket error: {:?}", e);
if !finalizing {
callbacks.on_error(
&format!("Voice stream connection error: {:?}", e),
false,
);
}
}
None => {
break;
}
}
}
}
}
eprintln!("[voice_stream] Listener task cleanup");
if !last_transcript_text.is_empty() {
eprintln!(
"[voice_stream] Promoting unreported interim transcript during task exit"
);
let final_text = last_transcript_text.clone();
callbacks.on_transcript(&final_text, true);
}
let _ = transcript_tx.send(TranscriptResult::Done);
if let Err(_) = on_tx.send(FinalizeSource::WsClose) {
}
callbacks.on_close();
{
let mut handle = task_handle_clone.lock().await;
*handle = None;
}
eprintln!("[voice_stream] Listener task ended");
});
{
let mut handle = task_handle_clone.lock().await;
*handle = Some(handle);
}
let connection = VoiceStreamConnection {
audio_tx,
transcript_rx,
state: Arc::new(state_tx),
on_tx,
on_rx: Mutex::new(Some(finalize_rx)),
task_handle: task_handle_clone,
};
callbacks.on_ready(connection.clone());
Some(connection)
}
pub fn keepalive_message() -> String {
"{\"type\":\"KeepAlive\"}".to_string()
}
pub fn close_stream_message() -> String {
"{\"type\":\"CloseStream\"}".to_string()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_keepalive_message_format() {
let msg = keepalive_message();
assert_eq!(msg, "{\"type\":\"KeepAlive\"}");
}
#[test]
fn test_close_stream_message_format() {
let msg = close_stream_message();
assert_eq!(msg, "{\"type\":\"CloseStream\"}");
}
#[test]
fn test_finalize_source_display() {
assert_eq!(
format!("{}", FinalizeSource::PostClosestreamEndpoint),
"post_closestream_endpoint"
);
assert_eq!(
format!("{}", FinalizeSource::NoDataTimeout),
"no_data_timeout"
);
assert_eq!(
format!("{}", FinalizeSource::SafetyTimeout),
"safety_timeout"
);
assert_eq!(format!("{}", FinalizeSource::WsClose), "ws_close");
assert_eq!(
format!("{}", FinalizeSource::WsAlreadyClosed),
"ws_already_closed"
);
}
#[test]
fn test_finalize_timeouts_constants() {
assert_eq!(FINALIZE_TIMEOUTS_MS.safety, 5_000);
assert_eq!(FINALIZE_TIMEOUTS_MS.no_data, 1_500);
}
#[test]
fn test_voice_stream_server_message_transcript_text() {
let json = r#"{"type":"TranscriptText","data":"hello world"}"#;
let msg: VoiceStreamServerMessage = serde_json::from_str(json).unwrap();
match msg {
VoiceStreamServerMessage::TranscriptText { data } => {
assert_eq!(data, "hello world");
}
_ => panic!("Expected TranscriptText"),
}
}
#[test]
fn test_voice_stream_server_message_transcript_endpoint() {
let json = r#"{"type":"TranscriptEndpoint"}"#;
let msg: VoiceStreamServerMessage = serde_json::from_str(json).unwrap();
assert!(matches!(msg, VoiceStreamServerMessage::TranscriptEndpoint));
}
#[test]
fn test_voice_stream_server_message_transcript_error() {
let json = r#"{"type":"TranscriptError","errorCode":"invalid_audio","description":"Bad audio format"}"#;
let msg: VoiceStreamServerMessage = serde_json::from_str(json).unwrap();
match msg {
VoiceStreamServerMessage::TranscriptError {
error_code,
description,
} => {
assert_eq!(error_code, Some("invalid_audio".to_string()));
assert_eq!(description, Some("Bad audio format".to_string()));
}
_ => panic!("Expected TranscriptError"),
}
}
#[test]
fn test_voice_stream_server_message_error() {
let json = r#"{"type":"error","message":"Server error"}"#;
let msg: VoiceStreamServerMessage = serde_json::from_str(json).unwrap();
match msg {
VoiceStreamServerMessage::Error { message } => {
assert_eq!(message, Some("Server error".to_string()));
}
_ => panic!("Expected Error"),
}
}
#[test]
fn test_connection_state_values() {
assert_eq!(ConnectionState::Connected, ConnectionState::Connected);
assert_eq!(ConnectionState::Closing, ConnectionState::Closing);
assert_ne!(ConnectionState::Connected, ConnectionState::Closing);
}
#[test]
fn test_voice_stream_server_message_transcript_error_no_desc() {
let json = r#"{"type":"TranscriptError","errorCode":"bad_format"}"#;
let msg: VoiceStreamServerMessage = serde_json::from_str(json).unwrap();
match msg {
VoiceStreamServerMessage::TranscriptError {
error_code,
description,
} => {
assert_eq!(error_code, Some("bad_format".to_string()));
assert!(description.is_none());
}
_ => panic!("Expected TranscriptError"),
}
}
#[test]
fn test_voice_stream_server_message_error_no_message() {
let json = r#"{"type":"error"}"#;
let msg: VoiceStreamServerMessage = serde_json::from_str(json).unwrap();
match msg {
VoiceStreamServerMessage::Error { message } => {
assert!(message.is_none());
}
_ => panic!("Expected Error"),
}
}
}