use tower::{Layer, Service};
use std::task::{Context, Poll};
use std::pin::Pin;
use futures::future::BoxFuture;
use crate::pipeline::types::{AudioRequest, AudioChunk, AudioResponse, PipelineError};
#[derive(Clone, Debug)]
pub struct SilenceRemovalLayer {
pub threshold: f32,
pub min_silence_duration_ms: u32,
pub keep_short_pauses: bool,
pub padding_ms: u32,
}
impl SilenceRemovalLayer {
pub fn new(threshold: f32, min_silence_duration_ms: u32, keep_short_pauses: bool, padding_ms: u32) -> Self {
Self { threshold, min_silence_duration_ms, keep_short_pauses, padding_ms }
}
}
impl<S> Layer<S> for SilenceRemovalLayer {
type Service = SilenceRemovalService<S>;
fn layer(&self, inner: S) -> Self::Service {
SilenceRemovalService {
inner,
threshold: self.threshold,
min_silence_duration_ms: self.min_silence_duration_ms,
keep_short_pauses: self.keep_short_pauses,
padding_ms: self.padding_ms,
}
}
}
#[derive(Clone, Debug)]
pub struct SilenceRemovalService<S> {
inner: S,
threshold: f32,
min_silence_duration_ms: u32,
keep_short_pauses: bool,
padding_ms: u32,
}
impl<S> Service<AudioRequest> for SilenceRemovalService<S>
where
S: Service<AudioRequest, Response = AudioResponse, Error = PipelineError> + Send + 'static + Clone,
S::Future: Send + 'static,
{
type Response = AudioResponse;
type Error = PipelineError;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: AudioRequest) -> Self::Future {
let mut chunk = req.0.clone();
let threshold = self.threshold;
let min_silence_duration_ms = self.min_silence_duration_ms;
let keep_short_pauses = self.keep_short_pauses;
let padding_ms = self.padding_ms;
let mut inner = self.inner.clone();
Box::pin(async move {
let processed_chunk = process_silence_removal(
chunk,
threshold,
min_silence_duration_ms,
keep_short_pauses,
padding_ms,
);
let req = AudioRequest(processed_chunk);
inner.call(req).await
})
}
}
fn process_silence_removal(
chunk: AudioChunk,
threshold: f32,
min_silence_duration_ms: u32,
keep_short_pauses: bool,
padding_ms: u32,
) -> AudioChunk {
let sample_rate = 16_000; let ms_to_samples = |ms| ((ms as f64 / 1000.0) * sample_rate as f64).round() as usize;
let padding_samples = ms_to_samples(padding_ms);
let min_silence_samples = ms_to_samples(min_silence_duration_ms);
let window_size = ms_to_samples(20);
let mut is_in_silence = true;
let mut segments: Vec<(usize, usize)> = Vec::new();
for window_start in (0..chunk.data.len()).step_by(window_size) {
let window_end = (window_start + window_size).min(chunk.data.len());
let energy: f32 = chunk.data[window_start..window_end]
.iter()
.map(|sample| sample.powi(2))
.sum::<f32>() / (window_end - window_start) as f32;
let rms = energy.sqrt();
if rms > threshold {
if is_in_silence {
is_in_silence = false;
let segment_start = window_start.saturating_sub(padding_samples);
if let Some(last) = segments.last_mut() {
if segment_start <= last.1 + min_silence_samples {
continue;
}
}
segments.push((segment_start, segment_start));
}
} else {
if !is_in_silence {
is_in_silence = true;
if let Some(last) = segments.last_mut() {
last.1 = (window_end + padding_samples).min(chunk.data.len());
}
}
}
}
if !is_in_silence {
if let Some(last) = segments.last_mut() {
last.1 = (chunk.data.len() + padding_samples).min(chunk.data.len());
}
}
if keep_short_pauses && segments.len() > 1 {
let mut merged = Vec::new();
let mut current = segments[0];
for i in 1..segments.len() {
let gap = segments[i].0 as isize - current.1 as isize;
if gap > 0 && gap <= ms_to_samples(300) as isize {
current.1 = segments[i].1;
} else {
merged.push(current);
current = segments[i];
}
}
merged.push(current);
segments = merged;
}
if chunk.data.len() <= ms_to_samples(min_silence_duration_ms * 2) {
return chunk;
}
if segments.is_empty() {
let min_size = ms_to_samples(200);
let silent_data = vec![0.0; min_size];
return AudioChunk {
timestamp: chunk.timestamp,
data: silent_data,
is_speech: chunk.is_speech,
};
}
let total_length: usize = segments.iter().map(|(s, e)| e - s).sum();
let mut processed_data = Vec::with_capacity(total_length);
for (start, end) in segments {
let s = start.max(0).min(chunk.data.len());
let e = end.max(0).min(chunk.data.len());
if s < e {
processed_data.extend_from_slice(&chunk.data[s..e]);
}
}
AudioChunk {
timestamp: chunk.timestamp,
data: processed_data,
is_speech: chunk.is_speech,
}
}