use std::str::FromStr;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Mutex;
use std::time::{Duration, Instant};
use cpal::{
traits::{DeviceTrait, HostTrait, StreamTrait},
Device, DeviceId, Sample, SampleFormat, SizedSample, Stream, SupportedStreamConfig,
};
use tauri::ipc::Channel;
const WAVEFORM_FPS: u64 = 30;
const WAVEFORM_SAMPLE_COUNT: usize = 512;
pub struct ActiveInputStream(Mutex<Option<ActiveStream>>);
impl Default for ActiveInputStream {
fn default() -> Self {
Self(Mutex::new(None))
}
}
enum ActiveStream {
Cpal {
_stream: Stream,
},
#[cfg(target_os = "linux")]
PipewireOutput {
_stream: crate::pipewire::PipewireOutputStream,
},
}
#[derive(Clone, Copy, Debug, serde::Deserialize, serde::Serialize, specta::Type)]
#[serde(rename_all = "camelCase")]
pub enum CaptureSourceKind {
Input,
Output,
}
#[derive(Debug, serde::Serialize, specta::Type)]
#[serde(rename_all = "camelCase")]
pub struct StreamInfo {
pub device_id: String,
pub source_kind: CaptureSourceKind,
pub sample_rate: u32,
pub channels: u16,
pub sample_format: String,
}
#[derive(Clone, Debug, serde::Serialize, specta::Type)]
#[serde(rename_all = "camelCase")]
pub struct WaveformEvent {
sequence: u32,
#[specta(type = specta_typescript::Number)]
rms: f32,
#[specta(type = specta_typescript::Number)]
peak: f32,
peaks: Vec<WaveformPeak>,
}
#[derive(Clone, Copy, Debug, PartialEq, serde::Serialize, specta::Type)]
#[serde(rename_all = "camelCase")]
pub struct WaveformPeak {
#[specta(type = specta_typescript::Number)]
min: f32,
#[specta(type = specta_typescript::Number)]
max: f32,
}
#[derive(Clone, Debug, serde::Serialize, specta::Type)]
#[serde(rename_all = "camelCase")]
pub struct StreamErrorEvent {
pub(super) message: String,
}
#[derive(Clone, Debug, serde::Serialize, specta::Type)]
#[serde(tag = "event", content = "data", rename_all = "camelCase")]
pub enum StreamEvent {
Waveform(WaveformEvent),
Error(StreamErrorEvent),
}
#[tauri::command]
#[specta::specta]
pub fn create_stream(
device_id: String,
kind: CaptureSourceKind,
on_event: Channel<StreamEvent>,
active_stream: tauri::State<'_, ActiveInputStream>,
) -> Result<StreamInfo, String> {
#[cfg(target_os = "linux")]
if matches!(kind, CaptureSourceKind::Output) {
let stream = crate::pipewire::PipewireOutputStream::start(&device_id, on_event)?;
let stream_info = pipewire_stream_info(device_id, kind, stream.format());
replace_active_stream(
&active_stream,
ActiveStream::PipewireOutput { _stream: stream },
)?;
return Ok(stream_info);
}
let host = cpal::default_host();
let parsed_device_id = DeviceId::from_str(&device_id).map_err(|error| error.to_string())?;
let device = host
.device_by_id(&parsed_device_id)
.ok_or_else(|| format!("Device not found: {device_id}"))?;
let supported_config = stream_config(&device, kind)?;
let stream_info = StreamInfo {
device_id,
source_kind: kind,
sample_rate: supported_config.sample_rate(),
channels: supported_config.channels(),
sample_format: supported_config.sample_format().to_string(),
};
let stream = build_stream(device, supported_config, on_event)?;
stream.play().map_err(|error| error.to_string())?;
replace_active_stream(&active_stream, ActiveStream::Cpal { _stream: stream })?;
Ok(stream_info)
}
#[tauri::command]
#[specta::specta]
pub fn stop_stream(active_stream: tauri::State<'_, ActiveInputStream>) -> Result<(), String> {
let mut active_stream = active_stream
.0
.lock()
.map_err(|_| "Active stream state is poisoned".to_string())?;
active_stream.take();
Ok(())
}
fn replace_active_stream(
active_stream: &tauri::State<'_, ActiveInputStream>,
stream: ActiveStream,
) -> Result<(), String> {
let mut active_stream = active_stream
.0
.lock()
.map_err(|_| "Active stream state is poisoned".to_string())?;
active_stream.take();
*active_stream = Some(stream);
Ok(())
}
#[cfg(target_os = "linux")]
fn pipewire_stream_info(
device_id: String,
source_kind: CaptureSourceKind,
format: crate::pipewire::PipewireStreamFormat,
) -> StreamInfo {
StreamInfo {
device_id,
source_kind,
sample_rate: format.sample_rate,
channels: format.channels,
sample_format: "f32-pipewire".to_string(),
}
}
fn stream_config(
device: &Device,
kind: CaptureSourceKind,
) -> Result<SupportedStreamConfig, String> {
match kind {
CaptureSourceKind::Input => device
.default_input_config()
.map_err(|error| format!("No default input config for selected device: {error}")),
CaptureSourceKind::Output => output_stream_config(device),
}
}
#[cfg(target_os = "windows")]
fn output_stream_config(device: &Device) -> Result<SupportedStreamConfig, String> {
device
.default_output_config()
.map_err(|error| format!("No default output config for selected device: {error}"))
}
#[cfg(not(target_os = "windows"))]
fn output_stream_config(_device: &Device) -> Result<SupportedStreamConfig, String> {
Err("Output capture is only supported on Windows in the CPAL-first build".to_string())
}
fn build_stream(
device: Device,
supported_config: SupportedStreamConfig,
on_event: Channel<StreamEvent>,
) -> Result<Stream, String> {
let sample_format = supported_config.sample_format();
let config = supported_config.config();
let channels = usize::from(config.channels);
let state = WaveformState::new(on_event);
match sample_format {
SampleFormat::I8 => build_typed_stream::<i8>(device, config, channels, state),
SampleFormat::I16 => build_typed_stream::<i16>(device, config, channels, state),
SampleFormat::I24 => build_typed_stream::<cpal::I24>(device, config, channels, state),
SampleFormat::I32 => build_typed_stream::<i32>(device, config, channels, state),
SampleFormat::I64 => build_typed_stream::<i64>(device, config, channels, state),
SampleFormat::U8 => build_typed_stream::<u8>(device, config, channels, state),
SampleFormat::U16 => build_typed_stream::<u16>(device, config, channels, state),
SampleFormat::U24 => build_typed_stream::<cpal::U24>(device, config, channels, state),
SampleFormat::U32 => build_typed_stream::<u32>(device, config, channels, state),
SampleFormat::U64 => build_typed_stream::<u64>(device, config, channels, state),
SampleFormat::F32 => build_typed_stream::<f32>(device, config, channels, state),
SampleFormat::F64 => build_typed_stream::<f64>(device, config, channels, state),
SampleFormat::DsdU8 | SampleFormat::DsdU16 | SampleFormat::DsdU32 => Err(format!(
"Unsupported capture sample format: {sample_format}"
)),
_ => Err(format!(
"Unsupported capture sample format: {sample_format}"
)),
}
}
fn build_typed_stream<T>(
device: Device,
config: cpal::StreamConfig,
channels: usize,
state: WaveformState,
) -> Result<Stream, String>
where
T: SizedSample + Sample + Send + 'static,
f32: cpal::FromSample<T>,
{
let error_channel = state.on_event.clone();
device
.build_input_stream(
&config,
move |data: &[T], _| state.process(data, channels),
move |error| {
let message = error.to_string();
let _ = error_channel.send(StreamEvent::Error(StreamErrorEvent {
message: message.clone(),
}));
eprintln!("an error occurred on stream: {message}");
},
None,
)
.map_err(|error| error.to_string())
}
pub(super) struct WaveformState {
pub(super) on_event: Channel<StreamEvent>,
last_emit: Mutex<Option<Instant>>,
sequence: AtomicU32,
}
impl WaveformState {
pub(super) fn new(on_event: Channel<StreamEvent>) -> Self {
Self {
on_event,
last_emit: Mutex::new(None),
sequence: AtomicU32::new(0),
}
}
pub(super) fn process<T>(&self, data: &[T], channels: usize)
where
T: Sample,
f32: cpal::FromSample<T>,
{
if !self.should_emit() {
return;
}
if let Some(event) = waveform_event(data, channels.max(1), self.next_sequence()) {
let _ = self.on_event.send(StreamEvent::Waveform(event));
}
}
fn should_emit(&self) -> bool {
let now = Instant::now();
let interval = Duration::from_millis(1_000 / WAVEFORM_FPS);
let Ok(mut last_emit) = self.last_emit.lock() else {
return false;
};
if last_emit.is_some_and(|last| now.duration_since(last) < interval) {
return false;
}
*last_emit = Some(now);
true
}
fn next_sequence(&self) -> u32 {
self.sequence.fetch_add(1, Ordering::Relaxed)
}
}
fn waveform_event<T>(data: &[T], channels: usize, sequence: u32) -> Option<WaveformEvent>
where
T: Sample,
f32: cpal::FromSample<T>,
{
let peaks = waveform_peaks(data, channels, WAVEFORM_SAMPLE_COUNT);
if peaks.is_empty() {
return None;
}
let (rms, peak) = levels(data);
Some(WaveformEvent {
sequence,
rms,
peak,
peaks,
})
}
fn waveform_peaks<T>(data: &[T], channels: usize, target_len: usize) -> Vec<WaveformPeak>
where
T: Sample,
f32: cpal::FromSample<T>,
{
if data.is_empty() || channels == 0 || target_len == 0 {
return Vec::new();
}
let frame_count = data.len().div_ceil(channels);
let bucket_count = frame_count.min(target_len);
let bucket_size = frame_count as f32 / bucket_count as f32;
(0..bucket_count)
.map(|index| {
let start_frame = (index as f32 * bucket_size).floor() as usize;
let end_frame = (((index + 1) as f32 * bucket_size).ceil() as usize).min(frame_count);
let start = start_frame * channels;
let end = (end_frame * channels).min(data.len());
peak_range(&data[start..end])
})
.collect()
}
fn peak_range<T>(samples: &[T]) -> WaveformPeak
where
T: Sample,
f32: cpal::FromSample<T>,
{
let mut min = 0.0f32;
let mut max = 0.0f32;
for sample in samples {
let sample = normalized_sample(*sample);
min = min.min(sample);
max = max.max(sample);
}
WaveformPeak { min, max }
}
fn levels<T>(samples: &[T]) -> (f32, f32)
where
T: Sample,
f32: cpal::FromSample<T>,
{
if samples.is_empty() {
return (0.0, 0.0);
}
let mut sum_squares = 0.0f32;
let mut peak = 0.0f32;
for sample in samples {
let sample = normalized_sample(*sample);
sum_squares += sample * sample;
peak = peak.max(sample.abs());
}
((sum_squares / samples.len() as f32).sqrt(), peak)
}
fn normalized_sample<T>(sample: T) -> f32
where
T: Sample,
f32: cpal::FromSample<T>,
{
sample.to_sample::<f32>().clamp(-1.0, 1.0)
}
#[cfg(test)]
mod tests {
use super::{levels, waveform_peaks, WaveformPeak};
#[cfg(target_os = "linux")]
use super::{pipewire_stream_info, CaptureSourceKind};
#[test]
fn preserves_antiphase_stereo_in_peak_envelope() {
let peaks = waveform_peaks(&[1.0f32, -1.0], 2, 512);
assert_eq!(
peaks,
vec![WaveformPeak {
min: -1.0,
max: 1.0
}]
);
}
#[test]
fn keeps_uneven_stereo_peaks_per_frame() {
let peaks = waveform_peaks(&[0.1f32, -0.8, -0.6, 0.4], 2, 512);
assert_eq!(
peaks,
vec![
WaveformPeak {
min: -0.8,
max: 0.1
},
WaveformPeak {
min: -0.6,
max: 0.4
}
]
);
}
#[test]
fn computes_rms_and_peak() {
let (rms, peak) = levels(&[-1.0, 0.0, 1.0]);
assert!((rms - 0.8164966).abs() < 0.0001);
assert_eq!(peak, 1.0);
}
#[test]
fn computes_empty_levels_as_silence() {
assert_eq!(levels::<f32>(&[]), (0.0, 0.0));
}
#[test]
fn buckets_waveform_by_min_max_peaks() {
let samples = vec![0.1, -0.8, 0.2, 0.7, -0.3, 0.4];
assert_eq!(
waveform_peaks(&samples, 1, 3),
vec![
WaveformPeak {
min: -0.8,
max: 0.1
},
WaveformPeak { min: 0.0, max: 0.7 },
WaveformPeak {
min: -0.3,
max: 0.4
}
]
);
}
#[cfg(target_os = "linux")]
#[test]
fn maps_negotiated_pipewire_format_to_stream_info() {
let info = pipewire_stream_info(
"pipewire-output:54".to_string(),
CaptureSourceKind::Output,
crate::pipewire::PipewireStreamFormat {
sample_rate: 44_100,
channels: 6,
},
);
assert_eq!(info.device_id, "pipewire-output:54");
assert!(matches!(info.source_kind, CaptureSourceKind::Output));
assert_eq!(info.sample_rate, 44_100);
assert_eq!(info.channels, 6);
assert_eq!(info.sample_format, "f32-pipewire");
}
}