use axum::body::Bytes;
use axum::extract::State;
use axum::http::StatusCode;
use axum::http::header;
use axum::response::sse::{Event, KeepAlive, Sse};
use axum::response::{IntoResponse, Json, Response};
use futures_util::StreamExt;
use futures_util::stream::Stream;
use serde::Serialize;
use std::sync::Arc;
use arc_swap::ArcSwap;
use super::config::{RuntimeLimits, pool_retry_after_ms, pool_retry_after_secs};
use super::metrics::MetricsRegistry;
use gigastt_core::inference::Engine;
pub struct AppState {
pub engine: Arc<Engine>,
pub limits: Arc<ArcSwap<RuntimeLimits>>,
pub metrics_registry: Option<Arc<MetricsRegistry>>,
pub shutdown: tokio_util::sync::CancellationToken,
pub tracker: tokio_util::task::TaskTracker,
}
pub async fn metrics(State(state): State<Arc<AppState>>) -> Response {
match &state.metrics_registry {
Some(registry) => (
StatusCode::OK,
[(
header::CONTENT_TYPE,
"text/plain; version=0.0.4; charset=utf-8",
)],
registry.render_prometheus(),
)
.into_response(),
None => (
StatusCode::NOT_FOUND,
Json(serde_json::json!({
"error": "metrics endpoint disabled",
"code": "metrics_disabled",
})),
)
.into_response(),
}
}
#[derive(Serialize)]
pub struct HealthResponse {
pub status: String,
pub model: String,
pub version: String,
}
#[derive(Serialize)]
pub struct ModelInfo {
pub id: String,
pub name: String,
pub version: String,
pub encoder: String,
pub vocab_size: usize,
pub sample_rate: u32,
pub pool_size: usize,
pub pool_available: usize,
pub supported_formats: Vec<String>,
pub supported_rates: Vec<u32>,
pub diarization: bool,
}
#[derive(Serialize)]
pub struct TranscribeResponse {
pub text: String,
pub words: Vec<gigastt_core::inference::WordInfo>,
pub duration: f64,
}
type ApiError = Response;
fn api_error(status: StatusCode, msg: &str, code: &str) -> ApiError {
(
status,
Json(serde_json::json!({"error": msg, "code": code})),
)
.into_response()
}
fn api_timeout_error(limits: &RuntimeLimits) -> ApiError {
(
StatusCode::SERVICE_UNAVAILABLE,
[(
header::RETRY_AFTER,
pool_retry_after_secs(limits).to_string(),
)],
Json(serde_json::json!({
"error": "Server busy, try again later",
"code": "timeout",
"retry_after_ms": pool_retry_after_ms(limits),
})),
)
.into_response()
}
fn api_pool_closed_error() -> ApiError {
(
StatusCode::SERVICE_UNAVAILABLE,
Json(serde_json::json!({
"error": "Server is shutting down",
"code": "pool_closed",
})),
)
.into_response()
}
#[derive(Serialize)]
pub struct ReadinessResponse {
pub status: String,
pub pool_available: usize,
pub pool_total: usize,
#[serde(skip_serializing_if = "Option::is_none")]
pub reason: Option<String>,
}
pub async fn health(State(state): State<Arc<AppState>>) -> Json<HealthResponse> {
let _ = &state.engine;
Json(HealthResponse {
status: "ok".into(),
model: "gigaam-v3-e2e-rnnt".into(),
version: env!("CARGO_PKG_VERSION").into(),
})
}
pub async fn readiness(State(state): State<Arc<AppState>>) -> Response {
if state.shutdown.is_cancelled() {
return (
StatusCode::SERVICE_UNAVAILABLE,
Json(ReadinessResponse {
status: "not_ready".into(),
pool_available: 0,
pool_total: state.engine.pool.total(),
reason: Some("shutting_down".into()),
}),
)
.into_response();
}
let available = state.engine.pool.available();
if let Some(ref reg) = state.metrics_registry {
reg.gauge_set("gigastt_pool_available", &[], available as i64);
reg.gauge_set(
"gigastt_pool_waiters",
&[],
state.engine.pool.waiters() as i64,
);
}
if available == 0 {
return (
StatusCode::SERVICE_UNAVAILABLE,
Json(ReadinessResponse {
status: "not_ready".into(),
pool_available: 0,
pool_total: state.engine.pool.total(),
reason: Some("pool_exhausted".into()),
}),
)
.into_response();
}
Json(ReadinessResponse {
status: "ready".into(),
pool_available: available,
pool_total: state.engine.pool.total(),
reason: None,
})
.into_response()
}
pub async fn models(State(state): State<Arc<AppState>>) -> Json<ModelInfo> {
let engine = &state.engine;
#[cfg(feature = "diarization")]
let diarization = engine.has_speaker_encoder();
#[cfg(not(feature = "diarization"))]
let diarization = false;
if let Some(ref reg) = state.metrics_registry {
reg.gauge_set(
"gigastt_pool_available",
&[],
engine.pool.available() as i64,
);
reg.gauge_set("gigastt_pool_waiters", &[], engine.pool.waiters() as i64);
}
Json(ModelInfo {
id: "gigaam-v3-e2e-rnnt".into(),
name: "GigaAM v3 RNN-T".into(),
version: env!("CARGO_PKG_VERSION").into(),
encoder: if engine.is_int8() {
"int8".into()
} else {
"fp32".into()
},
vocab_size: engine.vocab_size(),
sample_rate: 16000,
pool_size: engine.pool.total(),
pool_available: engine.pool.available(),
supported_formats: vec![
"wav".into(),
"mp3".into(),
"m4a".into(),
"ogg".into(),
"flac".into(),
],
supported_rates: super::config::SUPPORTED_RATES.to_vec(),
diarization,
})
}
pub async fn transcribe(
State(state): State<Arc<AppState>>,
body: Bytes,
) -> Result<Json<TranscribeResponse>, ApiError> {
if body.is_empty() {
return Err(api_error(
StatusCode::BAD_REQUEST,
"Empty request body",
"empty_body",
));
}
let limits = state.limits.load();
if body.len() > limits.body_limit_bytes {
return Err(api_error(
StatusCode::PAYLOAD_TOO_LARGE,
"Request body exceeds the configured size limit",
"payload_too_large",
));
}
let checkout_start = std::time::Instant::now();
let guard = match tokio::time::timeout(
std::time::Duration::from_secs(limits.pool_checkout_timeout_secs),
state.engine.pool.checkout(),
)
.await
{
Ok(Ok(guard)) => guard,
Ok(Err(_pool_closed)) => return Err(api_pool_closed_error()),
Err(_timeout) => {
if let Some(ref reg) = state.metrics_registry {
reg.counter_inc("gigastt_pool_timeouts_total", &[], 1);
reg.histogram_record(
"gigastt_pool_checkout_duration_seconds",
&[],
checkout_start.elapsed().as_secs_f64(),
);
}
return Err(api_timeout_error(&limits));
}
};
if let Some(ref reg) = state.metrics_registry {
reg.histogram_record(
"gigastt_pool_checkout_duration_seconds",
&[],
checkout_start.elapsed().as_secs_f64(),
);
}
let mut reservation = guard.into_owned();
let engine = state.engine.clone();
let inference_start = std::time::Instant::now();
let span = tracing::Span::current();
let result = tokio::task::spawn_blocking(move || {
let _enter = span.enter();
let r = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
engine.transcribe_bytes_shared(body, &mut reservation)
}));
match r {
Ok(inference_result) => inference_result,
Err(_) => {
tracing::error!("Panic in REST transcribe — triplet recovered");
Err(gigastt_core::error::GigasttError::Inference {
source: anyhow::anyhow!("Inference thread panicked").into(),
})
}
}
})
.await;
if let Some(ref reg) = state.metrics_registry {
reg.histogram_record(
"gigastt_inference_duration_seconds",
&[],
inference_start.elapsed().as_secs_f64(),
);
}
match result {
Ok(Ok(result)) => Ok(Json(TranscribeResponse {
text: result.text,
words: result.words,
duration: result.duration_s,
})),
Ok(Err(e)) => {
tracing::error!("Transcription error: {e}");
Err(api_error(
StatusCode::UNPROCESSABLE_ENTITY,
"Transcription failed. Check audio format.",
"transcription_error",
))
}
Err(e) => {
tracing::error!("spawn_blocking join error: {e}");
Err(api_error(
StatusCode::INTERNAL_SERVER_ERROR,
"Internal server error",
"internal",
))
}
}
}
pub async fn transcribe_stream(
State(state): State<Arc<AppState>>,
body: Bytes,
) -> Result<Sse<impl Stream<Item = Result<Event, std::convert::Infallible>>>, ApiError> {
if body.is_empty() {
return Err(api_error(
StatusCode::BAD_REQUEST,
"Empty request body",
"empty_body",
));
}
let limits = state.limits.load();
if body.len() > limits.body_limit_bytes {
return Err(api_error(
StatusCode::PAYLOAD_TOO_LARGE,
"Request body exceeds the configured size limit",
"payload_too_large",
));
}
let samples = tokio::task::spawn_blocking(move || {
gigastt_core::inference::audio::decode_audio_bytes_shared(body)
})
.await
.map_err(|e| {
tracing::error!("spawn_blocking join error: {e}");
api_error(
StatusCode::INTERNAL_SERVER_ERROR,
"Internal server error",
"internal",
)
})?
.map_err(|e| {
tracing::error!("Audio decode error: {e:#}");
api_error(
StatusCode::UNPROCESSABLE_ENTITY,
"Failed to decode audio file. Check format (WAV, MP3, M4A, OGG, FLAC supported).",
"invalid_audio",
)
})?;
let checkout_start = std::time::Instant::now();
let guard = match tokio::time::timeout(
std::time::Duration::from_secs(limits.pool_checkout_timeout_secs),
state.engine.pool.checkout(),
)
.await
{
Ok(Ok(guard)) => guard,
Ok(Err(_pool_closed)) => return Err(api_pool_closed_error()),
Err(_timeout) => {
if let Some(ref reg) = state.metrics_registry {
reg.counter_inc("gigastt_pool_timeouts_total", &[], 1);
reg.histogram_record(
"gigastt_pool_checkout_duration_seconds",
&[],
checkout_start.elapsed().as_secs_f64(),
);
}
return Err(api_timeout_error(&limits));
}
};
if let Some(ref reg) = state.metrics_registry {
reg.histogram_record(
"gigastt_pool_checkout_duration_seconds",
&[],
checkout_start.elapsed().as_secs_f64(),
);
}
let mut reservation = guard.into_owned();
let (tx, rx) = tokio::sync::mpsc::channel::<
Result<gigastt_core::inference::TranscriptSegment, String>,
>(16);
let engine = state.engine.clone();
let cancel = state.shutdown.clone();
let tracker = state.tracker.clone();
let span = tracing::Span::current();
tracker.spawn_blocking(move || {
let _enter = span.enter();
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
let mut stream_state = engine.create_state(false);
let chunk_size = 16000;
for chunk in samples.chunks(chunk_size) {
if cancel.is_cancelled() {
tracing::info!("SSE transcription cancelled by shutdown");
return;
}
match engine.process_chunk(chunk, &mut stream_state, &mut reservation) {
Ok(segs) => {
for seg in segs {
if tx.blocking_send(Ok(seg)).is_err() {
return;
}
}
}
Err(e) => {
let _ = tx.blocking_send(Err(format!("{e}")));
return;
}
}
}
if let Some(seg) = engine.flush_state(&mut stream_state) {
let _ = tx.blocking_send(Ok(seg));
}
}));
if result.is_err() {
tracing::error!("Panic in SSE inference task — triplet recovered");
}
});
let stream = tokio_stream::wrappers::ReceiverStream::new(rx)
.map(|result| {
let event = match result {
Ok(seg) => {
let msg = if seg.is_final {
serde_json::json!({"type": "final", "text": seg.text, "timestamp": seg.timestamp, "words": seg.words})
} else {
serde_json::json!({"type": "partial", "text": seg.text, "timestamp": seg.timestamp, "words": seg.words})
};
Event::default().data(msg.to_string())
}
Err(_) => {
let msg = serde_json::json!({"type": "error", "message": "Transcription failed.", "code": "inference_error"});
Event::default().data(msg.to_string())
}
};
Ok(event)
});
Ok(Sse::new(stream).keep_alive(
KeepAlive::new()
.interval(std::time::Duration::from_secs(15))
.text(""),
))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_health_response_serialization() {
let resp = HealthResponse {
status: "ok".into(),
model: "test".into(),
version: "0.3.0".into(),
};
let json = serde_json::to_string(&resp).unwrap();
let v: serde_json::Value = serde_json::from_str(&json).unwrap();
assert_eq!(v["status"], "ok");
assert_eq!(v["model"], "test");
}
#[test]
fn test_transcribe_response_serialization() {
let resp = TranscribeResponse {
text: "hello".into(),
words: vec![],
duration: 1.5,
};
let json = serde_json::to_string(&resp).unwrap();
let v: serde_json::Value = serde_json::from_str(&json).unwrap();
assert_eq!(v["text"], "hello");
assert_eq!(v["duration"], 1.5);
}
#[test]
fn test_readiness_response_ready_serialization() {
let resp = ReadinessResponse {
status: "ready".into(),
pool_available: 3,
pool_total: 4,
reason: None,
};
let json = serde_json::to_value(&resp).unwrap();
assert_eq!(json["status"], "ready");
assert_eq!(json["pool_available"], 3);
assert_eq!(json["pool_total"], 4);
assert!(json.get("reason").is_none() || json["reason"].is_null());
}
#[test]
fn test_readiness_response_not_ready_serialization() {
let resp = ReadinessResponse {
status: "not_ready".into(),
pool_available: 0,
pool_total: 4,
reason: Some("pool_exhausted".into()),
};
let json = serde_json::to_value(&resp).unwrap();
assert_eq!(json["status"], "not_ready");
assert_eq!(json["reason"], "pool_exhausted");
}
#[tokio::test]
async fn test_api_error_basic() {
let resp = api_error(StatusCode::BAD_REQUEST, "bad request", "bad_request");
let (parts, body) = resp.into_parts();
assert_eq!(parts.status, StatusCode::BAD_REQUEST);
let bytes = axum::body::to_bytes(body, 1024).await.unwrap();
let v: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
assert_eq!(v["error"], "bad request");
assert_eq!(v["code"], "bad_request");
}
#[tokio::test]
async fn test_api_timeout_error_includes_retry_after() {
let limits = RuntimeLimits::default();
let resp = api_timeout_error(&limits);
let (parts, body) = resp.into_parts();
assert_eq!(parts.status, StatusCode::SERVICE_UNAVAILABLE);
assert_eq!(
parts.headers.get(header::RETRY_AFTER).unwrap().to_str().unwrap(),
pool_retry_after_secs(&limits).to_string()
);
let bytes = axum::body::to_bytes(body, 1024).await.unwrap();
let v: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
assert_eq!(v["code"], "timeout");
assert_eq!(v["retry_after_ms"], pool_retry_after_ms(&limits));
}
#[tokio::test]
async fn test_api_pool_closed_error_no_retry() {
let resp = api_pool_closed_error();
let (parts, body) = resp.into_parts();
assert_eq!(parts.status, StatusCode::SERVICE_UNAVAILABLE);
assert!(parts.headers.get(header::RETRY_AFTER).is_none());
let bytes = axum::body::to_bytes(body, 1024).await.unwrap();
let v: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
assert_eq!(v["code"], "pool_closed");
assert!(v.get("retry_after_ms").is_none());
}
#[tokio::test]
async fn test_readiness_when_shutdown_cancelled() {
let state = Arc::new(AppState {
engine: test_engine(),
limits: Arc::new(ArcSwap::from_pointee(RuntimeLimits::default())),
metrics_registry: None,
shutdown: tokio_util::sync::CancellationToken::new(),
tracker: tokio_util::task::TaskTracker::new(),
});
state.shutdown.cancel();
let resp = readiness(State(state)).await;
assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE);
let bytes = axum::body::to_bytes(resp.into_body(), 1024).await.unwrap();
let v: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
assert_eq!(v["status"], "not_ready");
assert_eq!(v["reason"], "shutting_down");
}
#[tokio::test]
async fn test_readiness_when_pool_exhausted() {
let engine = fresh_engine();
let _guards: Vec<_> = (0..engine.pool.total())
.map(|_| engine.pool.checkout_blocking().unwrap())
.collect();
let state = Arc::new(AppState {
engine,
limits: Arc::new(ArcSwap::from_pointee(RuntimeLimits::default())),
metrics_registry: None,
shutdown: tokio_util::sync::CancellationToken::new(),
tracker: tokio_util::task::TaskTracker::new(),
});
let resp = readiness(State(state)).await;
assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE);
let bytes = axum::body::to_bytes(resp.into_body(), 1024).await.unwrap();
let v: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
assert_eq!(v["status"], "not_ready");
assert_eq!(v["reason"], "pool_exhausted");
}
#[tokio::test]
async fn test_transcribe_payload_too_large() {
let state = Arc::new(AppState {
engine: test_engine(),
limits: Arc::new(ArcSwap::from_pointee(RuntimeLimits {
body_limit_bytes: 10,
..RuntimeLimits::default()
})),
metrics_registry: None,
shutdown: tokio_util::sync::CancellationToken::new(),
tracker: tokio_util::task::TaskTracker::new(),
});
let body = Bytes::from(vec![0u8; 100]);
let result = transcribe(State(state), body).await;
match result {
Err(resp) => assert_eq!(resp.status(), StatusCode::PAYLOAD_TOO_LARGE),
Ok(_) => panic!("expected payload_too_large error"),
}
}
#[tokio::test]
async fn test_models_with_metrics() {
let state = Arc::new(AppState {
engine: test_engine(),
limits: Arc::new(ArcSwap::from_pointee(RuntimeLimits::default())),
metrics_registry: Some(Arc::new(MetricsRegistry::new())),
shutdown: tokio_util::sync::CancellationToken::new(),
tracker: tokio_util::task::TaskTracker::new(),
});
let resp = models(State(state)).await;
let json = serde_json::to_value(&*resp).unwrap();
assert_eq!(json["id"], "gigaam-v3-e2e-rnnt");
}
#[tokio::test]
async fn test_readiness_with_metrics() {
let state = Arc::new(AppState {
engine: fresh_engine(),
limits: Arc::new(ArcSwap::from_pointee(RuntimeLimits::default())),
metrics_registry: Some(Arc::new(MetricsRegistry::new())),
shutdown: tokio_util::sync::CancellationToken::new(),
tracker: tokio_util::task::TaskTracker::new(),
});
let resp = readiness(State(state)).await;
assert_eq!(resp.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_transcribe_pool_closed() {
let engine = fresh_engine();
engine.pool.close();
let state = Arc::new(AppState {
engine,
limits: Arc::new(ArcSwap::from_pointee(RuntimeLimits::default())),
metrics_registry: None,
shutdown: tokio_util::sync::CancellationToken::new(),
tracker: tokio_util::task::TaskTracker::new(),
});
let body = Bytes::from(vec![0u8; 100]);
let result = transcribe(State(state), body).await;
match result {
Err(resp) => assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE),
Ok(_) => panic!("expected pool_closed error"),
}
}
#[tokio::test]
async fn test_transcribe_stream_invalid_audio() {
let state = Arc::new(AppState {
engine: test_engine(),
limits: Arc::new(ArcSwap::from_pointee(RuntimeLimits::default())),
metrics_registry: None,
shutdown: tokio_util::sync::CancellationToken::new(),
tracker: tokio_util::task::TaskTracker::new(),
});
let body = Bytes::from(vec![0u8; 100]);
let result = transcribe_stream(State(state), body).await;
match result {
Err(resp) => assert_eq!(resp.status(), StatusCode::UNPROCESSABLE_ENTITY),
Ok(_) => panic!("expected invalid_audio error"),
}
}
#[tokio::test]
async fn test_transcribe_stream_payload_too_large() {
let state = Arc::new(AppState {
engine: test_engine(),
limits: Arc::new(ArcSwap::from_pointee(RuntimeLimits {
body_limit_bytes: 10,
..RuntimeLimits::default()
})),
metrics_registry: None,
shutdown: tokio_util::sync::CancellationToken::new(),
tracker: tokio_util::task::TaskTracker::new(),
});
let body = Bytes::from(vec![0u8; 100]);
let result = transcribe_stream(State(state), body).await;
match result {
Err(resp) => assert_eq!(resp.status(), StatusCode::PAYLOAD_TOO_LARGE),
Ok(_) => panic!("expected payload_too_large error"),
}
}
#[tokio::test]
async fn test_transcribe_stream_pool_closed() {
let engine = fresh_engine();
engine.pool.close();
let state = Arc::new(AppState {
engine,
limits: Arc::new(ArcSwap::from_pointee(RuntimeLimits::default())),
metrics_registry: None,
shutdown: tokio_util::sync::CancellationToken::new(),
tracker: tokio_util::task::TaskTracker::new(),
});
let body = minimal_wav();
let result = transcribe_stream(State(state), body).await;
match result {
Err(resp) => assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE),
Ok(_) => panic!("expected pool_closed error"),
}
}
#[tokio::test]
async fn test_transcribe_with_metrics() {
let state = Arc::new(AppState {
engine: test_engine(),
limits: Arc::new(ArcSwap::from_pointee(RuntimeLimits::default())),
metrics_registry: Some(Arc::new(MetricsRegistry::new())),
shutdown: tokio_util::sync::CancellationToken::new(),
tracker: tokio_util::task::TaskTracker::new(),
});
let body = short_wav();
match transcribe(State(state), body).await {
Ok(_) => {}
Err(_) => panic!("transcribe with metrics failed"),
}
}
#[tokio::test]
async fn test_transcribe_stream_with_metrics() {
let state = Arc::new(AppState {
engine: test_engine(),
limits: Arc::new(ArcSwap::from_pointee(RuntimeLimits::default())),
metrics_registry: Some(Arc::new(MetricsRegistry::new())),
shutdown: tokio_util::sync::CancellationToken::new(),
tracker: tokio_util::task::TaskTracker::new(),
});
let body = short_wav();
match transcribe_stream(State(state), body).await {
Ok(_) => {}
Err(_) => panic!("transcribe_stream with metrics failed"),
}
}
}
#[cfg(test)]
fn test_engine() -> Arc<Engine> {
use std::sync::OnceLock;
static ENGINE: OnceLock<Arc<Engine>> = OnceLock::new();
ENGINE
.get_or_init(|| {
Arc::new(
Engine::load_with_pool_size(
&gigastt_core::model::default_model_dir(),
1,
)
.unwrap(),
)
})
.clone()
}
#[cfg(test)]
fn fresh_engine() -> Arc<Engine> {
Arc::new(
Engine::load_with_pool_size(
&gigastt_core::model::default_model_dir(),
1,
)
.unwrap(),
)
}
#[cfg(test)]
fn minimal_wav() -> Bytes {
let data_size = 4u32;
let file_size = 44 + data_size - 8;
let mut wav = vec![];
wav.extend_from_slice(b"RIFF");
wav.extend_from_slice(&file_size.to_le_bytes());
wav.extend_from_slice(b"WAVE");
wav.extend_from_slice(b"fmt ");
wav.extend_from_slice(&16u32.to_le_bytes());
wav.extend_from_slice(&1u16.to_le_bytes());
wav.extend_from_slice(&1u16.to_le_bytes());
wav.extend_from_slice(&16000u32.to_le_bytes());
wav.extend_from_slice(&(16000u32 * 2).to_le_bytes());
wav.extend_from_slice(&2u16.to_le_bytes());
wav.extend_from_slice(&16u16.to_le_bytes());
wav.extend_from_slice(b"data");
wav.extend_from_slice(&data_size.to_le_bytes());
wav.extend_from_slice(&0i16.to_le_bytes());
wav.extend_from_slice(&0i16.to_le_bytes());
Bytes::from(wav)
}
#[cfg(test)]
fn short_wav() -> Bytes {
let sample_rate = 16000u32;
let duration_s = 0.1f32;
let num_samples = (sample_rate as f32 * duration_s) as u32;
let data_size = num_samples * 2;
let file_size = 44 + data_size - 8;
let mut wav = vec![];
wav.extend_from_slice(b"RIFF");
wav.extend_from_slice(&file_size.to_le_bytes());
wav.extend_from_slice(b"WAVE");
wav.extend_from_slice(b"fmt ");
wav.extend_from_slice(&16u32.to_le_bytes());
wav.extend_from_slice(&1u16.to_le_bytes());
wav.extend_from_slice(&1u16.to_le_bytes());
wav.extend_from_slice(&sample_rate.to_le_bytes());
wav.extend_from_slice(&(sample_rate * 2).to_le_bytes());
wav.extend_from_slice(&2u16.to_le_bytes());
wav.extend_from_slice(&16u16.to_le_bytes());
wav.extend_from_slice(b"data");
wav.extend_from_slice(&data_size.to_le_bytes());
for _ in 0..num_samples {
wav.extend_from_slice(&0i16.to_le_bytes());
}
Bytes::from(wav)
}