use crate::audio::{encode_wav, start_audio_capture};
use crate::config::Config;
use anyhow::{Context, Result};
use cpal::Stream;
use reqwest::Client;
use serde_json::Value;
use std::sync::Arc;
use tokio::fs::OpenOptions;
use tokio::io::AsyncWriteExt;
use tokio::sync::mpsc::{UnboundedReceiver, unbounded_channel};
#[path = "whisper-microphone/mod.rs"]
mod on_device;
#[derive(Debug, Clone)]
pub enum TranscriptionEvent {
Transcription {
chunk_id: usize,
text: String,
},
Error {
chunk_id: usize,
error: String,
},
}
pub struct TranscriptionService {
config: Config,
api_key: Option<String>,
}
impl TranscriptionService {
pub fn new(config: Config, api_key: String) -> Result<Self> {
Ok(Self {
config,
api_key: Some(api_key),
})
}
pub fn new_no_api(config: Config) -> Result<Self> {
Ok(Self { config, api_key: None })
}
pub fn new_on_device(config: Config) -> Result<Self> {
Ok(Self {
config,
api_key: None,
})
}
pub async fn start(&mut self) -> Result<(UnboundedReceiver<TranscriptionEvent>, Stream)> {
let (event_tx, event_rx) = unbounded_channel::<TranscriptionEvent>();
let transcript_sink = if let Some(path) = &self.config.out_file {
let file = OpenOptions::new()
.create(true)
.append(true)
.open(path)
.await
.with_context(|| format!("opening output file {}", path.display()))?;
Some(Arc::new(tokio::sync::Mutex::new(file)))
} else {
None
};
if let Some(on_device_cfg) = self.config.on_device_config().cloned() {
let handle = tokio::runtime::Handle::current();
let stream = on_device::start_on_device_transcription(
on_device_cfg,
event_tx.clone(),
transcript_sink.clone(),
handle,
)?;
return Ok((event_rx, stream));
}
let (sample_tx, mut sample_rx) = unbounded_channel::<Vec<f32>>();
let (_stream, audio_config) = start_audio_capture(sample_tx)?;
let client = Client::new();
let chunk_duration_secs = self.config.chunk_duration_secs.max(1);
let samples_per_chunk = (audio_config.sample_rate as usize)
.saturating_mul(chunk_duration_secs)
.saturating_mul(audio_config.channels as usize);
let model = Arc::new(self.config.model.clone());
let endpoint = Arc::new(self.config.endpoint.clone());
let api_key = self.api_key.clone();
tokio::spawn(async move {
let mut buffer = Vec::with_capacity(samples_per_chunk * 2);
let mut chunk_id = 0usize;
while let Some(data) = sample_rx.recv().await {
buffer.extend(data);
while buffer.len() >= samples_per_chunk {
let chunk_samples = buffer.drain(..samples_per_chunk).collect::<Vec<_>>();
let client = client.clone();
let api_key = api_key.clone();
let sample_rate = audio_config.sample_rate;
let channels = audio_config.channels;
let current_chunk = chunk_id;
let chunk_model = model.clone();
let chunk_endpoint = endpoint.clone();
let chunk_sink = transcript_sink.clone();
let event_sender = event_tx.clone();
tokio::spawn(async move {
match transcribe_chunk(
client,
api_key,
sample_rate,
channels,
chunk_samples,
current_chunk,
chunk_model,
chunk_endpoint,
)
.await
{
Ok(text) => {
let _ = event_sender.send(TranscriptionEvent::Transcription {
chunk_id: current_chunk,
text: text.clone(),
});
if let Some(writer) = chunk_sink {
let record_text = if text.is_empty() {
"<silence>"
} else {
text.as_str()
};
if let Err(err) =
append_transcript(writer, current_chunk, record_text).await
{
let _ = event_sender.send(TranscriptionEvent::Error {
chunk_id: current_chunk,
error: format!("File write failed: {err}"),
});
}
}
}
Err(err) => {
let _ = event_sender.send(TranscriptionEvent::Error {
chunk_id: current_chunk,
error: err.to_string(),
});
}
}
});
chunk_id += 1;
}
}
});
Ok((event_rx, _stream))
}
}
async fn transcribe_chunk(
client: Client,
api_key: Option<String>,
sample_rate: u32,
channels: u16,
samples: Vec<f32>,
chunk_id: usize,
model: Arc<String>,
endpoint: Arc<String>,
) -> Result<String> {
let wav = encode_wav(&samples, sample_rate, channels)?;
let part = reqwest::multipart::Part::bytes(wav)
.file_name(format!("chunk-{chunk_id}.wav"))
.mime_str("audio/wav")?;
let form = reqwest::multipart::Form::new()
.text("model", model.as_ref().clone())
.part("file", part);
let mut req = client.post(endpoint.as_str());
if let Some(key) = api_key.as_ref() {
req = req.bearer_auth(key);
}
let response = req
.multipart(form)
.send()
.await?
.error_for_status()?;
let payload: Value = response.json().await?;
let text = payload
.get("text")
.and_then(|v| v.as_str())
.map(str::trim)
.unwrap_or_default()
.to_string();
Ok(text)
}
pub(super) async fn append_transcript(
writer: Arc<tokio::sync::Mutex<tokio::fs::File>>,
chunk_id: usize,
text: &str,
) -> Result<()> {
let mut guard = writer.lock().await;
let entry = format!("Chunk {chunk_id}: {text}\n");
guard.write_all(entry.as_bytes()).await?;
guard.flush().await?;
Ok(())
}