Skip to main content

rodio/
buffer.rs

1//! A simple source of samples coming from a buffer.
2//!
3//! The `SamplesBuffer` struct can be used to treat a list of values as a `Source`.
4//!
5//! # Example
6//!
7//! ```
8//! use rodio::buffer::SamplesBuffer;
9//! use core::num::NonZero;
10//! let _ = SamplesBuffer::new(NonZero::new(1).unwrap(), NonZero::new(44100).unwrap(), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
11//! ```
12//!
13
14use crate::common::{ChannelCount, SampleRate};
15use crate::math::{duration_to_float, NANOS_PER_SEC};
16use crate::source::{SeekError, UniformSourceIterator};
17use crate::{Float, Sample, Source};
18use std::sync::Arc;
19use std::time::Duration;
20
21/// A buffer of samples treated as a source.
22#[derive(Debug, Clone)]
23pub struct SamplesBuffer {
24    data: Arc<[Sample]>,
25    pos: usize,
26    channels: ChannelCount,
27    sample_rate: SampleRate,
28    duration: Duration,
29}
30
31impl SamplesBuffer {
32    /// Builds a new `SamplesBuffer`.
33    ///
34    /// # Panics
35    ///
36    /// - Panics if the samples rate is zero.
37    /// - Panics if the length of the buffer is larger than approximately 16 billion elements.
38    ///   This is because the calculation of the duration would overflow.
39    ///
40    pub fn new<D>(channels: ChannelCount, sample_rate: SampleRate, data: D) -> Self
41    where
42        D: Into<Vec<Sample>>,
43    {
44        let data: Arc<[Sample]> = data.into().into();
45        let duration_ns = NANOS_PER_SEC.checked_mul(data.len() as u64).unwrap()
46            / sample_rate.get() as u64
47            / channels.get() as u64;
48        let duration = Duration::new(
49            duration_ns / NANOS_PER_SEC,
50            (duration_ns % NANOS_PER_SEC) as u32,
51        );
52
53        Self {
54            data,
55            pos: 0,
56            channels,
57            sample_rate,
58            duration,
59        }
60    }
61
62    pub(crate) fn record_source(source: impl Source) -> Self {
63        let channel_count = source.channels();
64        let sample_rate = source.sample_rate();
65        let source = UniformSourceIterator::new(source, channel_count, sample_rate);
66        Self::new(
67            source.channels(),
68            source.sample_rate(),
69            source.into_iter().collect::<Vec<_>>(),
70        )
71    }
72}
73
74impl Source for SamplesBuffer {
75    #[inline]
76    fn current_span_len(&self) -> Option<usize> {
77        if self.pos >= self.data.len() {
78            Some(0)
79        } else {
80            Some(self.data.len())
81        }
82    }
83
84    #[inline]
85    fn channels(&self) -> ChannelCount {
86        self.channels
87    }
88
89    #[inline]
90    fn sample_rate(&self) -> SampleRate {
91        self.sample_rate
92    }
93
94    #[inline]
95    fn total_duration(&self) -> Option<Duration> {
96        Some(self.duration)
97    }
98
99    /// This jumps in memory till the sample for `pos`.
100    #[inline]
101    fn try_seek(&mut self, pos: Duration) -> Result<(), SeekError> {
102        // This is fast because all the samples are in memory already
103        // and due to the constant sample_rate we can jump to the right
104        // sample directly.
105
106        let curr_channel = self.pos % self.channels().get() as usize;
107        let new_pos = duration_to_float(pos)
108            * self.sample_rate().get() as Float
109            * self.channels().get() as Float;
110        // saturate pos at the end of the source
111        let new_pos = new_pos as usize;
112        let new_pos = new_pos.min(self.data.len());
113
114        // make sure the next sample is for the right channel
115        let new_pos = new_pos.next_multiple_of(self.channels().get() as usize);
116        let new_pos = new_pos - curr_channel;
117
118        self.pos = new_pos;
119        Ok(())
120    }
121}
122
123impl Iterator for SamplesBuffer {
124    type Item = Sample;
125
126    #[inline]
127    fn next(&mut self) -> Option<Self::Item> {
128        let sample = self.data.get(self.pos)?;
129        self.pos += 1;
130        Some(*sample)
131    }
132
133    #[inline]
134    fn size_hint(&self) -> (usize, Option<usize>) {
135        let remaining = self.data.len() - self.pos;
136        (remaining, Some(remaining))
137    }
138}
139
140impl ExactSizeIterator for SamplesBuffer {}
141
142#[cfg(test)]
143mod tests {
144    use crate::buffer::SamplesBuffer;
145    use crate::math::nz;
146    use crate::source::Source;
147
148    #[test]
149    fn basic() {
150        let _ = SamplesBuffer::new(nz!(1), nz!(44100), vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0]);
151    }
152
153    #[test]
154    fn duration_basic() {
155        let buf = SamplesBuffer::new(nz!(2), nz!(2), vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0]);
156        let dur = buf.total_duration().unwrap();
157        assert_eq!(dur.as_secs(), 1);
158        assert_eq!(dur.subsec_nanos(), 500_000_000);
159    }
160
161    #[test]
162    fn iteration() {
163        let mut buf = SamplesBuffer::new(nz!(1), nz!(44100), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
164        assert_eq!(buf.next(), Some(1.0));
165        assert_eq!(buf.next(), Some(2.0));
166        assert_eq!(buf.next(), Some(3.0));
167        assert_eq!(buf.next(), Some(4.0));
168        assert_eq!(buf.next(), Some(5.0));
169        assert_eq!(buf.next(), Some(6.0));
170        assert_eq!(buf.next(), None);
171    }
172
173    #[cfg(test)]
174    mod try_seek {
175        use super::*;
176        use crate::common::{ChannelCount, Float, SampleRate};
177        use crate::Sample;
178        use std::time::Duration;
179
180        #[test]
181        fn channel_order_stays_correct() {
182            const SAMPLE_RATE: SampleRate = nz!(100);
183            const CHANNELS: ChannelCount = nz!(2);
184            let mut buf = SamplesBuffer::new(
185                CHANNELS,
186                SAMPLE_RATE,
187                (0..2000i16).map(|s| s as Sample).collect::<Vec<_>>(),
188            );
189            buf.try_seek(Duration::from_secs(5)).unwrap();
190            assert_eq!(
191                buf.next(),
192                Some(5.0 * SAMPLE_RATE.get() as Float * CHANNELS.get() as Float)
193            );
194
195            assert!(buf.next().is_some_and(|s| s.trunc() as i32 % 2 == 1));
196            assert!(buf.next().is_some_and(|s| s.trunc() as i32 % 2 == 0));
197
198            buf.try_seek(Duration::from_secs(6)).unwrap();
199            assert!(buf.next().is_some_and(|s| s.trunc() as i32 % 2 == 1),);
200        }
201    }
202}