use crate::stt::SttProvider;
use crate::tts::encode_pcm_f32_to_wav_pcm16;
use crate::{Result, VoiceConfig, VoiceError};
use async_trait::async_trait;
use block2::{DynBlock, RcBlock};
use objc2::rc::Retained;
use objc2::AllocAnyThread;
use objc2_avf_audio::{AVAudioFormat, AVAudioPCMBuffer};
use objc2_foundation::{NSError, NSLocale, NSString, NSURL};
use objc2_speech::{
SFSpeechAudioBufferRecognitionRequest, SFSpeechRecognitionRequest, SFSpeechRecognitionResult,
SFSpeechRecognizer, SFSpeechRecognizerAuthorizationStatus, SFSpeechURLRecognitionRequest,
};
use std::path::PathBuf;
use std::sync::mpsc::sync_channel;
use tempfile::NamedTempFile;
use tokio::sync::mpsc::UnboundedSender;
#[derive(Debug, Clone)]
pub struct AppleSpeechPartial {
pub text: String,
}
#[derive(Debug, Clone, Default)]
pub struct AppleSpeechSttProvider {
locale: Option<String>,
on_device_only: bool,
}
impl AppleSpeechSttProvider {
pub fn new() -> Self {
Self {
locale: None,
on_device_only: true,
}
}
pub fn from_config(config: &VoiceConfig) -> Self {
let locale = if config.language.is_empty() {
None
} else {
Some(config.language.clone())
};
Self {
locale,
on_device_only: true,
}
}
pub fn with_locale(mut self, locale: impl Into<String>) -> Self {
self.locale = Some(locale.into());
self
}
pub fn allow_server(mut self) -> Self {
self.on_device_only = false;
self
}
pub async fn transcribe_streaming<F>(
&self,
samples: &[f32],
sample_rate: u32,
mut on_partial: F,
) -> Result<String>
where
F: FnMut(AppleSpeechPartial) + Send,
{
if samples.is_empty() {
return Ok(String::new());
}
let (event_tx, mut event_rx) = tokio::sync::mpsc::unbounded_channel::<StreamEvent>();
let samples_owned: Vec<f32> = samples.to_vec();
let locale = self.locale.clone();
let on_device_only = self.on_device_only;
let _join = tokio::task::spawn_blocking(move || {
stream_blocking(samples_owned, sample_rate, locale, on_device_only, event_tx);
});
loop {
match event_rx.recv().await {
Some(StreamEvent::Partial(p)) => on_partial(p),
Some(StreamEvent::Final(text)) => return Ok(text),
Some(StreamEvent::Error(msg)) => return Err(VoiceError::Stt(msg)),
None => {
return Err(VoiceError::Stt(
"apple speech: streaming task ended without a result".into(),
))
}
}
}
}
}
enum StreamEvent {
Partial(AppleSpeechPartial),
Final(String),
Error(String),
}
#[async_trait]
impl SttProvider for AppleSpeechSttProvider {
async fn transcribe(&self, samples: &[f32], sample_rate: u32) -> Result<String> {
if samples.is_empty() {
return Ok(String::new());
}
let wav_bytes = encode_pcm_f32_to_wav_pcm16(samples, sample_rate)
.map_err(|e| VoiceError::Stt(format!("apple speech: wav encode: {e}")))?;
let mut tmp = NamedTempFile::with_suffix(".wav")
.map_err(|e| VoiceError::Stt(format!("apple speech: tempfile: {e}")))?;
{
let f = tmp.as_file_mut();
std::io::Write::write_all(f, &wav_bytes)
.map_err(|e| VoiceError::Stt(format!("apple speech: tempfile write: {e}")))?;
f.sync_all()
.map_err(|e| VoiceError::Stt(format!("apple speech: tempfile sync: {e}")))?;
}
let path = tmp.path().to_path_buf();
let locale = self.locale.clone();
let on_device_only = self.on_device_only;
let result =
tokio::task::spawn_blocking(move || transcribe_blocking(path, locale, on_device_only))
.await
.map_err(|e| VoiceError::Stt(format!("apple speech: join error: {e}")))??;
drop(tmp);
Ok(result)
}
}
fn transcribe_blocking(
wav_path: PathBuf,
locale: Option<String>,
on_device_only: bool,
) -> Result<String> {
let recognizer: Retained<SFSpeechRecognizer> = match locale {
Some(loc) => {
let ns_loc_id = NSString::from_str(&loc);
let alloc = NSLocale::alloc();
let ns_locale: Retained<NSLocale> =
NSLocale::initWithLocaleIdentifier(alloc, &ns_loc_id);
unsafe { SFSpeechRecognizer::initWithLocale(SFSpeechRecognizer::alloc(), &ns_locale) }
.ok_or_else(|| {
VoiceError::Stt(format!(
"apple speech: locale '{loc}' not supported on this device"
))
})?
}
None => unsafe { SFSpeechRecognizer::new() },
};
let available: bool = unsafe { recognizer.isAvailable() };
if !available {
return Err(VoiceError::Stt(
"apple speech: recognizer unavailable — locale may lack on-device \
support, network may be unreachable, or Speech Recognition may be \
disabled in System Settings > Privacy & Security"
.into(),
));
}
let status = unsafe { SFSpeechRecognizer::authorizationStatus() };
if status != SFSpeechRecognizerAuthorizationStatus::Authorized {
let reason = match status {
SFSpeechRecognizerAuthorizationStatus::NotDetermined => {
"not yet requested — host app must call \
SFSpeechRecognizer.requestAuthorization at startup"
}
SFSpeechRecognizerAuthorizationStatus::Denied => {
"denied by user — re-grant in System Settings > Privacy & \
Security > Speech Recognition"
}
SFSpeechRecognizerAuthorizationStatus::Restricted => {
"restricted by parental controls or MDM"
}
SFSpeechRecognizerAuthorizationStatus::Authorized => "authorized", _ => "unknown status",
};
return Err(VoiceError::Stt(format!(
"apple speech: authorization {reason}. Note: bundle Info.plist \
must declare NSSpeechRecognitionUsageDescription or the app \
will crash on first request."
)));
}
let ns_path = NSString::from_str(&wav_path.to_string_lossy());
let url: Retained<NSURL> = NSURL::fileURLWithPath(&ns_path);
let request: Retained<SFSpeechURLRecognitionRequest> = unsafe {
let alloc = SFSpeechURLRecognitionRequest::alloc();
SFSpeechURLRecognitionRequest::initWithURL(alloc, &url)
};
unsafe {
request.setShouldReportPartialResults(false);
if on_device_only {
request.setRequiresOnDeviceRecognition(true);
}
}
enum TaskOutcome {
Final(String),
Failed(String),
}
let (tx, rx) = sync_channel::<TaskOutcome>(1);
let tx_clone = tx.clone();
let handler: RcBlock<dyn Fn(*mut SFSpeechRecognitionResult, *mut NSError)> = RcBlock::new(
move |result_ptr: *mut SFSpeechRecognitionResult, err_ptr: *mut NSError| {
if !err_ptr.is_null() {
let msg = unsafe { (*err_ptr).localizedDescription() };
let _ = tx_clone.try_send(TaskOutcome::Failed(format!(
"apple speech: recognition error: {}",
msg.to_string()
)));
return;
}
if result_ptr.is_null() {
return;
}
let result: &SFSpeechRecognitionResult = unsafe { &*result_ptr };
let is_final: bool = unsafe { result.isFinal() };
if !is_final {
return;
}
let transcription = unsafe { result.bestTranscription() };
let formatted = unsafe { transcription.formattedString() };
let _ = tx_clone.try_send(TaskOutcome::Final(formatted.to_string()));
},
);
let handler_ref: &DynBlock<dyn Fn(*mut SFSpeechRecognitionResult, *mut NSError)> = &*handler;
let request_super: &SFSpeechRecognitionRequest = &**request;
let task =
unsafe { recognizer.recognitionTaskWithRequest_resultHandler(request_super, handler_ref) };
let outcome = rx.recv_timeout(std::time::Duration::from_secs(60));
let result_str = match outcome {
Ok(TaskOutcome::Final(s)) => Ok(s),
Ok(TaskOutcome::Failed(msg)) => {
unsafe { task.cancel() };
Err(VoiceError::Stt(msg))
}
Err(_) => {
unsafe { task.cancel() };
Err(VoiceError::Stt(
"apple speech: recognition timed out (60 s)".into(),
))
}
};
drop(task);
result_str
}
fn stream_blocking(
samples: Vec<f32>,
sample_rate: u32,
locale: Option<String>,
on_device_only: bool,
event_tx: UnboundedSender<StreamEvent>,
) {
macro_rules! abort {
($msg:expr) => {{
let _ = event_tx.send(StreamEvent::Error($msg.into()));
return;
}};
}
let recognizer: Retained<SFSpeechRecognizer> = match locale {
Some(loc) => {
let ns_loc_id = NSString::from_str(&loc);
let alloc = NSLocale::alloc();
let ns_locale: Retained<NSLocale> =
NSLocale::initWithLocaleIdentifier(alloc, &ns_loc_id);
match unsafe {
SFSpeechRecognizer::initWithLocale(SFSpeechRecognizer::alloc(), &ns_locale)
} {
Some(r) => r,
None => abort!(format!(
"apple speech: locale '{loc}' not supported on this device"
)),
}
}
None => unsafe { SFSpeechRecognizer::new() },
};
if !unsafe { recognizer.isAvailable() } {
abort!(
"apple speech: recognizer unavailable — locale may lack on-device \
support, network may be unreachable, or Speech Recognition may be \
disabled in System Settings > Privacy & Security"
);
}
let status = unsafe { SFSpeechRecognizer::authorizationStatus() };
if status != SFSpeechRecognizerAuthorizationStatus::Authorized {
abort!(format!(
"apple speech: authorization status not Authorized; host must call \
SFSpeechRecognizer.requestAuthorization at startup and the bundle \
must declare NSSpeechRecognitionUsageDescription in Info.plist \
(status={status:?})"
));
}
let format_alloc = AVAudioFormat::alloc();
let format: Retained<AVAudioFormat> = match unsafe {
AVAudioFormat::initStandardFormatWithSampleRate_channels(
format_alloc,
sample_rate as f64,
1,
)
} {
Some(f) => f,
None => abort!(format!(
"apple speech: cannot construct AVAudioFormat for sample_rate={sample_rate} mono"
)),
};
let frame_capacity = samples.len() as u32;
let buffer_alloc = AVAudioPCMBuffer::alloc();
let buffer: Retained<AVAudioPCMBuffer> = match unsafe {
AVAudioPCMBuffer::initWithPCMFormat_frameCapacity(buffer_alloc, &format, frame_capacity)
} {
Some(b) => b,
None => abort!("apple speech: cannot allocate AVAudioPCMBuffer"),
};
unsafe {
buffer.setFrameLength(frame_capacity);
let cd: *mut std::ptr::NonNull<f32> = buffer.floatChannelData();
if cd.is_null() {
abort!("apple speech: AVAudioPCMBuffer.floatChannelData returned null");
}
let ch0_ptr: *mut f32 = (*cd).as_ptr();
std::ptr::copy_nonoverlapping(samples.as_ptr(), ch0_ptr, samples.len());
}
let request: Retained<SFSpeechAudioBufferRecognitionRequest> = unsafe {
let alloc = SFSpeechAudioBufferRecognitionRequest::alloc();
SFSpeechAudioBufferRecognitionRequest::init(alloc)
};
unsafe {
request.setShouldReportPartialResults(true);
if on_device_only {
request.setRequiresOnDeviceRecognition(true);
}
request.appendAudioPCMBuffer(&buffer);
request.endAudio();
}
let (done_tx, done_rx) = std::sync::mpsc::channel::<()>();
let done_tx_block = done_tx.clone();
let event_tx_block = event_tx.clone();
let handler: RcBlock<dyn Fn(*mut SFSpeechRecognitionResult, *mut NSError)> = RcBlock::new(
move |result_ptr: *mut SFSpeechRecognitionResult, err_ptr: *mut NSError| {
if !err_ptr.is_null() {
let msg = unsafe { (*err_ptr).localizedDescription() };
let _ = event_tx_block.send(StreamEvent::Error(format!(
"apple speech: recognition error: {}",
msg.to_string()
)));
let _ = done_tx_block.send(());
return;
}
if result_ptr.is_null() {
return;
}
let result: &SFSpeechRecognitionResult = unsafe { &*result_ptr };
let transcription = unsafe { result.bestTranscription() };
let formatted = unsafe { transcription.formattedString() };
let text = formatted.to_string();
let is_final = unsafe { result.isFinal() };
if is_final {
let _ = event_tx_block.send(StreamEvent::Final(text));
let _ = done_tx_block.send(());
} else {
let _ = event_tx_block.send(StreamEvent::Partial(AppleSpeechPartial { text }));
}
},
);
let handler_ref: &DynBlock<dyn Fn(*mut SFSpeechRecognitionResult, *mut NSError)> = &*handler;
let request_super: &SFSpeechRecognitionRequest = &**request;
drop(done_tx);
let task =
unsafe { recognizer.recognitionTaskWithRequest_resultHandler(request_super, handler_ref) };
if done_rx
.recv_timeout(std::time::Duration::from_secs(60))
.is_err()
{
let _ = event_tx.send(StreamEvent::Error(
"apple speech: streaming recognition timed out (60 s)".into(),
));
}
unsafe { task.cancel() };
drop(task);
}