rustvani 0.1.6

Voice AI framework for Rust — real-time speech pipelines with STT, LLM, TTS, and Dhara conversation flows
use std::sync::Arc;

use tokio::sync::mpsc;

use crate::frames::{AudioRawData, Frame, FrameDirection, FrameProcessor};
use crate::transport::{BaseTransport, OutputMessage, TransportParams};

// ---------------------------------------------------------------------------
// ChannelMessage
// ---------------------------------------------------------------------------

/// Messages flowing through the channel transport.
///
/// This is the bidirectional counterpart to [`OutputMessage`]: it is used
/// for *both* incoming data (outside world → pipeline) and outgoing data
/// (pipeline → outside world).
#[derive(Debug, Clone)]
pub enum ChannelMessage {
    /// Raw PCM audio bytes.
    Audio(Vec<u8>),
    /// Serialised JSON text frame — used by RAVI for protocol messages.
    Text(String),
    /// Interruption signal.
    Interruption,
}

// ---------------------------------------------------------------------------
// ChannelTransport
// ---------------------------------------------------------------------------

pub struct ChannelTransport {
    base: Arc<BaseTransport>,
    incoming_rx: std::sync::Mutex<Option<mpsc::Receiver<ChannelMessage>>>,
    audio_out_rx: std::sync::Mutex<Option<mpsc::Receiver<OutputMessage>>>,
    sample_rate: u32,
    channels: u16,
}

const AUDIO_OUT_CHANNEL_CAP: usize = 150;

impl ChannelTransport {
    pub fn new(
        name: &str,
        params: TransportParams,
        incoming_rx: mpsc::Receiver<ChannelMessage>,
    ) -> Self {
        let base = Arc::new(BaseTransport::new(name, params.clone()));

        let (audio_out_tx, audio_out_rx) = mpsc::channel::<OutputMessage>(AUDIO_OUT_CHANNEL_CAP);
        base.set_audio_out_tx(audio_out_tx);

        let sample_rate = params.audio_in_sample_rate.unwrap_or(16_000);
        let channels = params.audio_in_channels.max(1);

        Self {
            base,
            incoming_rx: std::sync::Mutex::new(Some(incoming_rx)),
            audio_out_rx: std::sync::Mutex::new(Some(audio_out_rx)),
            sample_rate,
            channels,
        }
    }

    /// The input FrameProcessor — place first in the pipeline.
    pub fn input(&self) -> FrameProcessor {
        self.base.input()
    }

    /// The output FrameProcessor — place last in the pipeline.
    pub fn output(&self) -> FrameProcessor {
        self.base.output()
    }

    /// Drive the channel transport until either direction closes.
    ///
    /// Arm 1 — `incoming_rx.recv()`: incoming messages from the outside world.
    ///
    ///   - `Audio(bytes)`  → raw PCM audio → pipeline via `push_audio_frame`.
    ///   - `Text(json)`    → parsed as RAVI or legacy interruption → pushed
    ///                       into the pipeline via `push_tx`.
    ///   - `Interruption`  → `InterruptionFrame` downstream.
    ///
    /// Arm 2 — `audio_out_rx.recv()`: outgoing pipeline messages.
    ///
    ///   - `Audio(bytes)`  → `ChannelMessage::Audio`.
    ///   - `Text(json)`    → `ChannelMessage::Text` (RAVI protocol messages).
    ///   - `Interruption`  → drains stale audio, then sends
    ///                       `ChannelMessage::Interruption`.
    pub async fn run(
        &self,
        push_tx: mpsc::Sender<(Frame, FrameDirection)>,
        outgoing_tx: mpsc::Sender<ChannelMessage>,
    ) {
        let mut incoming_rx = self
            .incoming_rx
            .lock()
            .unwrap()
            .take()
            .expect("run called more than once on the same ChannelTransport");

        let mut audio_out_rx = self
            .audio_out_rx
            .lock()
            .unwrap()
            .take()
            .expect("run called more than once on the same ChannelTransport");

        let base = self.base.clone();
        let sample_rate = self.sample_rate;
        let channels = self.channels;

        loop {
            tokio::select! {
                // ----------------------------------------------------------------
                // Arm 1: incoming messages → pipeline
                // ----------------------------------------------------------------
                msg = incoming_rx.recv() => {
                    match msg {
                        Some(ChannelMessage::Audio(bytes)) => {
                            let data = AudioRawData::new(bytes, sample_rate, channels);
                            base.push_audio_frame(data).await;
                        }

                        Some(ChannelMessage::Text(text)) => {
                            handle_incoming_text(&text, &push_tx).await;
                        }

                        Some(ChannelMessage::Interruption) => {
                            let _ = push_tx
                                .send((Frame::interruption(), FrameDirection::Downstream))
                                .await;
                        }

                        None => {
                            log::debug!("ChannelTransport: incoming channel closed");
                            break;
                        }
                    }
                }

                // ----------------------------------------------------------------
                // Arm 2: outgoing pipeline messages → outside world
                // ----------------------------------------------------------------
                output_msg = audio_out_rx.recv() => {
                    match output_msg {
                        Some(OutputMessage::Audio(bytes)) => {
                            if outgoing_tx.send(ChannelMessage::Audio(bytes)).await.is_err() {
                                log::warn!("ChannelTransport: outgoing channel closed");
                                break;
                            }
                        }

                        Some(OutputMessage::Text(json)) => {
                            if outgoing_tx.send(ChannelMessage::Text(json)).await.is_err() {
                                log::warn!("ChannelTransport: outgoing channel closed");
                                break;
                            }
                        }

                        Some(OutputMessage::Interruption) => {
                            // Drain stale audio chunks queued before the marker.
                            while let Ok(queued) = audio_out_rx.try_recv() {
                                match queued {
                                    OutputMessage::Audio(_) => {}    // discard
                                    OutputMessage::Interruption => break,
                                    OutputMessage::Text(_) => {}     // discard
                                }
                            }

                            if outgoing_tx.send(ChannelMessage::Interruption).await.is_err() {
                                log::warn!("ChannelTransport: outgoing channel closed");
                                break;
                            }
                            log::debug!("ChannelTransport: sent interruption");
                        }

                        None => {
                            log::debug!("ChannelTransport: audio out channel closed");
                            break;
                        }
                    }
                }
            }
        }

        let _ = push_tx
            .send((Frame::end(), FrameDirection::Downstream))
            .await;
    }
}

// ---------------------------------------------------------------------------
// Incoming text message handler
// ---------------------------------------------------------------------------

/// Parse an incoming text message and push the appropriate frame.
///
/// Two protocols are recognised:
///
/// 1. **RAVI** (`label == "ravi"`) — parsed into a `RaviClientMessage`
///    frame and sent downstream.
///
/// 2. **Legacy interruption** (`type == "client_interruption"`) — kept for
///    backward-compatibility.
async fn handle_incoming_text(
    text: &str,
    push_tx: &mpsc::Sender<(Frame, FrameDirection)>,
) {
    let Ok(msg) = serde_json::from_str::<serde_json::Value>(text) else {
        log::warn!("ChannelTransport: ignoring non-JSON text message");
        return;
    };

    let msg_type = msg.get("type").and_then(|v| v.as_str()).unwrap_or("");
    let label = msg.get("label").and_then(|v| v.as_str()).unwrap_or("");

    if label == "ravi" {
        let Some(msg_id) = msg.get("id").and_then(|v| v.as_str()) else {
            log::warn!("ChannelTransport: RAVI message missing 'id' field — dropping");
            return;
        };

        let data_str = msg.get("data").map(|d| d.to_string());

        let frame = Frame::ravi_client_message(msg_id, msg_type, data_str);
        let _ = push_tx.send((frame, FrameDirection::Downstream)).await;

        log::trace!("ChannelTransport: RAVI '{}' (id={})", msg_type, msg_id);
        return;
    }

    // Legacy: bare client interruption without RAVI label.
    if msg_type == "client_interruption" {
        log::info!("ChannelTransport: legacy client-initiated interruption");
        let _ = push_tx
            .send((Frame::interruption(), FrameDirection::Downstream))
            .await;
    }
}