Skip to main content

audio_engine_core/decoder/
source.rs

1use std::fs::File;
2#[cfg(feature = "http")]
3use std::io::{Cursor, Read, Seek, SeekFrom};
4use std::path::Path;
5#[cfg(feature = "http")]
6use std::time::Duration;
7
8use symphonia::core::io::MediaSourceStream;
9use symphonia::core::probe::Hint;
10
11#[cfg(feature = "http")]
12use super::error::{network_error_to_decoder_error, with_network_retry, NetworkError};
13use super::error::{DecodeCancelToken, DecoderError};
14
15#[cfg(feature = "http")]
16const HTTP_CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
17#[cfg(feature = "http")]
18const HTTP_RANGE_STREAM_TIMEOUT: Duration = Duration::from_secs(30);
19#[cfg(feature = "http")]
20const HTTP_FULL_DOWNLOAD_TIMEOUT: Duration = Duration::from_secs(120);
21pub(super) const BYTES_PER_MIB: usize = 1024 * 1024;
22pub(super) const F64_SAMPLE_BYTES: usize = std::mem::size_of::<f64>();
23#[cfg(feature = "http")]
24const NON_RANGE_DOWNLOAD_MEMORY_DIVISOR: usize = 8;
25#[cfg(feature = "http")]
26pub(super) const RANGE_PREFETCH: usize = 256 * 1024;
27
28/// HTTP Basic authentication credentials for remote audio sources.
29///
30/// Ignored for local file paths. Only consulted by the HTTP source paths,
31/// which require the `http` feature.
32#[derive(Debug, Clone, Default)]
33pub struct HttpCredentials {
34    pub username: String,
35    pub password: String,
36}
37
38pub(super) fn configured_decode_memory_limit() -> (usize, usize) {
39    let budget = crate::diagnostics::decode_memory_budget();
40    (budget.limit_mb, budget.limit_bytes)
41}
42
43pub(super) fn bytes_to_mib(bytes: usize) -> usize {
44    bytes / BYTES_PER_MIB
45}
46
47pub(super) fn open_media_source(
48    path: &Path,
49    credentials: Option<&HttpCredentials>,
50    cancel_token: Option<DecodeCancelToken>,
51) -> Result<(MediaSourceStream, Hint), DecoderError> {
52    let path_str = path.to_string_lossy();
53    if cancel_token
54        .as_ref()
55        .is_some_and(DecodeCancelToken::is_cancelled)
56    {
57        return Err(DecoderError::Canceled);
58    }
59
60    if path_str.starts_with("http://") || path_str.starts_with("https://") {
61        #[cfg(feature = "http")]
62        {
63            open_http_media_source(path_str.as_ref(), credentials, cancel_token)
64        }
65        #[cfg(not(feature = "http"))]
66        {
67            let _ = credentials;
68            Err(DecoderError::Probe(
69                "HTTP sources require the `http` feature of audio-engine-core".to_string(),
70            ))
71        }
72    } else {
73        let file = File::open(path)?;
74        let mss = MediaSourceStream::new(Box::new(file), Default::default());
75        let mut hint = Hint::new();
76        if let Some(ext) = path.extension().and_then(|e| e.to_str()) {
77            hint.with_extension(ext);
78        }
79        Ok((mss, hint))
80    }
81}
82
83#[cfg(feature = "http")]
84fn open_http_media_source(
85    url: &str,
86    credentials: Option<&HttpCredentials>,
87    cancel_token: Option<DecodeCancelToken>,
88) -> Result<(MediaSourceStream, Hint), DecoderError> {
89    let owned_creds = credentials.cloned();
90    match RangeStream::new(url.to_string(), owned_creds, cancel_token.clone()) {
91        Ok(stream) if stream.is_usable_range_stream() => {
92            log::info!("HTTP URL supports Range requests, streaming: {}", url);
93            let mss = MediaSourceStream::new(Box::new(stream), Default::default());
94            Ok((mss, hint_from_url(url)))
95        }
96        Err(DecoderError::Canceled) => Err(DecoderError::Canceled),
97        _ => {
98            log::info!(
99                "HTTP URL does not support Range, falling back to full download: {}",
100                url
101            );
102            let cursor = download_full_source(url, credentials, cancel_token.as_ref())?;
103            let mss = MediaSourceStream::new(Box::new(cursor), Default::default());
104            Ok((mss, hint_from_url(url)))
105        }
106    }
107}
108
109#[cfg(feature = "http")]
110fn hint_from_url(url: &str) -> Hint {
111    let mut hint = Hint::new();
112    if let Some(ext) = url
113        .split('?')
114        .next()
115        .and_then(|p| p.rsplit('.').next())
116        .filter(|e| e.len() <= 5)
117    {
118        hint.with_extension(ext);
119    }
120    hint
121}
122
123#[cfg(feature = "http")]
124fn download_full_source(
125    url: &str,
126    credentials: Option<&HttpCredentials>,
127    cancel_token: Option<&DecodeCancelToken>,
128) -> Result<Cursor<Vec<u8>>, DecoderError> {
129    let (_, max_memory_bytes) = configured_decode_memory_limit();
130    let max_download_bytes = max_memory_bytes / NON_RANGE_DOWNLOAD_MEMORY_DIVISOR;
131
132    let client = reqwest::blocking::Client::builder()
133        .timeout(HTTP_FULL_DOWNLOAD_TIMEOUT)
134        .connect_timeout(HTTP_CONNECT_TIMEOUT)
135        .build()
136        .map_err(|e| {
137            DecoderError::Network(NetworkError::Other(format!(
138                "Failed to create HTTP client: {}",
139                e
140            )))
141        })?;
142
143    let content_length = with_network_retry("HTTP full-download HEAD", || {
144        if cancel_token.is_some_and(DecodeCancelToken::is_cancelled) {
145            return Err(NetworkError::Other("Decode cancelled".to_string()));
146        }
147        let mut head_req = client.head(url);
148        if let Some(creds) = credentials {
149            head_req = head_req.basic_auth(&creds.username, Some(&creds.password));
150        }
151        let response = head_req.send().map_err(NetworkError::from)?;
152        if let Some(e) = response_network_error(&response) {
153            return Err(e);
154        }
155        Ok(response
156            .headers()
157            .get("content-length")
158            .and_then(|v| v.to_str().ok())
159            .and_then(|s| s.parse().ok()))
160    })
161    .map_err(network_error_to_decoder_error)?;
162
163    if let Some(len) = content_length {
164        checked_download_capacity(Some(len), max_download_bytes)?;
165        log::info!(
166            "Downloading {} MB file (server does not support Range)",
167            len / BYTES_PER_MIB as u64
168        );
169    } else {
170        log::warn!("Content-Length unknown, downloading without size check (may cause OOM)");
171    }
172
173    let response = with_network_retry("HTTP full-download GET", || {
174        if cancel_token.is_some_and(DecodeCancelToken::is_cancelled) {
175            return Err(NetworkError::Other("Decode cancelled".to_string()));
176        }
177        let mut req = client.get(url);
178        if let Some(creds) = credentials {
179            req = req.basic_auth(&creds.username, Some(&creds.password));
180        }
181        let response = req.send().map_err(NetworkError::from)?;
182        if let Some(e) = response_network_error(&response) {
183            return Err(e);
184        }
185        Ok(response)
186    })
187    .map_err(network_error_to_decoder_error)?;
188
189    let download_capacity = checked_download_capacity(
190        content_length.or(response.content_length()),
191        max_download_bytes,
192    )?;
193    let mut stream = response;
194    let mut buffer = Vec::with_capacity(download_capacity.unwrap_or(RANGE_PREFETCH));
195    let mut chunk = [0_u8; 64 * 1024];
196    loop {
197        if cancel_token.is_some_and(DecodeCancelToken::is_cancelled) {
198            return Err(DecoderError::Canceled);
199        }
200
201        let n = stream
202            .read(&mut chunk)
203            .map_err(|e| DecoderError::Network(NetworkError::from_io(e)))?;
204        if n == 0 {
205            break;
206        }
207
208        buffer.extend_from_slice(&chunk[..n]);
209        if buffer.len() > max_download_bytes {
210            let actual_mb = bytes_to_mib(buffer.len());
211            return Err(DecoderError::Network(NetworkError::Other(format!(
212                "Downloaded file exceeds memory limit: {} MB (limit: {} MB)",
213                actual_mb,
214                bytes_to_mib(max_download_bytes)
215            ))));
216        }
217    }
218
219    log::debug!(
220        "Downloaded {} bytes into buffer with initial capacity {}",
221        buffer.len(),
222        download_capacity.unwrap_or(RANGE_PREFETCH)
223    );
224    Ok(Cursor::new(buffer))
225}
226
227#[cfg(feature = "http")]
228fn checked_download_capacity(
229    content_length: Option<u64>,
230    max_download_bytes: usize,
231) -> Result<Option<usize>, DecoderError> {
232    let Some(len) = content_length else {
233        return Ok(None);
234    };
235
236    if len > max_download_bytes as u64 {
237        let len_mb = len / BYTES_PER_MIB as u64;
238        return Err(DecoderError::Network(NetworkError::Other(format!(
239            "File too large for non-Range download: {} MB (limit: {} MB). \
240             Server must support Range requests for files this size. \
241             Increase DECODE_MAX_MEMORY_MB env var if needed.",
242            len_mb,
243            bytes_to_mib(max_download_bytes)
244        ))));
245    }
246
247    Ok(Some(len as usize))
248}
249
250#[cfg(feature = "http")]
251fn response_network_error(response: &reqwest::blocking::Response) -> Option<NetworkError> {
252    let status = response.status();
253    (!status.is_success() && status.as_u16() != 206)
254        .then_some(NetworkError::HttpStatus(status.as_u16()))
255}
256
257#[cfg(feature = "http")]
258pub(super) fn fetch_range_once(
259    client: &reqwest::blocking::Client,
260    url: &str,
261    credentials: Option<&HttpCredentials>,
262    start: u64,
263    len: usize,
264    cancel_token: Option<&DecodeCancelToken>,
265) -> Result<Vec<u8>, NetworkError> {
266    if len == 0 {
267        return Ok(Vec::new());
268    }
269    if cancel_token.is_some_and(DecodeCancelToken::is_cancelled) {
270        return Err(NetworkError::Other("Decode cancelled".to_string()));
271    }
272
273    let end = start
274        .checked_add(len as u64 - 1)
275        .ok_or_else(|| NetworkError::Other("Range end overflow".into()))?;
276
277    let mut req = client
278        .get(url)
279        .header("Range", format!("bytes={}-{}", start, end));
280    if let Some(creds) = credentials {
281        req = req.basic_auth(&creds.username, Some(&creds.password));
282    }
283
284    let response = req.send().map_err(NetworkError::from)?;
285    if let Some(e) = response_network_error(&response) {
286        return Err(e);
287    }
288
289    let bytes = response.bytes().map_err(NetworkError::from)?;
290    if cancel_token.is_some_and(DecodeCancelToken::is_cancelled) {
291        return Err(NetworkError::Other("Decode cancelled".to_string()));
292    }
293
294    Ok(bytes.to_vec())
295}
296
297#[cfg(feature = "http")]
298struct RangeStream {
299    url: String,
300    credentials: Option<HttpCredentials>,
301    client: reqwest::blocking::Client,
302    buf: Vec<u8>,
303    buf_start: u64,
304    pos: u64,
305    content_length: Option<u64>,
306    supports_range: bool,
307    cancel_token: Option<DecodeCancelToken>,
308}
309
310#[cfg(feature = "http")]
311impl RangeStream {
312    fn new(
313        url: String,
314        credentials: Option<HttpCredentials>,
315        cancel_token: Option<DecodeCancelToken>,
316    ) -> Result<Self, DecoderError> {
317        let client = reqwest::blocking::Client::builder()
318            .timeout(HTTP_RANGE_STREAM_TIMEOUT)
319            .connect_timeout(HTTP_CONNECT_TIMEOUT)
320            .build()
321            .map_err(|e| {
322                DecoderError::Network(NetworkError::Other(format!(
323                    "Failed to create HTTP client: {}",
324                    e
325                )))
326            })?;
327
328        let (content_length, supports_range) =
329            with_network_retry("HTTP stream initialization", || {
330                if cancel_token
331                    .as_ref()
332                    .is_some_and(DecodeCancelToken::is_cancelled)
333                {
334                    return Err(NetworkError::Other("Decode cancelled".to_string()));
335                }
336
337                let mut head_req = client.head(&url);
338                if let Some(ref creds) = credentials {
339                    head_req = head_req.basic_auth(&creds.username, Some(&creds.password));
340                }
341
342                let head_response = match head_req.send() {
343                    Ok(response) => {
344                        if let Some(e) = response_network_error(&response) {
345                            return Err(e);
346                        }
347                        Some(response)
348                    }
349                    Err(e) => return Err(NetworkError::from(e)),
350                };
351
352                let mut content_length = head_response.as_ref().and_then(|r| {
353                    r.headers()
354                        .get("content-length")
355                        .and_then(|v| v.to_str().ok())
356                        .and_then(|s| s.parse().ok())
357                });
358                let mut supports_range = head_response
359                    .as_ref()
360                    .map(|r| {
361                        r.headers()
362                            .get("accept-ranges")
363                            .and_then(|v| v.to_str().ok())
364                            .map(|s| s == "bytes")
365                            .unwrap_or(false)
366                    })
367                    .unwrap_or(false);
368
369                if !supports_range || content_length.is_none() {
370                    if cancel_token
371                        .as_ref()
372                        .is_some_and(DecodeCancelToken::is_cancelled)
373                    {
374                        return Err(NetworkError::Other("Decode cancelled".to_string()));
375                    }
376                    let mut range_req = client.get(&url).header("Range", "bytes=0-0");
377                    if let Some(ref creds) = credentials {
378                        range_req = range_req.basic_auth(&creds.username, Some(&creds.password));
379                    }
380                    let range_response = range_req.send().map_err(NetworkError::from)?;
381                    if let Some(e) = response_network_error(&range_response) {
382                        return Err(e);
383                    }
384                    let range_status = range_response.status();
385                    let range_content_total = range_response
386                        .headers()
387                        .get("content-range")
388                        .and_then(|v| v.to_str().ok())
389                        .and_then(|s| s.split('/').next_back().and_then(|s| s.parse().ok()));
390                    let range_content_length = range_response
391                        .headers()
392                        .get("content-length")
393                        .and_then(|v| v.to_str().ok())
394                        .and_then(|s| s.parse().ok());
395
396                    if range_status.as_u16() == 206 || range_content_total.is_some() {
397                        supports_range = true;
398                        content_length = range_content_total.or(content_length);
399                    } else if content_length.is_none() {
400                        content_length = range_content_length;
401                    }
402                }
403
404                Ok((content_length, supports_range))
405            })
406            .map_err(network_error_to_decoder_error)?;
407
408        let initial_fetch_len = content_length
409            .map(|len| RANGE_PREFETCH.min(len as usize))
410            .unwrap_or(RANGE_PREFETCH);
411        let initial_buf = if supports_range && initial_fetch_len > 0 {
412            with_network_retry("HTTP stream initial range GET", || {
413                fetch_range_once(
414                    &client,
415                    &url,
416                    credentials.as_ref(),
417                    0,
418                    initial_fetch_len,
419                    cancel_token.as_ref(),
420                )
421            })
422            .map_err(network_error_to_decoder_error)?
423        } else {
424            Vec::new()
425        };
426
427        Ok(Self {
428            url,
429            credentials,
430            client,
431            buf: initial_buf,
432            buf_start: 0,
433            pos: 0,
434            content_length,
435            supports_range,
436            cancel_token,
437        })
438    }
439
440    fn is_usable_range_stream(&self) -> bool {
441        self.supports_range && self.content_length.is_some()
442    }
443
444    fn fetch_range(&mut self, start: u64, len: usize) -> Result<Vec<u8>, DecoderError> {
445        fetch_range_once(
446            &self.client,
447            &self.url,
448            self.credentials.as_ref(),
449            start,
450            len,
451            self.cancel_token.as_ref(),
452        )
453        .map_err(network_error_to_decoder_error)
454    }
455
456    fn ensure_buffered(&mut self, need: usize) -> std::io::Result<()> {
457        let buf_end = self.buf_start + self.buf.len() as u64;
458        if self.pos >= self.buf_start && self.pos + need as u64 <= buf_end {
459            return Ok(());
460        }
461
462        let fetch_len = need.max(RANGE_PREFETCH);
463        let fetch_len = if let Some(cl) = self.content_length {
464            fetch_len.min((cl.saturating_sub(self.pos)) as usize)
465        } else {
466            fetch_len
467        };
468        if fetch_len == 0 {
469            return Ok(());
470        }
471        let data = self
472            .fetch_range(self.pos, fetch_len)
473            .map_err(|e| std::io::Error::other(e.to_string()))?;
474        self.buf_start = self.pos;
475        self.buf = data;
476        Ok(())
477    }
478}
479
480#[cfg(feature = "http")]
481impl Read for RangeStream {
482    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
483        if buf.is_empty() {
484            return Ok(0);
485        }
486        if self
487            .cancel_token
488            .as_ref()
489            .is_some_and(DecodeCancelToken::is_cancelled)
490        {
491            return Err(std::io::Error::new(
492                std::io::ErrorKind::Interrupted,
493                "Decode cancelled",
494            ));
495        }
496        self.ensure_buffered(buf.len())?;
497        let offset = (self.pos - self.buf_start) as usize;
498        let available = self.buf.len().saturating_sub(offset);
499        if available == 0 {
500            return Ok(0);
501        }
502        let n = available.min(buf.len());
503        buf[..n].copy_from_slice(&self.buf[offset..offset + n]);
504        self.pos += n as u64;
505        Ok(n)
506    }
507}
508
509#[cfg(feature = "http")]
510impl Seek for RangeStream {
511    fn seek(&mut self, pos: SeekFrom) -> std::io::Result<u64> {
512        let new_pos = match pos {
513            SeekFrom::Start(p) => p as i64,
514            SeekFrom::Current(d) => self.pos as i64 + d,
515            SeekFrom::End(d) => {
516                if let Some(len) = self.content_length {
517                    len as i64 + d
518                } else {
519                    return Err(std::io::Error::new(
520                        std::io::ErrorKind::Unsupported,
521                        "content-length unknown",
522                    ));
523                }
524            }
525        };
526        if new_pos < 0 {
527            return Err(std::io::Error::new(
528                std::io::ErrorKind::InvalidInput,
529                "negative seek",
530            ));
531        }
532        self.pos = new_pos as u64;
533        Ok(self.pos)
534    }
535}
536
537#[cfg(feature = "http")]
538impl symphonia::core::io::MediaSource for RangeStream {
539    fn is_seekable(&self) -> bool {
540        self.is_usable_range_stream()
541    }
542
543    fn byte_len(&self) -> Option<u64> {
544        self.content_length
545    }
546}
547
548#[cfg(all(test, feature = "http"))]
549mod tests {
550    use super::*;
551
552    #[test]
553    fn range_stream_requires_range_support_and_content_length() {
554        let stream = RangeStream {
555            url: "https://example.test/song.flac".to_string(),
556            credentials: None,
557            client: reqwest::blocking::Client::builder()
558                .build()
559                .expect("client fixture"),
560            buf: Vec::new(),
561            buf_start: 0,
562            pos: 0,
563            content_length: Some(1024),
564            supports_range: true,
565            cancel_token: None,
566        };
567        assert!(stream.is_usable_range_stream());
568
569        let no_range = RangeStream {
570            supports_range: false,
571            ..stream
572        };
573        assert!(!no_range.is_usable_range_stream());
574
575        let no_len = RangeStream {
576            supports_range: true,
577            content_length: None,
578            ..no_range
579        };
580        assert!(!no_len.is_usable_range_stream());
581    }
582}