use std::path::Path;
use std::sync::Arc;
use serde_json::{json, Value};
use crate::detail::core_interop::CoreInterop;
use crate::error::{FoundryLocalError, Result};
use super::json_stream::JsonStream;
#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
pub struct TranscriptionSegment {
pub id: i32,
pub seek: i32,
pub start: f64,
pub end: f64,
pub text: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub tokens: Option<Vec<i32>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub avg_logprob: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub compression_ratio: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub no_speech_prob: Option<f64>,
}
#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
pub struct TranscriptionWord {
pub word: String,
pub start: f64,
pub end: f64,
}
#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
pub struct AudioTranscriptionResponse {
pub text: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub language: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub duration: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub segments: Option<Vec<TranscriptionSegment>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub words: Option<Vec<TranscriptionWord>>,
}
#[derive(Debug, Clone, Default)]
pub struct AudioClientSettings {
language: Option<String>,
temperature: Option<f64>,
}
impl AudioClientSettings {
fn serialize(&self, model_id: &str, file_name: &str) -> Value {
let mut map = serde_json::Map::new();
map.insert("Model".into(), json!(model_id));
map.insert("FileName".into(), json!(file_name));
if let Some(ref lang) = self.language {
map.insert("Language".into(), json!(lang));
}
if let Some(temp) = self.temperature {
map.insert("Temperature".into(), json!(temp));
}
Value::Object(map)
}
}
pub type AudioTranscriptionStream = JsonStream<AudioTranscriptionResponse>;
pub struct AudioClient {
model_id: String,
core: Arc<CoreInterop>,
settings: AudioClientSettings,
}
impl AudioClient {
pub(crate) fn new(model_id: &str, core: Arc<CoreInterop>) -> Self {
Self {
model_id: model_id.to_owned(),
core,
settings: AudioClientSettings::default(),
}
}
pub fn language(mut self, lang: impl Into<String>) -> Self {
self.settings.language = Some(lang.into());
self
}
pub fn temperature(mut self, v: f64) -> Self {
self.settings.temperature = Some(v);
self
}
pub async fn transcribe(
&self,
audio_file_path: impl AsRef<Path>,
) -> Result<AudioTranscriptionResponse> {
let path_str =
audio_file_path
.as_ref()
.to_str()
.ok_or_else(|| FoundryLocalError::Validation {
reason: "audio file path is not valid UTF-8".into(),
})?;
Self::validate_path(path_str)?;
let request = self.settings.serialize(&self.model_id, path_str);
let params = json!({
"Params": {
"OpenAICreateRequest": serde_json::to_string(&request)?
}
});
let raw = self
.core
.execute_command_async("audio_transcribe".into(), Some(params))
.await?;
let parsed: AudioTranscriptionResponse = serde_json::from_str(&raw)?;
Ok(parsed)
}
pub async fn transcribe_streaming(
&self,
audio_file_path: impl AsRef<Path>,
) -> Result<AudioTranscriptionStream> {
let path_str =
audio_file_path
.as_ref()
.to_str()
.ok_or_else(|| FoundryLocalError::Validation {
reason: "audio file path is not valid UTF-8".into(),
})?;
Self::validate_path(path_str)?;
let request = self.settings.serialize(&self.model_id, path_str);
let params = json!({
"Params": {
"OpenAICreateRequest": serde_json::to_string(&request)?
}
});
let rx = self
.core
.execute_command_streaming_channel("audio_transcribe".into(), Some(params))
.await?;
Ok(AudioTranscriptionStream::new(rx))
}
fn validate_path(path: &str) -> Result<()> {
if path.trim().is_empty() {
return Err(FoundryLocalError::Validation {
reason: "audio_file_path must be a non-empty string".into(),
});
}
Ok(())
}
}