use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use serde_json::json;
use tokio_util::sync::CancellationToken;
use crate::detail::core_interop::CoreInterop;
use crate::error::{FoundryLocalError, Result};
#[derive(Debug, Clone)]
pub struct LiveAudioTranscriptionOptions {
pub sample_rate: u32,
pub channels: u32,
pub bits_per_sample: u32,
pub language: Option<String>,
pub push_queue_capacity: usize,
}
impl Default for LiveAudioTranscriptionOptions {
fn default() -> Self {
Self {
sample_rate: 16000,
channels: 1,
bits_per_sample: 16,
language: None,
push_queue_capacity: 100,
}
}
}
#[derive(Debug, Clone, serde::Deserialize)]
struct LiveAudioTranscriptionRaw {
#[serde(default)]
is_final: bool,
#[serde(default)]
text: String,
start_time: Option<f64>,
end_time: Option<f64>,
id: Option<String>,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct ContentPart {
pub text: String,
pub transcript: String,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct LiveAudioTranscriptionResponse {
pub content: Vec<ContentPart>,
pub is_final: bool,
pub start_time: Option<f64>,
pub end_time: Option<f64>,
pub id: Option<String>,
}
impl LiveAudioTranscriptionResponse {
pub fn from_json(json: &str) -> Result<Self> {
serde_json::from_str::<LiveAudioTranscriptionRaw>(json)
.map(Self::from_raw)
.map_err(FoundryLocalError::from)
}
fn from_raw(raw: LiveAudioTranscriptionRaw) -> Self {
Self {
content: vec![ContentPart {
transcript: raw.text.clone(),
text: raw.text,
}],
is_final: raw.is_final,
start_time: raw.start_time,
end_time: raw.end_time,
id: raw.id,
}
}
}
#[derive(Debug, Clone, serde::Deserialize)]
pub struct CoreErrorResponse {
pub code: String,
pub message: String,
#[serde(rename = "isTransient", default)]
pub is_transient: bool,
}
impl CoreErrorResponse {
pub fn try_parse(error_string: &str) -> Option<Self> {
serde_json::from_str(error_string).ok()
}
}
pub struct LiveAudioTranscriptionStream {
rx: tokio::sync::mpsc::UnboundedReceiver<Result<LiveAudioTranscriptionResponse>>,
}
impl futures_core::Stream for LiveAudioTranscriptionStream {
type Item = Result<LiveAudioTranscriptionResponse>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.rx.poll_recv(cx)
}
}
struct SessionState {
session_handle: Option<String>,
started: bool,
stopped: bool,
push_tx: Option<tokio::sync::mpsc::Sender<Vec<u8>>>,
output_tx: Option<tokio::sync::mpsc::UnboundedSender<Result<LiveAudioTranscriptionResponse>>>,
output_rx: Option<tokio::sync::mpsc::UnboundedReceiver<Result<LiveAudioTranscriptionResponse>>>,
push_loop_handle: Option<tokio::task::JoinHandle<()>>,
}
impl SessionState {
fn new() -> Self {
Self {
session_handle: None,
started: false,
stopped: false,
push_tx: None,
output_tx: None,
output_rx: None,
push_loop_handle: None,
}
}
}
pub struct LiveAudioTranscriptionSession {
model_id: String,
core: Arc<CoreInterop>,
pub settings: LiveAudioTranscriptionOptions,
state: tokio::sync::Mutex<SessionState>,
}
impl LiveAudioTranscriptionSession {
pub(crate) fn new(model_id: &str, core: Arc<CoreInterop>) -> Self {
Self {
model_id: model_id.to_owned(),
core,
settings: LiveAudioTranscriptionOptions::default(),
state: tokio::sync::Mutex::new(SessionState::new()),
}
}
pub async fn start(&self, ct: Option<CancellationToken>) -> Result<()> {
let mut state = self.state.lock().await;
if state.started {
return Err(FoundryLocalError::Validation {
reason: "Streaming session already started. Call stop() first.".into(),
});
}
let active_settings = self.settings.clone();
let (output_tx, output_rx) =
tokio::sync::mpsc::unbounded_channel::<Result<LiveAudioTranscriptionResponse>>();
let (push_tx, push_rx) =
tokio::sync::mpsc::channel::<Vec<u8>>(active_settings.push_queue_capacity);
let request = self.build_start_request(&active_settings);
let core = Arc::clone(&self.core);
let start_future = tokio::task::spawn_blocking(move || {
core.execute_command("audio_stream_start", Some(&request))
});
let session_handle = self.await_start(start_future, ct).await?;
if session_handle.is_empty() {
return Err(FoundryLocalError::CommandExecution {
reason: "Native core did not return a session handle.".into(),
});
}
let push_loop_core = Arc::clone(&self.core);
let push_loop_output_tx = output_tx.clone();
let handle_clone = session_handle.clone();
let push_loop_handle = tokio::task::spawn_blocking(move || {
Self::push_loop(push_loop_core, handle_clone, push_rx, push_loop_output_tx);
});
state.session_handle = Some(session_handle);
state.started = true;
state.stopped = false;
state.push_tx = Some(push_tx);
state.output_tx = Some(output_tx);
state.output_rx = Some(output_rx);
state.push_loop_handle = Some(push_loop_handle);
Ok(())
}
pub async fn append(&self, pcm_data: &[u8], ct: Option<CancellationToken>) -> Result<()> {
let tx = {
let state = self.state.lock().await;
if !state.started || state.stopped {
return Err(FoundryLocalError::Validation {
reason: "No active streaming session. Call start() first.".into(),
});
}
state
.push_tx
.clone()
.ok_or_else(|| FoundryLocalError::Internal {
reason: "Push channel not available — session may be in an invalid state"
.into(),
})?
};
let data = pcm_data.to_vec();
if let Some(token) = &ct {
tokio::select! {
result = tx.send(data) => {
result.map_err(|_| FoundryLocalError::CommandExecution {
reason: "Push channel closed — session has been stopped".into(),
})
}
_ = token.cancelled() => {
Err(FoundryLocalError::CommandExecution {
reason: "Append cancelled".into(),
})
}
}
} else {
tx.send(data)
.await
.map_err(|_| FoundryLocalError::CommandExecution {
reason: "Push channel closed — session has been stopped".into(),
})
}
}
pub async fn get_stream(&self) -> Result<LiveAudioTranscriptionStream> {
let mut state = self.state.lock().await;
let rx = state
.output_rx
.take()
.ok_or_else(|| FoundryLocalError::Validation {
reason: "No active streaming session, or stream already taken. \
Call start() first and only call get_stream() once."
.into(),
})?;
Ok(LiveAudioTranscriptionStream { rx })
}
pub async fn stop(&self, ct: Option<CancellationToken>) -> Result<()> {
let mut state = self.state.lock().await;
if !state.started || state.stopped {
return Ok(());
}
state.stopped = true;
self.drain_push_loop(&mut state).await;
let stop_result = self.stop_native_session(&state, ct).await;
Self::write_final_result(&stop_result, &state);
self.finalize_state(&mut state);
stop_result?;
Ok(())
}
fn build_start_request(&self, settings: &LiveAudioTranscriptionOptions) -> serde_json::Value {
let mut params = json!({
"Model": self.model_id,
"SampleRate": settings.sample_rate.to_string(),
"Channels": settings.channels.to_string(),
"BitsPerSample": settings.bits_per_sample.to_string(),
});
if let Some(ref lang) = settings.language {
params["Language"] = json!(lang);
}
json!({ "Params": params })
}
async fn await_start(
&self,
start_future: tokio::task::JoinHandle<Result<String>>,
ct: Option<CancellationToken>,
) -> Result<String> {
let join_result = start_future
.await
.map_err(|e| FoundryLocalError::CommandExecution {
reason: format!("Start audio stream task join error: {e}"),
})?;
if let Some(token) = ct {
if token.is_cancelled() {
if let Ok(ref handle) = join_result {
if !handle.is_empty() {
let params = json!({
"Params": { "SessionHandle": handle }
});
let _ = self
.core
.execute_command("audio_stream_stop", Some(¶ms));
}
}
return Err(FoundryLocalError::CommandExecution {
reason: "Start cancelled".into(),
});
}
}
join_result
}
async fn drain_push_loop(&self, state: &mut SessionState) {
state.push_tx.take();
if let Some(handle) = state.push_loop_handle.take() {
let _ = handle.await;
}
}
async fn stop_native_session(
&self,
state: &SessionState,
_ct: Option<CancellationToken>,
) -> Result<String> {
let session_handle = state
.session_handle
.as_ref()
.ok_or_else(|| FoundryLocalError::Internal {
reason: "Session handle missing during stop".into(),
})?
.clone();
let params = json!({ "Params": { "SessionHandle": session_handle } });
let core = Arc::clone(&self.core);
tokio::task::spawn_blocking(move || {
core.execute_command("audio_stream_stop", Some(¶ms))
})
.await
.map_err(|e| FoundryLocalError::CommandExecution {
reason: format!("Stop audio stream task join error: {e}"),
})?
}
fn write_final_result(stop_result: &Result<String>, state: &SessionState) {
let _ = stop_result
.as_ref()
.ok()
.filter(|d| !d.is_empty())
.and_then(|d| serde_json::from_str::<LiveAudioTranscriptionRaw>(d).ok())
.filter(|r| !r.text.is_empty())
.and_then(|raw| {
state.output_tx.as_ref().map(|tx| {
let _ = tx.send(Ok(LiveAudioTranscriptionResponse::from_raw(raw)));
})
});
}
fn finalize_state(&self, state: &mut SessionState) {
state.output_tx.take();
state.session_handle = None;
state.started = false;
}
fn push_loop(
core: Arc<CoreInterop>,
session_handle: String,
mut push_rx: tokio::sync::mpsc::Receiver<Vec<u8>>,
output_tx: tokio::sync::mpsc::UnboundedSender<Result<LiveAudioTranscriptionResponse>>,
) {
while let Some(audio_data) = push_rx.blocking_recv() {
let params = json!({
"Params": { "SessionHandle": &session_handle }
});
let data = match core.execute_command_with_binary(
"audio_stream_push",
Some(¶ms),
&audio_data,
) {
Ok(d) => d,
Err(e) => {
let code = match &e {
FoundryLocalError::CommandExecution { reason } => {
CoreErrorResponse::try_parse(reason)
.map(|ei| ei.code)
.unwrap_or_else(|| "UNKNOWN".into())
}
_ => "UNKNOWN".into(),
};
let _ = output_tx.send(Err(FoundryLocalError::CommandExecution {
reason: format!("Push failed (code={code}): {e}"),
}));
drop(output_tx);
return;
}
};
if let Ok(raw) = serde_json::from_str::<LiveAudioTranscriptionRaw>(&data) {
if !raw.text.is_empty() {
let _ = output_tx.send(Ok(LiveAudioTranscriptionResponse::from_raw(raw)));
}
}
}
}
}
impl Drop for LiveAudioTranscriptionSession {
fn drop(&mut self) {
if let Ok(mut state) = self.state.try_lock() {
state.push_tx.take();
state.output_tx.take();
if state.started && !state.stopped {
if let Some(ref handle) = state.session_handle {
let params = serde_json::json!({
"Params": { "SessionHandle": handle }
});
let _ = self
.core
.execute_command("audio_stream_stop", Some(¶ms));
}
state.session_handle = None;
state.started = false;
state.stopped = true;
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn from_json_parses_text_and_is_final() {
let json = r#"{"is_final":true,"text":"hello world","start_time":null,"end_time":null}"#;
let result = LiveAudioTranscriptionResponse::from_json(json).unwrap();
assert_eq!(result.content.len(), 1);
assert_eq!(result.content[0].text, "hello world");
assert_eq!(result.content[0].transcript, "hello world");
assert!(result.is_final);
}
#[test]
fn from_json_maps_timing_fields() {
let json = r#"{"is_final":false,"text":"partial","start_time":1.5,"end_time":3.0}"#;
let result = LiveAudioTranscriptionResponse::from_json(json).unwrap();
assert_eq!(result.content[0].text, "partial");
assert!(!result.is_final);
assert_eq!(result.start_time, Some(1.5));
assert_eq!(result.end_time, Some(3.0));
}
#[test]
fn from_json_empty_text_parses_successfully() {
let json = r#"{"is_final":true,"text":"","start_time":null,"end_time":null}"#;
let result = LiveAudioTranscriptionResponse::from_json(json).unwrap();
assert_eq!(result.content[0].text, "");
assert!(result.is_final);
}
#[test]
fn from_json_only_start_time_sets_start_time() {
let json = r#"{"is_final":true,"text":"word","start_time":2.0,"end_time":null}"#;
let result = LiveAudioTranscriptionResponse::from_json(json).unwrap();
assert_eq!(result.start_time, Some(2.0));
assert_eq!(result.end_time, None);
assert_eq!(result.content[0].text, "word");
}
#[test]
fn from_json_invalid_json_returns_error() {
let result = LiveAudioTranscriptionResponse::from_json("not valid json");
assert!(result.is_err());
}
#[test]
fn from_json_content_has_text_and_transcript() {
let json = r#"{"is_final":true,"text":"test","start_time":null,"end_time":null}"#;
let result = LiveAudioTranscriptionResponse::from_json(json).unwrap();
assert_eq!(result.content[0].text, "test");
assert_eq!(result.content[0].transcript, "test");
}
#[test]
fn from_json_parses_id_when_present() {
let json =
r#"{"is_final":true,"text":"hi","id":"evt_123","start_time":null,"end_time":null}"#;
let result = LiveAudioTranscriptionResponse::from_json(json).unwrap();
assert_eq!(result.id.as_deref(), Some("evt_123"));
}
#[test]
fn from_json_id_defaults_to_none() {
let json = r#"{"is_final":true,"text":"hi","start_time":null,"end_time":null}"#;
let result = LiveAudioTranscriptionResponse::from_json(json).unwrap();
assert!(result.id.is_none());
}
#[test]
fn options_default_values() {
let options = LiveAudioTranscriptionOptions::default();
assert_eq!(options.sample_rate, 16000);
assert_eq!(options.channels, 1);
assert_eq!(options.bits_per_sample, 16);
assert_eq!(options.language, None);
assert_eq!(options.push_queue_capacity, 100);
}
#[test]
fn core_error_response_try_parse_valid_json() {
let json =
r#"{"code":"ASR_SESSION_NOT_FOUND","message":"Session not found","isTransient":false}"#;
let error = CoreErrorResponse::try_parse(json).unwrap();
assert_eq!(error.code, "ASR_SESSION_NOT_FOUND");
assert_eq!(error.message, "Session not found");
assert!(!error.is_transient);
}
#[test]
fn core_error_response_try_parse_invalid_json_returns_none() {
let result = CoreErrorResponse::try_parse("not json");
assert!(result.is_none());
}
#[test]
fn core_error_response_try_parse_transient_error() {
let json = r#"{"code":"BUSY","message":"Model busy","isTransient":true}"#;
let error = CoreErrorResponse::try_parse(json).unwrap();
assert!(error.is_transient);
}
}