1#![forbid(unsafe_code)]
7#![allow(clippy::cast_sign_loss)]
8#![allow(clippy::cast_possible_wrap)]
9#![allow(clippy::cast_lossless)]
10#![allow(clippy::needless_pass_by_value)]
11#![allow(clippy::needless_range_loop)]
12#![allow(clippy::get_first)]
13#![allow(clippy::doc_markdown)]
14
15use bytes::{Bytes, BytesMut};
16
17use crate::error::{GraphError, GraphResult};
18use crate::frame::FilterFrame;
19use crate::node::{Node, NodeId, NodeState, NodeType};
20use crate::port::{AudioPortFormat, InputPort, OutputPort, PortFormat, PortId, PortType};
21
22use oximedia_audio::{AudioBuffer, AudioFrame, ChannelLayout};
23use oximedia_core::SampleFormat;
24
25#[derive(Clone, Copy, Debug, PartialEq, Eq)]
27pub enum FadeDirection {
28 In,
30 Out,
32}
33
34#[derive(Clone, Debug)]
36pub struct FadeConfig {
37 pub direction: FadeDirection,
39 pub duration_samples: usize,
41 pub position: usize,
43 pub active: bool,
45}
46
47impl FadeConfig {
48 #[must_use]
50 pub fn new(direction: FadeDirection, duration_samples: usize) -> Self {
51 Self {
52 direction,
53 duration_samples,
54 position: 0,
55 active: true,
56 }
57 }
58
59 #[must_use]
61 pub fn fade_in(duration_samples: usize) -> Self {
62 Self::new(FadeDirection::In, duration_samples)
63 }
64
65 #[must_use]
67 pub fn fade_out(duration_samples: usize) -> Self {
68 Self::new(FadeDirection::Out, duration_samples)
69 }
70
71 #[must_use]
73 pub fn gain_at_position(&self, sample_offset: usize) -> f64 {
74 if !self.active || self.duration_samples == 0 {
75 return 1.0;
76 }
77
78 let pos = (self.position + sample_offset).min(self.duration_samples);
79 let t = pos as f64 / self.duration_samples as f64;
80
81 match self.direction {
82 FadeDirection::In => t,
83 FadeDirection::Out => 1.0 - t,
84 }
85 }
86
87 #[must_use]
89 pub fn is_complete(&self) -> bool {
90 self.position >= self.duration_samples
91 }
92
93 pub fn advance(&mut self, samples: usize) {
95 self.position = self.position.saturating_add(samples);
96 if self.is_complete() {
97 self.active = false;
98 }
99 }
100}
101
102#[derive(Clone, Debug)]
104pub struct VolumeConfig {
105 pub gain: f64,
107 pub fade: Option<FadeConfig>,
109 pub normalize_peak: bool,
111 pub target_peak: f64,
113 pub soft_clip: bool,
115 pub soft_clip_threshold: f64,
117}
118
119impl Default for VolumeConfig {
120 fn default() -> Self {
121 Self {
122 gain: 1.0,
123 fade: None,
124 normalize_peak: false,
125 target_peak: 1.0,
126 soft_clip: false,
127 soft_clip_threshold: 0.9,
128 }
129 }
130}
131
132impl VolumeConfig {
133 #[must_use]
135 pub fn new(gain: f64) -> Self {
136 Self {
137 gain,
138 ..Default::default()
139 }
140 }
141
142 #[must_use]
144 pub fn from_db(db: f64) -> Self {
145 Self::new(Self::db_to_linear(db))
146 }
147
148 #[must_use]
150 pub fn with_db_gain(mut self, db: f64) -> Self {
151 self.gain = Self::db_to_linear(db);
152 self
153 }
154
155 #[must_use]
157 pub fn with_fade_in(mut self, duration_samples: usize) -> Self {
158 self.fade = Some(FadeConfig::fade_in(duration_samples));
159 self
160 }
161
162 #[must_use]
164 pub fn with_fade_out(mut self, duration_samples: usize) -> Self {
165 self.fade = Some(FadeConfig::fade_out(duration_samples));
166 self
167 }
168
169 #[must_use]
171 pub fn with_peak_normalization(mut self, target_peak: f64) -> Self {
172 self.normalize_peak = true;
173 self.target_peak = target_peak;
174 self
175 }
176
177 #[must_use]
179 pub fn with_soft_clip(mut self, threshold: f64) -> Self {
180 self.soft_clip = true;
181 self.soft_clip_threshold = threshold;
182 self
183 }
184
185 #[must_use]
187 pub fn db_to_linear(db: f64) -> f64 {
188 10.0_f64.powf(db / 20.0)
189 }
190
191 #[must_use]
193 pub fn linear_to_db(linear: f64) -> f64 {
194 if linear <= 0.0 {
195 f64::NEG_INFINITY
196 } else {
197 20.0 * linear.log10()
198 }
199 }
200}
201
202pub struct VolumeFilter {
221 id: NodeId,
222 name: String,
223 state: NodeState,
224 config: VolumeConfig,
225 inputs: Vec<InputPort>,
226 outputs: Vec<OutputPort>,
227}
228
229impl VolumeFilter {
230 #[must_use]
232 pub fn new(id: NodeId, name: impl Into<String>, config: VolumeConfig) -> Self {
233 let audio_format = PortFormat::Audio(AudioPortFormat::any());
234
235 Self {
236 id,
237 name: name.into(),
238 state: NodeState::Idle,
239 config,
240 inputs: vec![InputPort::new(PortId(0), "input", PortType::Audio)
241 .with_format(audio_format.clone())],
242 outputs: vec![
243 OutputPort::new(PortId(0), "output", PortType::Audio).with_format(audio_format)
244 ],
245 }
246 }
247
248 #[must_use]
250 pub fn config(&self) -> &VolumeConfig {
251 &self.config
252 }
253
254 pub fn set_gain(&mut self, gain: f64) {
256 self.config.gain = gain;
257 }
258
259 pub fn set_gain_db(&mut self, db: f64) {
261 self.config.gain = VolumeConfig::db_to_linear(db);
262 }
263
264 pub fn start_fade_in(&mut self, duration_samples: usize) {
266 self.config.fade = Some(FadeConfig::fade_in(duration_samples));
267 }
268
269 pub fn start_fade_out(&mut self, duration_samples: usize) {
271 self.config.fade = Some(FadeConfig::fade_out(duration_samples));
272 }
273
274 fn frame_to_samples(frame: &AudioFrame) -> Vec<Vec<f64>> {
276 let channels = frame.channels.count();
277 let sample_count = frame.sample_count();
278
279 if sample_count == 0 {
280 return vec![Vec::new(); channels];
281 }
282
283 let mut output = vec![Vec::with_capacity(sample_count); channels];
284
285 match &frame.samples {
286 AudioBuffer::Interleaved(data) => {
287 Self::convert_interleaved(data, frame.format, channels, &mut output);
288 }
289 AudioBuffer::Planar(planes) => {
290 Self::convert_planar(planes, frame.format, &mut output);
291 }
292 }
293
294 output
295 }
296
297 fn convert_interleaved(
299 data: &Bytes,
300 format: SampleFormat,
301 channels: usize,
302 output: &mut [Vec<f64>],
303 ) {
304 let bytes_per_sample = format.bytes_per_sample();
305 if bytes_per_sample == 0 || channels == 0 {
306 return;
307 }
308
309 let sample_count = data.len() / (bytes_per_sample * channels);
310
311 for i in 0..sample_count {
312 for ch in 0..channels {
313 let offset = (i * channels + ch) * bytes_per_sample;
314 if offset + bytes_per_sample <= data.len() {
315 let sample =
316 Self::bytes_to_f64(&data[offset..offset + bytes_per_sample], format);
317 output[ch].push(sample);
318 }
319 }
320 }
321 }
322
323 fn convert_planar(planes: &[Bytes], format: SampleFormat, output: &mut [Vec<f64>]) {
325 let bytes_per_sample = format.bytes_per_sample();
326 if bytes_per_sample == 0 {
327 return;
328 }
329
330 for (ch, plane) in planes.iter().enumerate() {
331 if ch >= output.len() {
332 break;
333 }
334 let sample_count = plane.len() / bytes_per_sample;
335 for i in 0..sample_count {
336 let offset = i * bytes_per_sample;
337 if offset + bytes_per_sample <= plane.len() {
338 let sample =
339 Self::bytes_to_f64(&plane[offset..offset + bytes_per_sample], format);
340 output[ch].push(sample);
341 }
342 }
343 }
344 }
345
346 fn bytes_to_f64(bytes: &[u8], format: SampleFormat) -> f64 {
348 match format {
349 SampleFormat::U8 => {
350 if bytes.is_empty() {
351 return 0.0;
352 }
353 (f64::from(bytes[0]) - 128.0) / 128.0
354 }
355 SampleFormat::S16 => {
356 if bytes.len() < 2 {
357 return 0.0;
358 }
359 let sample = i16::from_le_bytes([bytes[0], bytes[1]]);
360 f64::from(sample) / f64::from(i16::MAX)
361 }
362 SampleFormat::S32 => {
363 if bytes.len() < 4 {
364 return 0.0;
365 }
366 let sample = i32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]);
367 f64::from(sample) / f64::from(i32::MAX)
368 }
369 SampleFormat::F32 => {
370 if bytes.len() < 4 {
371 return 0.0;
372 }
373 f64::from(f32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]))
374 }
375 SampleFormat::F64 => {
376 if bytes.len() < 8 {
377 return 0.0;
378 }
379 f64::from_le_bytes([
380 bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7],
381 ])
382 }
383 _ => 0.0,
384 }
385 }
386
387 fn soft_clip(sample: f64, threshold: f64) -> f64 {
389 if sample.abs() <= threshold {
390 sample
391 } else {
392 let sign = sample.signum();
393 let excess = sample.abs() - threshold;
394 let range = 1.0 - threshold;
395 let compressed = threshold + range * (excess / range).tanh();
397 sign * compressed
398 }
399 }
400
401 fn find_peak(samples: &[Vec<f64>]) -> f64 {
403 samples
404 .iter()
405 .flat_map(|ch| ch.iter())
406 .map(|s| s.abs())
407 .fold(0.0_f64, f64::max)
408 }
409
410 fn process_samples(&mut self, samples: &mut [Vec<f64>]) {
412 let sample_count = samples.get(0).map_or(0, Vec::len);
413
414 let norm_gain = if self.config.normalize_peak {
416 let peak = Self::find_peak(samples);
417 if peak > 0.0 {
418 self.config.target_peak / peak
419 } else {
420 1.0
421 }
422 } else {
423 1.0
424 };
425
426 for sample_idx in 0..sample_count {
427 let fade_gain = if let Some(ref fade) = self.config.fade {
429 fade.gain_at_position(sample_idx)
430 } else {
431 1.0
432 };
433
434 let total_gain = self.config.gain * norm_gain * fade_gain;
436
437 for channel in samples.iter_mut() {
439 if sample_idx < channel.len() {
440 let mut sample = channel[sample_idx] * total_gain;
441
442 if self.config.soft_clip {
444 sample = Self::soft_clip(sample, self.config.soft_clip_threshold);
445 }
446
447 channel[sample_idx] = sample;
448 }
449 }
450 }
451
452 if let Some(ref mut fade) = self.config.fade {
454 fade.advance(sample_count);
455 }
456 }
457
458 fn samples_to_frame(
460 samples: Vec<Vec<f64>>,
461 format: SampleFormat,
462 sample_rate: u32,
463 channels: ChannelLayout,
464 ) -> AudioFrame {
465 let channel_count = channels.count();
466 if samples.is_empty() || samples[0].is_empty() || channel_count == 0 {
467 return AudioFrame::new(format, sample_rate, channels);
468 }
469
470 let sample_count = samples[0].len();
471 let bytes_per_sample = format.bytes_per_sample();
472 let mut buffer = BytesMut::with_capacity(sample_count * channel_count * bytes_per_sample);
473
474 for i in 0..sample_count {
475 for ch in 0..channel_count {
476 let sample = if ch < samples.len() && i < samples[ch].len() {
477 samples[ch][i]
478 } else {
479 0.0
480 };
481 Self::f64_to_bytes(sample, format, &mut buffer);
482 }
483 }
484
485 let mut frame = AudioFrame::new(format, sample_rate, channels);
486 frame.samples = AudioBuffer::Interleaved(buffer.freeze());
487 frame
488 }
489
490 fn f64_to_bytes(sample: f64, format: SampleFormat, buffer: &mut BytesMut) {
492 let clamped = sample.clamp(-1.0, 1.0);
493
494 match format {
495 SampleFormat::U8 => {
496 let value = ((clamped * 128.0) + 128.0) as u8;
497 buffer.extend_from_slice(&[value]);
498 }
499 SampleFormat::S16 => {
500 let value = (clamped * f64::from(i16::MAX)) as i16;
501 buffer.extend_from_slice(&value.to_le_bytes());
502 }
503 SampleFormat::S32 => {
504 let value = (clamped * f64::from(i32::MAX)) as i32;
505 buffer.extend_from_slice(&value.to_le_bytes());
506 }
507 SampleFormat::F32 => {
508 #[allow(clippy::cast_possible_truncation)]
509 let value = clamped as f32;
510 buffer.extend_from_slice(&value.to_le_bytes());
511 }
512 SampleFormat::F64 => {
513 buffer.extend_from_slice(&clamped.to_le_bytes());
514 }
515 _ => {}
516 }
517 }
518}
519
520impl Node for VolumeFilter {
521 fn id(&self) -> NodeId {
522 self.id
523 }
524
525 fn name(&self) -> &str {
526 &self.name
527 }
528
529 fn node_type(&self) -> NodeType {
530 NodeType::Filter
531 }
532
533 fn state(&self) -> NodeState {
534 self.state
535 }
536
537 fn set_state(&mut self, state: NodeState) -> GraphResult<()> {
538 if !self.state.can_transition_to(state) {
539 return Err(GraphError::InvalidStateTransition {
540 node: self.id,
541 from: self.state.to_string(),
542 to: state.to_string(),
543 });
544 }
545 self.state = state;
546 Ok(())
547 }
548
549 fn inputs(&self) -> &[InputPort] {
550 &self.inputs
551 }
552
553 fn outputs(&self) -> &[OutputPort] {
554 &self.outputs
555 }
556
557 fn process(&mut self, input: Option<FilterFrame>) -> GraphResult<Option<FilterFrame>> {
558 let frame = match input {
559 Some(FilterFrame::Audio(frame)) => frame,
560 Some(_) => {
561 return Err(GraphError::PortTypeMismatch {
562 expected: "Audio".to_string(),
563 actual: "Video".to_string(),
564 });
565 }
566 None => return Ok(None),
567 };
568
569 let mut samples = Self::frame_to_samples(&frame);
571
572 self.process_samples(&mut samples);
574
575 let output_frame = Self::samples_to_frame(
577 samples,
578 frame.format,
579 frame.sample_rate,
580 frame.channels.clone(),
581 );
582
583 Ok(Some(FilterFrame::Audio(output_frame)))
584 }
585
586 fn reset(&mut self) -> GraphResult<()> {
587 self.config.fade = None;
588 self.set_state(NodeState::Idle)
589 }
590}
591
592#[cfg(test)]
593mod tests {
594 use super::*;
595
596 #[test]
597 fn test_db_to_linear() {
598 let linear = VolumeConfig::db_to_linear(0.0);
599 assert!((linear - 1.0).abs() < f64::EPSILON);
600
601 let linear = VolumeConfig::db_to_linear(-6.0);
602 assert!((linear - 0.501).abs() < 0.01);
603
604 let linear = VolumeConfig::db_to_linear(6.0);
605 assert!((linear - 1.995).abs() < 0.01);
606 }
607
608 #[test]
609 fn test_linear_to_db() {
610 let db = VolumeConfig::linear_to_db(1.0);
611 assert!(db.abs() < f64::EPSILON);
612
613 let db = VolumeConfig::linear_to_db(0.5);
614 assert!((db - (-6.02)).abs() < 0.1);
615
616 let db = VolumeConfig::linear_to_db(0.0);
617 assert!(db.is_infinite() && db.is_sign_negative());
618 }
619
620 #[test]
621 fn test_fade_config() {
622 let mut fade = FadeConfig::fade_in(1000);
623 assert!(fade.active);
624 assert!(!fade.is_complete());
625 assert!(fade.gain_at_position(0).abs() < f64::EPSILON);
626 assert!((fade.gain_at_position(500) - 0.5).abs() < f64::EPSILON);
627 assert!((fade.gain_at_position(1000) - 1.0).abs() < f64::EPSILON);
628
629 fade.advance(500);
630 assert!(!fade.is_complete());
631
632 fade.advance(500);
633 assert!(fade.is_complete());
634 assert!(!fade.active);
635 }
636
637 #[test]
638 fn test_fade_out() {
639 let fade = FadeConfig::fade_out(1000);
640 assert!((fade.gain_at_position(0) - 1.0).abs() < f64::EPSILON);
641 assert!((fade.gain_at_position(500) - 0.5).abs() < f64::EPSILON);
642 assert!(fade.gain_at_position(1000).abs() < f64::EPSILON);
643 }
644
645 #[test]
646 fn test_volume_config() {
647 let config = VolumeConfig::from_db(-6.0)
648 .with_fade_in(1000)
649 .with_peak_normalization(0.9)
650 .with_soft_clip(0.8);
651
652 assert!((config.gain - 0.501).abs() < 0.01);
653 assert!(config.fade.is_some());
654 assert!(config.normalize_peak);
655 assert!((config.target_peak - 0.9).abs() < f64::EPSILON);
656 assert!(config.soft_clip);
657 assert!((config.soft_clip_threshold - 0.8).abs() < f64::EPSILON);
658 }
659
660 #[test]
661 fn test_soft_clip() {
662 let result = VolumeFilter::soft_clip(0.5, 0.9);
663 assert!((result - 0.5).abs() < f64::EPSILON);
664
665 let result = VolumeFilter::soft_clip(1.5, 0.9);
666 assert!(result > 0.9);
667 assert!(result < 1.0);
668
669 let result = VolumeFilter::soft_clip(-1.5, 0.9);
670 assert!(result < -0.9);
671 assert!(result > -1.0);
672 }
673
674 #[test]
675 fn test_find_peak() {
676 let samples = vec![vec![0.5, -0.8, 0.3], vec![0.2, 0.9, -0.1]];
677 let peak = VolumeFilter::find_peak(&samples);
678 assert!((peak - 0.9).abs() < f64::EPSILON);
679 }
680
681 #[test]
682 fn test_volume_filter_creation() {
683 let config = VolumeConfig::new(0.5);
684 let filter = VolumeFilter::new(NodeId(1), "volume", config);
685
686 assert_eq!(filter.id(), NodeId(1));
687 assert_eq!(filter.name(), "volume");
688 assert_eq!(filter.node_type(), NodeType::Filter);
689 }
690
691 #[test]
692 fn test_volume_filter_ports() {
693 let config = VolumeConfig::default();
694 let filter = VolumeFilter::new(NodeId(0), "test", config);
695
696 assert_eq!(filter.inputs().len(), 1);
697 assert_eq!(filter.outputs().len(), 1);
698 assert_eq!(filter.inputs()[0].port_type, PortType::Audio);
699 }
700
701 #[test]
702 fn test_set_gain() {
703 let config = VolumeConfig::default();
704 let mut filter = VolumeFilter::new(NodeId(0), "test", config);
705
706 filter.set_gain(0.5);
707 assert!((filter.config().gain - 0.5).abs() < f64::EPSILON);
708
709 filter.set_gain_db(-6.0);
710 assert!((filter.config().gain - 0.501).abs() < 0.01);
711 }
712
713 #[test]
714 fn test_start_fade() {
715 let config = VolumeConfig::default();
716 let mut filter = VolumeFilter::new(NodeId(0), "test", config);
717
718 filter.start_fade_in(1000);
719 assert!(filter.config().fade.is_some());
720 assert_eq!(
721 filter
722 .config()
723 .fade
724 .as_ref()
725 .expect("as_ref should succeed")
726 .direction,
727 FadeDirection::In
728 );
729
730 filter.start_fade_out(2000);
731 assert_eq!(
732 filter
733 .config()
734 .fade
735 .as_ref()
736 .expect("as_ref should succeed")
737 .direction,
738 FadeDirection::Out
739 );
740 }
741
742 #[test]
743 fn test_process_none() {
744 let config = VolumeConfig::default();
745 let mut filter = VolumeFilter::new(NodeId(0), "test", config);
746
747 let result = filter.process(None).expect("process should succeed");
748 assert!(result.is_none());
749 }
750
751 #[test]
752 fn test_process_with_gain() {
753 let config = VolumeConfig::new(0.5);
754 let mut filter = VolumeFilter::new(NodeId(0), "test", config);
755
756 let mut frame = AudioFrame::new(SampleFormat::F32, 48000, ChannelLayout::Mono);
757 let mut samples = BytesMut::new();
758 samples.extend_from_slice(&1.0f32.to_le_bytes());
759 frame.samples = AudioBuffer::Interleaved(samples.freeze());
760
761 let result = filter
762 .process(Some(FilterFrame::Audio(frame)))
763 .expect("process should succeed");
764 assert!(result.is_some());
765
766 if let Some(FilterFrame::Audio(output)) = result {
767 if let AudioBuffer::Interleaved(data) = &output.samples {
768 let sample = f32::from_le_bytes([data[0], data[1], data[2], data[3]]);
769 assert!((sample - 0.5).abs() < 0.01);
770 }
771 }
772 }
773
774 #[test]
775 fn test_state_transitions() {
776 let config = VolumeConfig::default();
777 let mut filter = VolumeFilter::new(NodeId(0), "test", config);
778
779 assert!(filter.set_state(NodeState::Processing).is_ok());
780 assert_eq!(filter.state(), NodeState::Processing);
781
782 assert!(filter.reset().is_ok());
783 assert_eq!(filter.state(), NodeState::Idle);
784 assert!(filter.config().fade.is_none());
785 }
786
787 #[test]
788 fn test_bytes_conversion_roundtrip() {
789 let original = 0.5;
790 let mut buffer = BytesMut::new();
791
792 VolumeFilter::f64_to_bytes(original, SampleFormat::F32, &mut buffer);
793 let converted = VolumeFilter::bytes_to_f64(&buffer, SampleFormat::F32);
794
795 assert!((original - converted).abs() < 0.0001);
796 }
797
798 #[test]
799 fn test_peak_normalization() {
800 let config = VolumeConfig::new(1.0).with_peak_normalization(1.0);
801 let mut filter = VolumeFilter::new(NodeId(0), "test", config);
802
803 let mut frame = AudioFrame::new(SampleFormat::F32, 48000, ChannelLayout::Mono);
804 let mut samples = BytesMut::new();
805 samples.extend_from_slice(&0.5f32.to_le_bytes());
807 frame.samples = AudioBuffer::Interleaved(samples.freeze());
808
809 let result = filter
810 .process(Some(FilterFrame::Audio(frame)))
811 .expect("process should succeed");
812 assert!(result.is_some());
813
814 if let Some(FilterFrame::Audio(output)) = result {
815 if let AudioBuffer::Interleaved(data) = &output.samples {
816 let sample = f32::from_le_bytes([data[0], data[1], data[2], data[3]]);
817 assert!((sample - 1.0).abs() < 0.01);
819 }
820 }
821 }
822}