stt-cli 0.2.1

Speech to text Cli using Groq API and OpenAI API
// src/hotkey_service.rs
//
// This module handles global hotkey registration and event handling.

use crate::audio::stream_manager::AudioStreamManager;
use crate::audio_state::RecordingState;
use crate::config::AppConfig;
use crate::providers::TranscriptionProvider;
use crate::shutdown_handler::{ExitPriority, ShutdownManager};
use anyhow::Result;
use global_hotkey::{
    hotkey::{HotKey, Modifiers},
    GlobalHotKeyEvent, GlobalHotKeyManager,
};
use tower::ServiceExt;
use tower::Service;
use std::str::FromStr;
use std::sync::{Arc, Mutex};
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::Duration;
use tokio::sync::{broadcast, mpsc, Mutex as TokioMutex};
use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken;
use tracing::{debug, error, info, warn};
// Import crossbeam error types
use crossbeam_channel::RecvTimeoutError;

/// Context needed for the hotkey service to start pipeline tasks
pub struct AppContext {
    pub stream_manager: Arc<TokioMutex<AudioStreamManager>>,
    pub provider: Arc<TokioMutex<Box<dyn TranscriptionProvider + Send + Sync>>>,
    pub pipeline_cancel_token: Mutex<Option<CancellationToken>>,
    pub pipeline_handle: Mutex<Option<JoinHandle<()>>>,
    pub config: AppConfig,
}

pub struct HotkeyService {
    manager: GlobalHotKeyManager,
    recording_state: Arc<RecordingState>,
    registered_hotkeys: Vec<HotKey>,
    shutdown_rx: broadcast::Receiver<()>,
    app_context: Option<Arc<AppContext>>,
    blocking_thread: Mutex<Option<std::thread::JoinHandle<()>>>,
    thread_running: Arc<AtomicBool>,
}

impl HotkeyService {
    pub fn new(
        recording_state: Arc<RecordingState>,
        shutdown_rx: broadcast::Receiver<()>,
    ) -> Result<Self> {
        let manager = GlobalHotKeyManager::new()?;
        Ok(Self {
            manager,
            recording_state,
            registered_hotkeys: Vec::new(),
            shutdown_rx,
            app_context: None,
            blocking_thread: Mutex::new(None),
            thread_running: Arc::new(AtomicBool::new(false)),
        })
    }

    pub fn register_hotkey(&mut self, hotkey_str: &str) -> Result<()> {
        let hotkey = HotKey::from_str(hotkey_str)?;
        self.manager.register(hotkey.clone())?;
        self.registered_hotkeys.push(hotkey);
        Ok(())
    }

    /// Set the application context for pipeline operations
    pub fn set_app_context(&mut self, context: Arc<AppContext>) {
        self.app_context = Some(context);
    }

    pub async fn run(&self) {
        info!("Hotkey service started");
        let event_receiver = GlobalHotKeyEvent::receiver();

        // Forward blocking hotkey events into an async channel
        let (async_tx, mut async_rx) = mpsc::unbounded_channel();
        let mut shutdown_rx = self.shutdown_rx.resubscribe();
        
        // Set the thread running flag
        self.thread_running.store(true, Ordering::SeqCst);
        
        // Create a clone of the termination flag for the thread
        let thread_running = self.thread_running.clone();
        
        let blocking_thread = std::thread::spawn(move || {
            tracing::info!("[hotkey_blocking_thread] Started");
            
            // Use a separate thread to receive hotkey events
            let event_thread = std::thread::spawn(move || {
                while thread_running.load(Ordering::SeqCst) {
                    // Use a timeout to periodically check the termination flag
                    match event_receiver.recv_timeout(Duration::from_millis(100)) {
                        Ok(event) => {
                            if !thread_running.load(Ordering::SeqCst) {
                                break;
                            }
                            let _ = async_tx.send(event);
                        }
                        Err(RecvTimeoutError::Timeout) => {
                            // Just a timeout, check the termination flag and continue
                            continue;
                        }
                        Err(RecvTimeoutError::Disconnected) => {
                            // Channel disconnected, exit the loop
                            break;
                        }
                    }
                }
                tracing::info!("[hotkey_event_thread] Exiting");
            });
            
            // Wait for the event thread to finish
            let _ = event_thread.join();
            tracing::info!("[hotkey_blocking_thread] Exiting");
        });
        
        *self.blocking_thread.lock().unwrap() = Some(blocking_thread);

        // Process events or exit on shutdown
        loop {
            tokio::select! {
                _ = shutdown_rx.recv() => {
                    info!("Hotkey service received shutdown signal");
                    // Signal the blocking thread to terminate
                    self.thread_running.store(false, Ordering::SeqCst);
                    break;
                }
                maybe_event = async_rx.recv() => {
                    match maybe_event {
                        Some(event) => {
                            if event.state == global_hotkey::HotKeyState::Pressed {
                                debug!("Hotkey pressed");
                                let is_now_active = self.recording_state.toggle();
                                info!(
                                    "Recording {}",
                                    if is_now_active { "started" } else { "paused" }
                                );
                                
                                // Start or stop pipeline tasks if app context is available
                                if let Some(app_context) = &self.app_context {
                                    if is_now_active {
                                        self.start_pipeline_task(app_context).await;
                                    } else {
                                        self.stop_pipeline_task(app_context).await;
                                    }
                                }
                            }
                        }
                        None => {
                            // Channel closed, exit the loop
                            self.thread_running.store(false, Ordering::SeqCst);
                            break;
                        }
                    }
                }
            }
        }

        info!("Hotkey service stopped");
    }

    /// Start the pipeline task when hotkey is pressed
    async fn start_pipeline_task(&self, app_context: &Arc<AppContext>) {
        // Only start if no pipeline is running
        let mut token_guard = app_context.pipeline_cancel_token.lock().unwrap();
        let mut handle_guard = app_context.pipeline_handle.lock().unwrap();
        
        if token_guard.is_none() && handle_guard.is_none() {
            // Create new cancellation token
            let token = CancellationToken::new();
            *token_guard = Some(token.clone());
            
            // Get audio receiver
            let stream_manager = &app_context.stream_manager;
            let bcast_receiver = match stream_manager.lock().await.get_receiver() {
                Some(receiver) => receiver,
                None => {
                    error!("Failed to get audio receiver from stream manager");
                    return;
                }
            };
            
            // Set up pipeline
            use crate::audio::constants::{SAMPLE_RATE, CHUNK_DURATION_MS};
            use crate::pipeline::chunking::ChunkingManager;
            use crate::pipeline::layers::wav_conversion::WavConversionLayer;
            use crate::pipeline::services::transcription::TranscriptionService;
            use crate::pipeline::types::{AudioChunk, AudioRequest, AudioResponse, ProcessedData};
            use crate::platform::{EnigoTextInserter, InsertOptions, PlatformTextInserterHandler};
            use crate::transcription::result_handler::TranscriptionResultHandler;
            use std::time::{Duration, SystemTime};
            use tower::ServiceBuilder;
            
            // Set up channels
            let (mpsc_tx, mut mpsc_rx) = mpsc::channel(8);
            let mut bcast_receiver_clone = bcast_receiver.resubscribe();
            
            // Forward audio chunks
            tokio::spawn(async move {
                while let Ok(chunk) = bcast_receiver_clone.recv().await {
                    if mpsc_tx.send(chunk).await.is_err() {
                        break;
                    }
                }
            });
            
            // Set up text inserter
            let provider = app_context.provider.clone();
            let inserter = match EnigoTextInserter::new() {
                Ok(inserter) => Box::new(inserter),
                Err(e) => {
                    error!("Failed to create text inserter: {}", e);
                    return;
                }
            };
            
            let handler = Arc::new(TokioMutex::new(
                Box::new(PlatformTextInserterHandler::new(inserter)) as Box<dyn TranscriptionResultHandler + Send + Sync>
            ));
            
            let config = app_context.config.clone();
            
            info!("Hotkey pressed: starting pipeline task");
            
            // Start pipeline task
            let pipeline_handle = tokio::spawn(async move {
                let mut chunking_manager = ChunkingManager::new(SAMPLE_RATE, Duration::from_millis(CHUNK_DURATION_MS));
                let mut pipeline_service = ServiceBuilder::new()
                    .layer(WavConversionLayer)
                    .service(TranscriptionService::new(provider.clone()));
                
                loop {
                    tokio::select! {
                        _ = token.cancelled() => {
                            info!("Pipeline driver received shutdown signal");
                            break;
                        }
                        opt = mpsc_rx.recv() => {
                            match opt {
                                Some(samples) => {
                                    let chunks = chunking_manager.add_samples(&samples);
                                    for complete_chunk in chunks {
                                        let audio_chunk = AudioChunk {
                                            timestamp: SystemTime::now(),
                                            data: complete_chunk,
                                            is_speech: None,
                                        };
                                        let request = AudioRequest(audio_chunk);
                                        if let Err(e) = pipeline_service.ready().await {
                                            error!("Pipeline service not ready: {}", e);
                                            continue;
                                        }
                                        match pipeline_service.call(request).await {
                                            Ok(AudioResponse { result_data: ProcessedData::Transcription(text), .. }) => {
                                                let options = InsertOptions {
                                                    auto_capitalize: config.auto_capitalize,
                                                    auto_punctuate: config.auto_punctuate,
                                                    ..Default::default()
                                                };
                                                let mut guard = handler.lock().await;
                                                if let Err(e) = guard.handle_result(&text, options) {
                                                    error!("Failed to handle transcription result: {}", e);
                                                }
                                            }
                                            Ok(AudioResponse { result_data: other_data, .. }) => {
                                                debug!("Received non-transcription data: {:?}", other_data);
                                            }
                                            Err(e) => {
                                                error!("Pipeline call failed: {}", e);
                                            }
                                        }
                                    }
                                }
                                None => {
                                    info!("Audio source channel closed");
                                    break;
                                }
                            }
                        }
                    }
                }
                info!("Pipeline driver task finished.");
            });
            
            *handle_guard = Some(pipeline_handle);
        }
    }
    
    /// Stop the pipeline task when hotkey is released
    async fn stop_pipeline_task(&self, app_context: &Arc<AppContext>) {
        let mut token_guard = app_context.pipeline_cancel_token.lock().unwrap();
        let mut handle_guard = app_context.pipeline_handle.lock().unwrap();
        
        // Cancel the task if running
        if let Some(token) = token_guard.take() {
            token.cancel();
            info!("Hotkey released: cancellation signal sent to pipeline");
        }
        
        // Wait for task to complete
        if let Some(handle) = handle_guard.take() {
            use tokio::time::{timeout, Duration};
            match timeout(Duration::from_secs(2), handle).await {
                Ok(Ok(_)) => info!("Pipeline task finished after cancellation"),
                Ok(Err(e)) => error!("Pipeline task panicked: {}", e),
                Err(_) => warn!("Timeout waiting for pipeline task to finish"),
            }
        }
    }

    /// Register a shutdown handler for the hotkey service
    pub fn register_shutdown_handler(service: Arc<Self>, shutdown_manager: &ShutdownManager) {
        shutdown_manager.register(
            "Stop hotkey service",
            ExitPriority::Normal,
            move || async move {
                // Signal the blocking thread to terminate
                service.thread_running.store(false, Ordering::SeqCst);
                
                // Unregister all hotkeys
                if let Err(e) = service.manager.unregister_all(&service.registered_hotkeys) {
                    debug!("Failed to unregister hotkeys during shutdown: {}", e);
                }
                
                // Take ownership of the thread handle outside of the async block
                let handle = {
                    let mut guard = service.blocking_thread.lock().unwrap();
                    guard.take()
                };
                
                // If we have a handle, join it with a timeout
                if let Some(handle) = handle {
                    // Create a oneshot channel to signal when the join is complete
                    let (tx, rx) = tokio::sync::oneshot::channel();
                    
                    // Spawn a thread to join the blocking thread
                    std::thread::spawn(move || {
                        let join_result = handle.join();
                        let _ = tx.send(join_result);
                    });
                    
                    // Wait for the join with a timeout
                    match tokio::time::timeout(Duration::from_secs(2), rx).await {
                        Ok(Ok(Ok(()))) => debug!("Hotkey blocking thread joined successfully"),
                        Ok(Ok(Err(e))) => error!("Hotkey blocking thread panicked: {:?}", e),
                        Ok(Err(_)) => error!("Failed to receive join result"),
                        Err(_) => warn!("Timeout waiting for hotkey blocking thread to join"),
                    }
                }
                
                info!("Hotkey service shutdown handler completed");
            },
        );
    }
}