use crate::config::MoshiConfig;
use crate::proto::{MOSHI_SAMPLE_RATE, OGG_SERIAL, OPUS_FRAME_SAMPLES, ctrl, mt};
use async_trait::async_trait;
use byteorder::WriteBytesExt;
use futures_util::{SinkExt, StreamExt};
use ogg::PacketWriteEndInfo;
use std::collections::VecDeque;
use std::io::Write;
use std::sync::{Arc, Mutex};
use tokio::io::AsyncWriteExt as _;
use tokio::sync::mpsc;
use tokio_tungstenite::tungstenite::Message;
use vona_core::{
AudioInputFrame, AudioOutputFrame, BackendCapabilities, BackendError, BackendStep,
ControlEvent, ExternalContextEvent, SessionConfig, SpeechToSpeechBackend,
};
type WsStream =
tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>;
type WsSink = futures_util::stream::SplitSink<WsStream, Message>;
type SharedWsTx = Arc<tokio::sync::Mutex<WsSink>>;
pub struct MoshiBackend {
config: MoshiConfig,
}
impl MoshiBackend {
pub fn new(config: MoshiConfig) -> Self {
Self { config }
}
}
pub struct MoshiSession {
ws_tx: SharedWsTx,
encoder: Mutex<opus::Encoder>,
pcm_pending: VecDeque<f32>,
total_samples: u64,
ogg_writer: ogg::PacketWriter<'static, Vec<u8>>,
encode_buf: Vec<u8>,
audio_rx: mpsc::Receiver<Vec<f32>>,
text_rx: mpsc::Receiver<String>,
output_seq: u64,
recv_task: tokio::task::JoinHandle<()>,
decode_task: tokio::task::JoinHandle<()>,
}
fn write_opus_head<W: Write>(w: &mut W) -> std::io::Result<()> {
w.write_all(b"OpusHead")?;
w.write_u8(1)?; w.write_u8(1)?; w.write_u16::<byteorder::LittleEndian>(3840)?; w.write_u32::<byteorder::LittleEndian>(48_000)?; w.write_i16::<byteorder::LittleEndian>(0)?; w.write_u8(0)?; Ok(())
}
fn write_opus_tags<W: Write>(w: &mut W) -> std::io::Result<()> {
let vendor = b"VonaRS/MoshiBackend";
w.write_all(b"OpusTags")?;
w.write_u32::<byteorder::LittleEndian>(vendor.len() as u32)?;
w.write_all(vendor)?;
w.write_u32::<byteorder::LittleEndian>(0)?; Ok(())
}
fn mix_to_mono(samples: &[f32], channels: u16) -> Vec<f32> {
if channels <= 1 {
return samples.to_vec();
}
samples
.chunks(channels as usize)
.map(|c| c.iter().sum::<f32>() / channels as f32)
.collect()
}
fn resample_mono(samples: &[f32], src_hz: u32, dst_hz: u32) -> Vec<f32> {
if src_hz == dst_hz || samples.is_empty() {
return samples.to_vec();
}
let out_len =
(((samples.len() as f64) / src_hz as f64 * dst_hz as f64).round() as usize).max(1);
(0..out_len)
.map(|i| {
let src_pos = i as f64 * src_hz as f64 / dst_hz as f64;
let idx = src_pos as usize;
let frac = (src_pos - idx as f64) as f32;
let a = samples.get(idx).copied().unwrap_or(0.0);
let b = samples.get(idx + 1).copied().unwrap_or(a);
a + (b - a) * frac
})
.collect()
}
#[inline]
fn ogg_ws_message(data: &[u8]) -> Message {
let mut msg = Vec::with_capacity(1 + data.len());
msg.push(mt::AUDIO);
msg.extend_from_slice(data);
Message::Binary(msg)
}
#[async_trait]
impl SpeechToSpeechBackend for MoshiBackend {
type Session = MoshiSession;
fn capabilities(&self) -> BackendCapabilities {
BackendCapabilities {
supports_full_duplex: true,
supports_control_stream: false,
supports_context_injection: false,
supports_pause_resume: false,
supports_style_conditioning: false,
supports_word_timestamps: false,
}
}
async fn start_session(&self, _config: SessionConfig) -> Result<MoshiSession, BackendError> {
let ws_stream = if self.config.accept_invalid_certs {
let tls = native_tls::TlsConnector::builder()
.danger_accept_invalid_certs(true)
.build()
.map_err(|e| BackendError::Start(format!("TLS connector build failed: {e}")))?;
let connector = tokio_tungstenite::Connector::NativeTls(tls);
let (stream, _) = tokio_tungstenite::connect_async_tls_with_config(
self.config.url.as_str(),
None,
false,
Some(connector),
)
.await
.map_err(|e| BackendError::Start(format!("WebSocket connect failed: {e}")))?;
stream
} else {
let (stream, _) = tokio_tungstenite::connect_async(self.config.url.as_str())
.await
.map_err(|e| BackendError::Start(format!("WebSocket connect failed: {e}")))?;
stream
};
let (ws_sink, mut ws_source) = ws_stream.split();
let ws_tx: SharedWsTx = Arc::new(tokio::sync::Mutex::new(ws_sink));
loop {
match ws_source.next().await {
None => {
return Err(BackendError::Start(
"server closed connection before handshake".into(),
));
}
Some(Err(e)) => {
return Err(BackendError::Start(format!(
"WS error during handshake: {e}"
)));
}
Some(Ok(Message::Binary(bin))) if !bin.is_empty() && bin[0] == mt::HANDSHAKE => {
let mut reply = vec![mt::HANDSHAKE];
reply.extend_from_slice(&0u32.to_le_bytes()); reply.extend_from_slice(&0u32.to_le_bytes()); ws_tx
.lock()
.await
.send(Message::Binary(reply))
.await
.map_err(|e| BackendError::Start(format!("handshake send failed: {e}")))?;
break;
}
Some(Ok(_)) => continue,
}
}
tracing::info!(url = %self.config.url, "Moshi handshake complete");
let (audio_tx, audio_rx) = mpsc::channel::<Vec<f32>>(128);
let (text_tx, text_rx) = mpsc::channel::<String>(64);
let (mut ogg_write_half, ogg_read_half) = tokio::io::duplex(512 * 1024);
let recv_task = tokio::spawn(async move {
while let Some(msg) = ws_source.next().await {
match msg {
Ok(Message::Binary(bin)) if !bin.is_empty() => match bin[0] {
mt::AUDIO if ogg_write_half.write_all(&bin[1..]).await.is_err() => {
break; }
mt::AUDIO => {}
mt::TEXT => {
let text = String::from_utf8_lossy(&bin[1..]).into_owned();
let _ = text_tx.send(text).await;
}
mt::ERROR => {
let msg = String::from_utf8_lossy(&bin[1..]);
tracing::error!(moshi_error = %msg, "Moshi server error");
}
_ => {}
},
Ok(Message::Close(_)) | Err(_) => break,
_ => {}
}
}
});
let decode_task = tokio::spawn(async move {
let mut pr = ogg::reading::async_api::PacketReader::new(ogg_read_half);
let decoder = opus::Decoder::new(MOSHI_SAMPLE_RATE, opus::Channels::Mono);
let mut decoder = match decoder {
Ok(d) => d,
Err(e) => {
tracing::error!("Moshi: Opus decoder init failed: {e}");
return;
}
};
let mut pcm_buf = vec![0f32; MOSHI_SAMPLE_RATE as usize * 10];
while let Some(packet_result) = pr.next().await {
match packet_result {
Err(e) => {
tracing::warn!("Moshi: OGG read error: {e}");
break;
}
Ok(packet) => {
if packet.data.starts_with(b"OpusHead")
|| packet.data.starts_with(b"OpusTags")
{
continue;
}
match decoder.decode_float(&packet.data, &mut pcm_buf, false) {
Ok(size) if size > 0 => {
if audio_tx.send(pcm_buf[..size].to_vec()).await.is_err() {
break; }
}
Ok(_) => {}
Err(e) => tracing::warn!("Moshi: Opus decode error: {e}"),
}
}
}
}
});
let encoder = opus::Encoder::new(
MOSHI_SAMPLE_RATE,
opus::Channels::Mono,
opus::Application::Voip,
)
.map_err(|e| BackendError::Start(format!("Opus encoder init: {e}")))?;
let mut ogg_writer = ogg::PacketWriter::new(Vec::new());
let mut head_buf = Vec::new();
write_opus_head(&mut head_buf)
.map_err(|e| BackendError::Start(format!("OpusHead write: {e}")))?;
ogg_writer
.write_packet(head_buf, OGG_SERIAL, PacketWriteEndInfo::EndPage, 0)
.map_err(|e| BackendError::Start(format!("OGG head packet: {e}")))?;
let mut tags_buf = Vec::new();
write_opus_tags(&mut tags_buf)
.map_err(|e| BackendError::Start(format!("OpusTags write: {e}")))?;
ogg_writer
.write_packet(tags_buf, OGG_SERIAL, PacketWriteEndInfo::EndPage, 0)
.map_err(|e| BackendError::Start(format!("OGG tags packet: {e}")))?;
{
let data = ogg_writer.inner_mut();
if !data.is_empty() {
let msg = ogg_ws_message(data);
ws_tx
.lock()
.await
.send(msg)
.await
.map_err(|e| BackendError::Start(format!("OGG header send: {e}")))?;
data.clear();
}
}
Ok(MoshiSession {
ws_tx,
encoder: Mutex::new(encoder),
pcm_pending: VecDeque::new(),
total_samples: 0,
ogg_writer,
encode_buf: vec![0u8; 8192],
audio_rx,
text_rx,
output_seq: 0,
recv_task,
decode_task,
})
}
async fn step(
&self,
session: &mut MoshiSession,
input: AudioInputFrame,
) -> Result<BackendStep, BackendError> {
let mono = mix_to_mono(&input.samples, input.channels);
let resampled = resample_mono(&mono, input.sample_rate_hz, MOSHI_SAMPLE_RATE);
session.pcm_pending.extend(resampled.iter().copied());
while session.pcm_pending.len() >= OPUS_FRAME_SAMPLES {
let chunk: Vec<f32> = session.pcm_pending.drain(..OPUS_FRAME_SAMPLES).collect();
session.total_samples += OPUS_FRAME_SAMPLES as u64;
let size = {
let enc = &session.encoder;
let buf = &mut session.encode_buf;
enc.lock()
.unwrap()
.encode_float(&chunk, buf)
.map_err(|e| BackendError::Step(format!("Opus encode: {e}")))?
};
if size > 0 {
let packet = session.encode_buf[..size].to_vec();
session
.ogg_writer
.write_packet(
packet,
OGG_SERIAL,
PacketWriteEndInfo::EndPage,
session.total_samples,
)
.map_err(|e| BackendError::Step(format!("OGG write: {e}")))?;
let data = session.ogg_writer.inner_mut();
if !data.is_empty() {
let msg = ogg_ws_message(data);
session
.ws_tx
.lock()
.await
.send(msg)
.await
.map_err(|e| BackendError::Step(format!("WS send: {e}")))?;
data.clear();
}
}
}
let mut output_audio = Vec::new();
while let Ok(pcm) = session.audio_rx.try_recv() {
let seq = session.output_seq;
session.output_seq += 1;
output_audio.push(AudioOutputFrame {
sequence: seq,
sample_rate_hz: MOSHI_SAMPLE_RATE,
channels: 1,
samples: pcm,
is_filler: false,
});
}
let mut control_events = Vec::new();
while let Ok(text) = session.text_rx.try_recv() {
control_events.push(ControlEvent::TranscriptFragment {
text,
final_fragment: false,
});
}
Ok(BackendStep {
output_audio,
control_events,
transcript: None,
finished: false,
debug_payload: None,
})
}
async fn inject_event(
&self,
_session: &mut MoshiSession,
_event: ExternalContextEvent,
) -> Result<(), BackendError> {
Ok(())
}
async fn end_session(&self, session: MoshiSession) -> Result<(), BackendError> {
let end_turn = vec![mt::CONTROL, ctrl::END_TURN];
let mut tx = session.ws_tx.lock().await;
let _ = tx.send(Message::Binary(end_turn)).await;
let _ = tx.send(Message::Close(None)).await;
drop(tx);
session.recv_task.abort();
session.decode_task.abort();
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::MoshiConfig;
#[test]
fn backend_new_stores_config() {
let cfg = MoshiConfig {
url: "wss://test.local/api/chat".into(),
accept_invalid_certs: true,
};
let backend = MoshiBackend::new(cfg.clone());
assert_eq!(backend.config.url, cfg.url);
assert_eq!(
backend.config.accept_invalid_certs,
cfg.accept_invalid_certs
);
}
#[test]
fn backend_capabilities_full_duplex() {
let backend = MoshiBackend::new(MoshiConfig::default());
let caps = backend.capabilities();
assert!(caps.supports_full_duplex, "Moshi is a full-duplex backend");
}
#[test]
fn backend_capabilities_no_context_injection() {
let backend = MoshiBackend::new(MoshiConfig::default());
let caps = backend.capabilities();
assert!(
!caps.supports_context_injection,
"Moshi does not support external context injection"
);
}
#[test]
fn backend_capabilities_no_style_conditioning() {
let backend = MoshiBackend::new(MoshiConfig::default());
let caps = backend.capabilities();
assert!(!caps.supports_style_conditioning);
}
#[tokio::test]
async fn start_session_fails_when_server_unreachable() {
let backend = MoshiBackend::new(MoshiConfig {
url: "ws://127.0.0.1:19998/api/chat".into(),
accept_invalid_certs: false,
});
let result = backend.start_session(SessionConfig::default()).await;
assert!(
result.is_err(),
"expected BackendError when Moshi server is not reachable"
);
if let Err(BackendError::Start(msg)) = result {
assert!(!msg.is_empty(), "error message should describe the failure");
}
}
}