use crate::inference::Engine;
use crate::protocol::ServerMessage;
use anyhow::Result;
use futures_util::{SinkExt, StreamExt};
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::net::TcpListener;
use tokio::sync::Semaphore;
use tokio_tungstenite::tungstenite::Message;
const MAX_CONCURRENT_CONNECTIONS: usize = 4;
pub async fn run(engine: Engine, port: u16) -> Result<()> {
let addr = SocketAddr::from(([127, 0, 0, 1], port));
let listener = TcpListener::bind(&addr).await?;
let engine = Arc::new(engine);
let semaphore = Arc::new(Semaphore::new(MAX_CONCURRENT_CONNECTIONS));
tracing::info!("gigastt server listening on ws://{addr}");
loop {
tokio::select! {
result = listener.accept() => {
let (stream, peer) = result?;
let engine = engine.clone();
let permit = semaphore.clone().acquire_owned().await?;
tokio::spawn(async move {
if let Err(e) = handle_connection(stream, peer, engine).await {
tracing::error!("Connection error from {peer}: {e}");
}
drop(permit);
});
}
_ = tokio::signal::ctrl_c() => {
tracing::info!("Shutting down server");
break;
}
}
}
Ok(())
}
async fn handle_connection(
stream: tokio::net::TcpStream,
peer: SocketAddr,
engine: Arc<Engine>,
) -> Result<()> {
let ws_config = tokio_tungstenite::tungstenite::protocol::WebSocketConfig {
max_message_size: Some(512 * 1024), max_frame_size: Some(512 * 1024),
..Default::default()
};
let ws_stream =
tokio_tungstenite::accept_async_with_config(stream, Some(ws_config)).await?;
let (mut sink, mut source) = ws_stream.split();
tracing::info!("Client connected: {peer}");
let ready = ServerMessage::Ready {
model: "gigaam-v3-e2e-rnnt".into(),
sample_rate: 48000,
version: crate::protocol::PROTOCOL_VERSION.into(),
};
sink.send(Message::Text(serde_json::to_string(&ready)?)).await?;
let mut stream_state = engine.create_state();
while let Some(msg) = source.next().await {
let msg = msg?;
match msg {
Message::Binary(data) => {
let samples_48k: Vec<i16> = data
.chunks_exact(2)
.map(|chunk| i16::from_le_bytes([chunk[0], chunk[1]]))
.collect();
let samples: Vec<i16> = samples_48k
.chunks(3)
.map(|c| {
let sum: i32 = c.iter().map(|&s| s as i32).sum();
(sum / c.len() as i32) as i16
})
.collect();
match engine.process_chunk(&samples, &mut stream_state) {
Ok(segments) => {
for seg in segments {
let msg = if seg.is_final {
ServerMessage::Final {
text: seg.text,
timestamp: seg.timestamp,
}
} else {
ServerMessage::Partial {
text: seg.text,
timestamp: seg.timestamp,
}
};
sink.send(Message::Text(serde_json::to_string(&msg)?)).await?;
}
}
Err(e) => {
tracing::error!("Inference error for {peer}: {e:#}");
let err = ServerMessage::Error {
message: "Inference failed. Please check audio format.".into(),
code: "inference_error".into(),
};
sink.send(Message::Text(serde_json::to_string(&err)?)).await?;
}
}
}
Message::Close(_) => break,
_ => {} }
}
tracing::info!("Client disconnected: {peer}");
Ok(())
}