use std::collections::VecDeque;
use std::path::PathBuf;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::{Arc, Mutex};
use std::thread;
use crate::audio::{process_recorded_audio, RawRecordedAudio};
use super::Transcriber;
const MAX_QUEUE_SIZE: usize = 10;
pub struct QueuedSegment {
pub samples: Vec<f32>,
pub sample_rate: u32,
pub channels: u16,
pub wav_path: Option<PathBuf>,
}
pub trait TranscriptionCallback: Send + Sync + 'static {
fn on_transcription_started(&self);
fn on_transcription_complete(&self, text: String, wav_path: Option<String>);
fn on_transcription_error(&self, error: String);
fn on_transcription_finished(&self);
fn on_queue_update(&self, depth: usize);
}
pub struct TranscriptionQueue {
queue: Arc<Mutex<VecDeque<QueuedSegment>>>,
worker_active: Arc<AtomicBool>,
worker_generation: Arc<AtomicUsize>,
queue_count: Arc<AtomicUsize>,
callback: Arc<Mutex<Option<Arc<dyn TranscriptionCallback>>>>,
}
impl TranscriptionQueue {
pub fn new() -> Self {
Self {
queue: Arc::new(Mutex::new(VecDeque::new())),
worker_active: Arc::new(AtomicBool::new(false)),
worker_generation: Arc::new(AtomicUsize::new(0)),
queue_count: Arc::new(AtomicUsize::new(0)),
callback: Arc::new(Mutex::new(None)),
}
}
pub fn set_callback(&self, callback: Arc<dyn TranscriptionCallback>) {
*self.callback.lock().unwrap() = Some(callback);
}
pub fn clear_callback(&self) {
*self.callback.lock().unwrap() = None;
}
pub fn queue_depth(&self) -> usize {
self.queue_count.load(Ordering::SeqCst)
}
pub fn is_worker_active(&self) -> bool {
self.worker_active.load(Ordering::SeqCst)
}
pub fn enqueue(&self, segment: QueuedSegment) -> bool {
let mut queue = self.queue.lock().unwrap();
if queue.len() >= MAX_QUEUE_SIZE {
return false;
}
queue.push_back(segment);
let depth = queue.len();
self.queue_count.store(depth, Ordering::SeqCst);
if let Some(ref cb) = *self.callback.lock().unwrap() {
cb.on_queue_update(depth);
}
true
}
pub fn start_worker(&self, model_path: PathBuf) {
if self.worker_active.load(Ordering::SeqCst) {
return; }
self.worker_active.store(true, Ordering::SeqCst);
let worker_id = self.worker_generation.fetch_add(1, Ordering::SeqCst) + 1;
let queue = Arc::clone(&self.queue);
let worker_active = Arc::clone(&self.worker_active);
let worker_generation = Arc::clone(&self.worker_generation);
let queue_count = Arc::clone(&self.queue_count);
let callback = Arc::clone(&self.callback);
thread::spawn(move || {
let mut transcriber = Transcriber::with_model_path(model_path.clone());
if model_path.exists() {
if let Err(e) = transcriber.load_model() {
tracing::error!("[TranscriptionQueue] Failed to load model: {}", e);
} else {
tracing::info!(
"[TranscriptionQueue] Model loaded at startup from: {}",
model_path.display()
);
}
} else {
tracing::warn!(
"[TranscriptionQueue] Model not found at startup: {}",
model_path.display()
);
}
loop {
if worker_generation.load(Ordering::SeqCst) != worker_id {
break;
}
if !worker_active.load(Ordering::SeqCst) {
let remaining = {
let q = queue.lock().unwrap();
q.len()
};
if remaining == 0 {
break;
}
}
let segment = {
let mut q = queue.lock().unwrap();
let seg = q.pop_front();
let depth = q.len();
queue_count.store(depth, Ordering::SeqCst);
if seg.is_some() {
if let Some(ref cb) = *callback.lock().unwrap() {
cb.on_queue_update(depth);
}
}
seg
};
match segment {
Some(seg) => {
let raw_audio = RawRecordedAudio {
samples: seg.samples,
sample_rate: seg.sample_rate,
channels: seg.channels,
};
let wav_path_str = seg
.wav_path
.as_ref()
.map(|p| p.to_string_lossy().to_string());
match process_recorded_audio(raw_audio) {
Ok(processed) => {
if let Some(ref cb) = *callback.lock().unwrap() {
cb.on_transcription_started();
}
match transcriber.transcribe(&processed) {
Ok(text) => {
if let Some(ref cb) = *callback.lock().unwrap() {
cb.on_transcription_complete(text, wav_path_str);
}
}
Err(e) => {
if let Some(ref cb) = *callback.lock().unwrap() {
cb.on_transcription_error(e);
}
}
}
if let Some(ref cb) = *callback.lock().unwrap() {
cb.on_transcription_finished();
}
}
Err(e) => {
if let Some(ref cb) = *callback.lock().unwrap() {
cb.on_transcription_error(e);
}
}
}
}
None => {
thread::sleep(std::time::Duration::from_millis(50));
}
}
}
if worker_generation.load(Ordering::SeqCst) == worker_id {
worker_active.store(false, Ordering::SeqCst);
}
tracing::info!("[TranscriptionQueue] Worker thread exiting");
});
}
pub fn stop_worker(&self) {
self.worker_active.store(false, Ordering::SeqCst);
}
pub fn restart_worker(&self, model_path: PathBuf) {
self.worker_generation.fetch_add(1, Ordering::SeqCst);
self.worker_active.store(false, Ordering::SeqCst);
self.start_worker(model_path);
}
pub fn clear(&self) {
let mut queue = self.queue.lock().unwrap();
queue.clear();
self.queue_count.store(0, Ordering::SeqCst);
if let Some(ref cb) = *self.callback.lock().unwrap() {
cb.on_queue_update(0);
}
}
}
impl Default for TranscriptionQueue {
fn default() -> Self {
Self::new()
}
}