Skip to main content

kael_media/
lib.rs

1//! Audio playback primitives for GPUI.
2
3#![deny(missing_docs)]
4
5use ffmpeg_next as ffmpeg;
6use parking_lot::Mutex;
7use rodio::{Decoder, OutputStream, Sink, Source, buffer::SamplesBuffer};
8use std::{
9    collections::hash_map::DefaultHasher,
10    fmt,
11    fs::{self, File, OpenOptions},
12    hash::{Hash, Hasher},
13    io::{self, BufReader, Cursor, Read, Seek, SeekFrom, Write},
14    path::{Path, PathBuf},
15    rc::Rc,
16    sync::Arc,
17    sync::atomic::{AtomicU64, Ordering},
18    time::{Duration, Instant},
19};
20use thiserror::Error;
21
22trait MediaReadSeek: Read + Seek + Send + Sync {}
23
24impl<T> MediaReadSeek for T where T: Read + Seek + Send + Sync {}
25
26type MediaReaderFactory = dyn Fn() -> io::Result<Box<dyn MediaReadSeek>> + Send + Sync;
27
28/// Internal backing state for keyed reader-based media sources.
29#[doc(hidden)]
30pub struct ReaderMediaSource {
31    key: Arc<str>,
32    open: Arc<MediaReaderFactory>,
33    staged_path: Mutex<Option<PathBuf>>,
34}
35
36enum ResolvedMediaInput {
37    Path(PathBuf),
38    Url(Arc<str>),
39}
40
41static STAGED_READER_COUNTER: AtomicU64 = AtomicU64::new(0);
42const MAX_DECODED_VIDEO_FRAMES: usize = 256;
43const MAX_DECODED_VIDEO_BYTES: u64 = 128 * 1024 * 1024;
44
45/// A source of media content that can be played back.
46#[derive(Clone)]
47pub enum MediaSource {
48    /// Media content loaded from a file on disk.
49    File(PathBuf),
50    /// Media content loaded from a URL that FFmpeg can open directly.
51    Url(Arc<str>),
52    /// Media content already available in memory.
53    Bytes(Arc<[u8]>),
54    /// Media content opened on demand from a keyed reader factory.
55    Reader(Arc<ReaderMediaSource>),
56}
57
58impl MediaSource {
59    /// Create a media source backed by a file on disk.
60    pub fn file(path: impl Into<PathBuf>) -> Self {
61        Self::File(path.into())
62    }
63
64    /// Create a media source backed by a URL that FFmpeg can open directly.
65    pub fn url(url: impl Into<Arc<str>>) -> Self {
66        Self::Url(url.into())
67    }
68
69    /// Create a media source backed by in-memory bytes.
70    pub fn bytes(bytes: impl Into<Arc<[u8]>>) -> Self {
71        Self::Bytes(bytes.into())
72    }
73
74    /// Create a media source backed by a keyed reader factory.
75    ///
76    /// The key participates in hashing and equality for asset caching, so it must uniquely identify
77    /// the underlying media content for the lifetime of the source.
78    pub fn reader<R>(
79        key: impl Into<Arc<str>>,
80        open: impl Fn() -> io::Result<R> + Send + Sync + 'static,
81    ) -> Self
82    where
83        R: Read + Seek + Send + Sync + 'static,
84    {
85        let open =
86            Arc::new(move || open().map(|reader| -> Box<dyn MediaReadSeek> { Box::new(reader) }));
87        Self::Reader(Arc::new(ReaderMediaSource {
88            key: key.into(),
89            open,
90            staged_path: Mutex::new(None),
91        }))
92    }
93
94    /// Create a media source backed by compile-time bytes.
95    pub fn from_static_bytes(bytes: &'static [u8]) -> Self {
96        Self::Bytes(Arc::<[u8]>::from(bytes))
97    }
98
99    fn open_reader(&self) -> Result<MediaReader, AudioPlaybackError> {
100        match self {
101            Self::File(path) => Ok(MediaReader::File(BufReader::new(File::open(path)?))),
102            Self::Bytes(bytes) => Ok(MediaReader::Bytes(Cursor::new(bytes.clone()))),
103            Self::Reader(source) => Ok(MediaReader::Reader((source.open)()?)),
104            Self::Url(_) => Err(AudioPlaybackError::UnsupportedSource(
105                "url-backed media cannot be opened as a direct rodio reader".into(),
106            )),
107        }
108    }
109
110    fn direct_reader_supported(&self) -> bool {
111        !matches!(self, Self::Url(_))
112    }
113
114    fn resolve_ffmpeg_input(&self) -> Result<ResolvedMediaInput, MediaDecodeError> {
115        match self {
116            Self::File(path) => Ok(ResolvedMediaInput::Path(path.clone())),
117            Self::Url(url) => Ok(ResolvedMediaInput::Url(url.clone())),
118            Self::Bytes(bytes) => Ok(ResolvedMediaInput::Path(stage_bytes(bytes)?)),
119            Self::Reader(source) => Ok(ResolvedMediaInput::Path(
120                source.stage_to_path().map_err(MediaDecodeError::from_io)?,
121            )),
122        }
123    }
124}
125
126impl From<PathBuf> for MediaSource {
127    fn from(value: PathBuf) -> Self {
128        Self::File(value)
129    }
130}
131
132impl From<&Path> for MediaSource {
133    fn from(value: &Path) -> Self {
134        Self::File(value.to_path_buf())
135    }
136}
137
138impl From<Arc<[u8]>> for MediaSource {
139    fn from(value: Arc<[u8]>) -> Self {
140        Self::Bytes(value)
141    }
142}
143
144impl From<Arc<str>> for MediaSource {
145    fn from(value: Arc<str>) -> Self {
146        Self::Url(value)
147    }
148}
149
150impl From<Vec<u8>> for MediaSource {
151    fn from(value: Vec<u8>) -> Self {
152        Self::Bytes(Arc::<[u8]>::from(value))
153    }
154}
155
156impl From<&'static [u8]> for MediaSource {
157    fn from(value: &'static [u8]) -> Self {
158        Self::from_static_bytes(value)
159    }
160}
161
162impl fmt::Debug for MediaSource {
163    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
164        match self {
165            Self::File(path) => f.debug_tuple("File").field(path).finish(),
166            Self::Url(url) => f.debug_tuple("Url").field(url).finish(),
167            Self::Bytes(bytes) => f
168                .debug_tuple("Bytes")
169                .field(&format_args!("{} bytes", bytes.len()))
170                .finish(),
171            Self::Reader(source) => f.debug_tuple("Reader").field(&source.key).finish(),
172        }
173    }
174}
175
176impl PartialEq for MediaSource {
177    fn eq(&self, other: &Self) -> bool {
178        match (self, other) {
179            (Self::File(left), Self::File(right)) => left == right,
180            (Self::Url(left), Self::Url(right)) => left == right,
181            (Self::Bytes(left), Self::Bytes(right)) => left == right,
182            (Self::Reader(left), Self::Reader(right)) => left.key == right.key,
183            _ => false,
184        }
185    }
186}
187
188impl Eq for MediaSource {}
189
190impl Hash for MediaSource {
191    fn hash<H: Hasher>(&self, state: &mut H) {
192        std::mem::discriminant(self).hash(state);
193        match self {
194            Self::File(path) => path.hash(state),
195            Self::Url(url) => url.hash(state),
196            Self::Bytes(bytes) => bytes.hash(state),
197            Self::Reader(source) => source.key.hash(state),
198        }
199    }
200}
201
202/// Metadata for a decoded video stream.
203#[derive(Clone, Copy, Debug, PartialEq, Eq)]
204pub struct VideoMetadata {
205    /// The decoded frame width in pixels.
206    pub width: u32,
207    /// The decoded frame height in pixels.
208    pub height: u32,
209    /// The stream duration when FFmpeg reports one.
210    pub duration: Option<Duration>,
211}
212
213/// A decoded video frame in BGRA format.
214#[derive(Clone, Debug, PartialEq, Eq)]
215pub struct VideoFrame {
216    /// Raw BGRA pixel data for the frame.
217    pub data: Arc<[u8]>,
218    /// The frame width in pixels.
219    pub width: u32,
220    /// The frame height in pixels.
221    pub height: u32,
222    /// The presentation timestamp for this frame.
223    pub timestamp: Duration,
224}
225
226/// An error that can occur while decoding video content.
227#[derive(Debug, Error)]
228pub enum MediaDecodeError {
229    /// The requested source type is not supported by the current decoder path.
230    #[error("unsupported source: {0}")]
231    UnsupportedSource(String),
232    /// The media source does not contain a video stream.
233    #[error("no video stream found")]
234    NoVideoStream,
235    /// FFmpeg failed to open or decode the media source.
236    #[error("ffmpeg decode error: {0}")]
237    Decode(String),
238}
239
240/// A file-backed decoder for media metadata and video frames.
241#[derive(Clone, Debug)]
242pub struct MediaDecoder {
243    source: MediaSource,
244}
245
246struct OpenedVideoStream {
247    input_context: ffmpeg::format::context::Input,
248    decoder: ffmpeg::decoder::Video,
249    scaler: ffmpeg::software::scaling::context::Context,
250    video_stream_index: usize,
251    time_base: ffmpeg::Rational,
252    metadata: VideoMetadata,
253}
254
255/// A sequential decoder for file-backed video frames.
256///
257/// This stream decodes frames on demand and can be restarted when playback seeks backward.
258pub struct VideoFrameStream {
259    source: MediaSource,
260    input_context: ffmpeg::format::context::Input,
261    decoder: ffmpeg::decoder::Video,
262    scaler: ffmpeg::software::scaling::context::Context,
263    video_stream_index: usize,
264    time_base: ffmpeg::Rational,
265    metadata: VideoMetadata,
266    sent_eof: bool,
267}
268
269impl MediaDecoder {
270    /// Create a decoder for the given media source.
271    pub fn new(source: impl Into<MediaSource>) -> Self {
272        Self {
273            source: source.into(),
274        }
275    }
276
277    /// Return the source associated with this decoder.
278    pub fn source(&self) -> &MediaSource {
279        &self.source
280    }
281
282    /// Read the primary video stream metadata from the media source.
283    pub fn video_metadata(&self) -> Result<VideoMetadata, MediaDecodeError> {
284        Ok(VideoFrameStream::new(self.source.clone())?.metadata())
285    }
286
287    /// Decode all frames from the primary video stream up to an in-memory safety cap.
288    ///
289    /// Use [`VideoFrameStream`] when decoding larger videos incrementally.
290    pub fn decode_video_frames(&self) -> Result<Vec<VideoFrame>, MediaDecodeError> {
291        let mut stream = VideoFrameStream::new(self.source.clone())?;
292        let mut frames = Vec::new();
293        let mut decoded_bytes = 0u64;
294        while let Some(frame) = stream.next_frame()? {
295            push_decoded_video_frame(&mut frames, &mut decoded_bytes, frame)?;
296        }
297
298        Ok(frames)
299    }
300}
301
302fn push_decoded_video_frame(
303    frames: &mut Vec<VideoFrame>,
304    decoded_bytes: &mut u64,
305    frame: VideoFrame,
306) -> Result<(), MediaDecodeError> {
307    if frames.len() >= MAX_DECODED_VIDEO_FRAMES {
308        return Err(MediaDecodeError::Decode(format!(
309            "video decode exceeded {} frames; use VideoFrameStream for larger videos",
310            MAX_DECODED_VIDEO_FRAMES
311        )));
312    }
313
314    let frame_bytes = u64::try_from(frame.data.len()).unwrap_or(u64::MAX);
315    let next_total = decoded_bytes.saturating_add(frame_bytes);
316    if next_total > MAX_DECODED_VIDEO_BYTES {
317        return Err(MediaDecodeError::Decode(format!(
318            "video decode exceeded {} bytes; use VideoFrameStream for larger videos",
319            MAX_DECODED_VIDEO_BYTES
320        )));
321    }
322
323    *decoded_bytes = next_total;
324    frames.push(frame);
325    Ok(())
326}
327
328impl VideoFrameStream {
329    /// Create a new sequential frame stream for the given media source.
330    pub fn new(source: impl Into<MediaSource>) -> Result<Self, MediaDecodeError> {
331        let source = source.into();
332        let OpenedVideoStream {
333            input_context,
334            decoder,
335            scaler,
336            video_stream_index,
337            time_base,
338            metadata,
339        } = open_video_stream(&source)?;
340
341        Ok(Self {
342            source,
343            input_context,
344            decoder,
345            scaler,
346            video_stream_index,
347            time_base,
348            metadata,
349            sent_eof: false,
350        })
351    }
352
353    /// Return the source associated with this stream.
354    pub fn source(&self) -> &MediaSource {
355        &self.source
356    }
357
358    /// Return the decoded video metadata.
359    pub fn metadata(&self) -> VideoMetadata {
360        self.metadata
361    }
362
363    /// Restart decoding from the beginning of the media source.
364    pub fn restart(&mut self) -> Result<(), MediaDecodeError> {
365        *self = Self::new(self.source.clone())?;
366        Ok(())
367    }
368
369    /// Decode and return the next video frame, or `None` after end-of-stream.
370    pub fn next_frame(&mut self) -> Result<Option<VideoFrame>, MediaDecodeError> {
371        loop {
372            let mut decoded = ffmpeg::util::frame::video::Video::empty();
373            if self.decoder.receive_frame(&mut decoded).is_ok() {
374                return decode_video_frame(&mut self.scaler, &decoded, self.time_base).map(Some);
375            }
376
377            if self.sent_eof {
378                return Ok(None);
379            }
380
381            if let Some(packet) = self.next_video_packet() {
382                self.decoder
383                    .send_packet(&packet)
384                    .map_err(ffmpeg_decode_error)?;
385            } else {
386                self.decoder.send_eof().map_err(ffmpeg_decode_error)?;
387                self.sent_eof = true;
388            }
389        }
390    }
391
392    fn next_video_packet(&mut self) -> Option<ffmpeg::Packet> {
393        for (stream, packet) in self.input_context.packets() {
394            if stream.index() == self.video_stream_index {
395                return Some(packet);
396            }
397        }
398
399        None
400    }
401}
402
403enum MediaReader {
404    File(BufReader<File>),
405    Bytes(Cursor<Arc<[u8]>>),
406    Reader(Box<dyn MediaReadSeek>),
407}
408
409impl Read for MediaReader {
410    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
411        match self {
412            Self::File(reader) => reader.read(buf),
413            Self::Bytes(reader) => reader.read(buf),
414            Self::Reader(reader) => reader.read(buf),
415        }
416    }
417}
418
419impl Seek for MediaReader {
420    fn seek(&mut self, pos: SeekFrom) -> std::io::Result<u64> {
421        match self {
422            Self::File(reader) => reader.seek(pos),
423            Self::Bytes(reader) => reader.seek(pos),
424            Self::Reader(reader) => reader.seek(pos),
425        }
426    }
427}
428
429/// The current playback state for a media handle.
430#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
431pub enum PlaybackState {
432    /// Playback is actively advancing.
433    Playing,
434    /// Playback is paused at the current position.
435    Paused,
436    /// Playback is stopped and positioned at the start or end.
437    #[default]
438    Stopped,
439}
440
441/// An error that can occur while preparing or controlling audio playback.
442#[derive(Debug, Error)]
443pub enum AudioPlaybackError {
444    /// The media source could not be opened.
445    #[error("io error: {0}")]
446    Io(#[from] std::io::Error),
447    /// The media source cannot be opened through the direct rodio path.
448    #[error("unsupported source: {0}")]
449    UnsupportedSource(String),
450    /// The media data could not be decoded.
451    #[error("decoder error: {0}")]
452    Decoder(String),
453    /// The host audio output stream could not be created.
454    #[error("audio output error: {0}")]
455    Output(String),
456}
457
458struct AudioEngine {
459    _stream: OutputStream,
460    sink: Sink,
461}
462
463struct DecodedAudio {
464    channels: u16,
465    sample_rate: u32,
466    samples: Arc<[f32]>,
467    duration: Duration,
468}
469
470struct AudioHandleState {
471    source: MediaSource,
472    volume: f32,
473    duration: Option<Duration>,
474    decoded_audio: Option<Arc<DecodedAudio>>,
475    position: Duration,
476    started_at: Option<Instant>,
477    state: PlaybackState,
478    engine: Option<AudioEngine>,
479    generation: u64,
480}
481
482struct AudioPlaybackRequest {
483    generation: u64,
484    source: MediaSource,
485    volume: f32,
486    position: Duration,
487    duration: Option<Duration>,
488    decoded_audio: Option<Arc<DecodedAudio>>,
489    playback_state: PlaybackState,
490}
491
492struct AudioProbeRequest {
493    generation: u64,
494    source: MediaSource,
495    decoded_audio: Option<Arc<DecodedAudio>>,
496}
497
498/// A clonable controller for audio playback.
499#[derive(Clone)]
500pub struct AudioHandle {
501    state: Rc<Mutex<AudioHandleState>>,
502}
503
504impl fmt::Debug for AudioHandle {
505    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
506        let state = self.state.lock();
507        f.debug_struct("AudioHandle")
508            .field("source", &state.source)
509            .field("volume", &state.volume)
510            .field("duration", &state.duration)
511            .field("position", &state.current_position())
512            .field("state", &state.state)
513            .finish()
514    }
515}
516
517impl AudioHandle {
518    /// Create a new audio handle for the given source.
519    pub fn new(source: impl Into<MediaSource>) -> Self {
520        Self {
521            state: Rc::new(Mutex::new(AudioHandleState {
522                source: source.into(),
523                volume: 1.0,
524                duration: None,
525                decoded_audio: None,
526                position: Duration::ZERO,
527                started_at: None,
528                state: PlaybackState::Stopped,
529                engine: None,
530                generation: 0,
531            })),
532        }
533    }
534
535    /// Start or resume playback.
536    pub fn play(&self) -> Result<(), AudioPlaybackError> {
537        let request = {
538            let mut state = self.state.lock();
539            state.refresh_finished();
540            if state.state == PlaybackState::Playing {
541                return Ok(());
542            }
543
544            if state.state == PlaybackState::Paused
545                && let Some(engine) = state.engine.as_ref()
546            {
547                engine.sink.play();
548                state.started_at = Some(Instant::now());
549                state.state = PlaybackState::Playing;
550                state.generation += 1;
551                return Ok(());
552            }
553
554            let requested_position = if state
555                .duration
556                .is_some_and(|duration| state.position >= duration)
557            {
558                Duration::ZERO
559            } else {
560                state.position
561            };
562
563            AudioPlaybackRequest {
564                generation: state.generation,
565                source: state.source.clone(),
566                volume: state.volume,
567                position: requested_position,
568                duration: state.duration,
569                decoded_audio: state.decoded_audio.clone(),
570                playback_state: state.state,
571            }
572        };
573
574        let (engine, duration, position, decoded_audio) = create_engine_with_cache(
575            &request.source,
576            request.volume,
577            request.position,
578            request.decoded_audio,
579        )?;
580
581        let mut state = self.state.lock();
582        state.refresh_finished();
583        if state.generation != request.generation || state.state == PlaybackState::Playing {
584            return Ok(());
585        }
586
587        if let Some(decoded_audio) = decoded_audio {
588            state.decoded_audio = Some(decoded_audio);
589        }
590        state.duration = duration.or(state.duration);
591        state.position = position;
592        state.started_at = Some(Instant::now());
593        state.state = PlaybackState::Playing;
594        engine.sink.set_volume(state.volume.max(0.0));
595        state.engine = Some(engine);
596        state.generation += 1;
597        Ok(())
598    }
599
600    /// Pause playback, preserving the current position.
601    pub fn pause(&self) {
602        let mut state = self.state.lock();
603        state.refresh_finished();
604        if state.state != PlaybackState::Playing {
605            return;
606        }
607
608        state.position = state.current_position();
609        state.started_at = None;
610        if let Some(engine) = state.engine.as_ref() {
611            engine.sink.pause();
612        }
613        state.state = PlaybackState::Paused;
614        state.generation += 1;
615    }
616
617    /// Stop playback and reset the position to the start.
618    pub fn stop(&self) {
619        let mut state = self.state.lock();
620        if let Some(engine) = state.engine.take() {
621            engine.sink.stop();
622        }
623        state.position = Duration::ZERO;
624        state.started_at = None;
625        state.state = PlaybackState::Stopped;
626        state.generation += 1;
627    }
628
629    /// Seek to the given playback position.
630    pub fn seek(&self, position: Duration) -> Result<(), AudioPlaybackError> {
631        let request = {
632            let mut state = self.state.lock();
633            state.refresh_finished();
634            AudioPlaybackRequest {
635                generation: state.generation,
636                source: state.source.clone(),
637                volume: state.volume,
638                position,
639                duration: state.duration,
640                decoded_audio: state.decoded_audio.clone(),
641                playback_state: state.state,
642            }
643        };
644
645        let (duration, decoded_audio) = match request.duration {
646            Some(duration) => (Some(duration), request.decoded_audio.clone()),
647            None => probe_duration_with_cache(&request.source, request.decoded_audio.clone())?,
648        };
649        let clamped_position = duration
650            .map(|duration| position.min(duration))
651            .unwrap_or(position);
652        let (engine, duration, actual_position, decoded_audio) =
653            if request.playback_state == PlaybackState::Playing {
654                let (engine, actual_duration, actual_position, decoded_audio) =
655                    create_engine_with_cache(
656                        &request.source,
657                        request.volume,
658                        clamped_position,
659                        decoded_audio,
660                    )?;
661                (
662                    Some(engine),
663                    actual_duration.or(duration),
664                    actual_position,
665                    decoded_audio,
666                )
667            } else {
668                (None, duration, clamped_position, decoded_audio)
669            };
670
671        let mut state = self.state.lock();
672        state.refresh_finished();
673        if state.generation != request.generation {
674            return Ok(());
675        }
676
677        if let Some(decoded_audio) = decoded_audio {
678            state.decoded_audio = Some(decoded_audio);
679        }
680        state.duration = duration.or(state.duration);
681        state.position = actual_position;
682        state.started_at = if request.playback_state == PlaybackState::Playing {
683            Some(Instant::now())
684        } else {
685            None
686        };
687        if let Some(engine) = engine {
688            engine.sink.set_volume(state.volume.max(0.0));
689            state.engine = Some(engine);
690        } else {
691            state.engine = None;
692        }
693        state.generation += 1;
694
695        Ok(())
696    }
697
698    /// Set the playback volume where `1.0` is the original amplitude.
699    pub fn set_volume(&self, volume: f32) {
700        let mut state = self.state.lock();
701        let clamped_volume = volume.max(0.0);
702        state.volume = clamped_volume;
703        if let Some(engine) = state.engine.as_ref() {
704            engine.sink.set_volume(clamped_volume);
705        }
706    }
707
708    /// Return the current playback volume.
709    pub fn volume(&self) -> f32 {
710        self.state.lock().volume
711    }
712
713    /// Return the current playback state.
714    pub fn state(&self) -> PlaybackState {
715        let mut state = self.state.lock();
716        state.refresh_finished();
717        state.state
718    }
719
720    /// Return the current playback position.
721    pub fn position(&self) -> Duration {
722        let mut state = self.state.lock();
723        state.refresh_finished();
724        state.current_position()
725    }
726
727    /// Return the total duration if it can be determined from the media source.
728    pub fn duration(&self) -> Result<Option<Duration>, AudioPlaybackError> {
729        let request = {
730            let state = self.state.lock();
731            if state.duration.is_some() {
732                return Ok(state.duration);
733            }
734
735            AudioProbeRequest {
736                generation: state.generation,
737                source: state.source.clone(),
738                decoded_audio: state.decoded_audio.clone(),
739            }
740        };
741
742        let (duration, decoded_audio) =
743            probe_duration_with_cache(&request.source, request.decoded_audio)?;
744
745        let mut state = self.state.lock();
746        if state.generation == request.generation && state.duration.is_none() {
747            if let Some(decoded_audio) = decoded_audio {
748                state.decoded_audio = Some(decoded_audio);
749            }
750            state.duration = duration;
751        }
752        Ok(state.duration.or(duration))
753    }
754
755    /// Return the source that this audio handle will play.
756    pub fn source(&self) -> MediaSource {
757        self.state.lock().source.clone()
758    }
759}
760
761/// Probe the total duration for an audio source without constructing a playback handle.
762pub fn probe_audio_duration(
763    source: impl Into<MediaSource>,
764) -> Result<Option<Duration>, AudioPlaybackError> {
765    let source = source.into();
766    probe_duration(&source, &mut None)
767}
768
769impl AudioHandleState {
770    fn current_position(&self) -> Duration {
771        let position = if self.state == PlaybackState::Playing {
772            self.started_at
773                .map(|started_at| self.position + started_at.elapsed())
774                .unwrap_or(self.position)
775        } else {
776            self.position
777        };
778
779        self.duration
780            .map(|duration| position.min(duration))
781            .unwrap_or(position)
782    }
783
784    fn refresh_finished(&mut self) {
785        if self.state != PlaybackState::Playing {
786            return;
787        }
788
789        let finished = self
790            .engine
791            .as_ref()
792            .is_some_and(|engine| engine.sink.empty());
793        let position = self.current_position();
794        let reached_end = self.duration.is_some_and(|duration| position >= duration);
795        if !finished && !reached_end {
796            return;
797        }
798
799        self.position = self.duration.unwrap_or(position);
800        self.started_at = None;
801        self.state = PlaybackState::Stopped;
802        self.engine = None;
803    }
804}
805
806fn probe_duration(
807    source: &MediaSource,
808    decoded_audio: &mut Option<Arc<DecodedAudio>>,
809) -> Result<Option<Duration>, AudioPlaybackError> {
810    match try_create_decoder(source)? {
811        Some(decoder) => Ok(decoder.total_duration()),
812        None => {
813            let decoded_audio = ensure_decoded_audio(source, decoded_audio)
814                .map_err(|decode_error| AudioPlaybackError::Decoder(decode_error.to_string()))?;
815            Ok(Some(decoded_audio.duration))
816        }
817    }
818}
819
820fn probe_duration_with_cache(
821    source: &MediaSource,
822    decoded_audio: Option<Arc<DecodedAudio>>,
823) -> Result<(Option<Duration>, Option<Arc<DecodedAudio>>), AudioPlaybackError> {
824    let mut decoded_audio = decoded_audio;
825    let duration = probe_duration(source, &mut decoded_audio)?;
826    Ok((duration, decoded_audio))
827}
828
829fn create_engine_with_cache(
830    source: &MediaSource,
831    volume: f32,
832    position: Duration,
833    decoded_audio: Option<Arc<DecodedAudio>>,
834) -> Result<
835    (
836        AudioEngine,
837        Option<Duration>,
838        Duration,
839        Option<Arc<DecodedAudio>>,
840    ),
841    AudioPlaybackError,
842> {
843    let mut decoded_audio = decoded_audio;
844    let (engine, duration, clamped_position) =
845        create_engine(source, volume, position, &mut decoded_audio)?;
846    Ok((engine, duration, clamped_position, decoded_audio))
847}
848
849fn create_engine(
850    source: &MediaSource,
851    volume: f32,
852    position: Duration,
853    decoded_audio: &mut Option<Arc<DecodedAudio>>,
854) -> Result<(AudioEngine, Option<Duration>, Duration), AudioPlaybackError> {
855    let (stream, stream_handle) = OutputStream::try_default()
856        .map_err(|error| AudioPlaybackError::Output(error.to_string()))?;
857    let sink = Sink::try_new(&stream_handle)
858        .map_err(|error| AudioPlaybackError::Output(error.to_string()))?;
859    let (duration, clamped_position) = match try_create_decoder(source)? {
860        Some(decoder) => {
861            let duration = decoder.total_duration();
862            let clamped_position = duration
863                .map(|duration| position.min(duration))
864                .unwrap_or(position);
865            sink.append(decoder.skip_duration(clamped_position));
866            (duration, clamped_position)
867        }
868        None => {
869            let decoded_audio = ensure_decoded_audio(source, decoded_audio)
870                .map_err(|decode_error| AudioPlaybackError::Decoder(decode_error.to_string()))?;
871            let clamped_position = position.min(decoded_audio.duration);
872            sink.append(
873                SamplesBuffer::new(
874                    decoded_audio.channels,
875                    decoded_audio.sample_rate,
876                    decoded_audio.samples.as_ref().to_vec(),
877                )
878                .skip_duration(clamped_position),
879            );
880            (Some(decoded_audio.duration), clamped_position)
881        }
882    };
883    sink.set_volume(volume.max(0.0));
884    Ok((
885        AudioEngine {
886            _stream: stream,
887            sink,
888        },
889        duration,
890        clamped_position,
891    ))
892}
893
894fn ensure_decoded_audio(
895    source: &MediaSource,
896    decoded_audio: &mut Option<Arc<DecodedAudio>>,
897) -> Result<Arc<DecodedAudio>, MediaDecodeError> {
898    if let Some(decoded_audio) = decoded_audio.as_ref() {
899        return Ok(decoded_audio.clone());
900    }
901
902    let decoded = Arc::new(decode_audio(source)?);
903    *decoded_audio = Some(decoded.clone());
904    Ok(decoded)
905}
906
907fn decode_audio(source: &MediaSource) -> Result<DecodedAudio, MediaDecodeError> {
908    ffmpeg::init().map_err(ffmpeg_decode_error)?;
909
910    let mut input_context = source
911        .resolve_ffmpeg_input()?
912        .open_input()
913        .map_err(ffmpeg_decode_error)?;
914    let input_stream = input_context
915        .streams()
916        .best(ffmpeg::media::Type::Audio)
917        .ok_or_else(|| MediaDecodeError::Decode("no audio stream found".into()))?;
918    let audio_stream_index = input_stream.index();
919
920    let context_decoder =
921        ffmpeg::codec::context::Context::from_parameters(input_stream.parameters())
922            .map_err(ffmpeg_decode_error)?;
923    let mut decoder = context_decoder
924        .decoder()
925        .audio()
926        .map_err(ffmpeg_decode_error)?;
927
928    let channel_layout = if decoder.channel_layout().is_empty() {
929        ffmpeg::ChannelLayout::default(decoder.channels().into())
930    } else {
931        decoder.channel_layout()
932    };
933    let sample_rate = decoder.rate();
934    let mut resampler = ffmpeg::software::resampling::context::Context::get(
935        decoder.format(),
936        channel_layout,
937        sample_rate,
938        ffmpeg::format::Sample::F32(ffmpeg::format::sample::Type::Packed),
939        channel_layout,
940        sample_rate,
941    )
942    .map_err(ffmpeg_decode_error)?;
943
944    let mut samples = Vec::new();
945
946    let mut receive_and_process_decoded_frames = |decoder: &mut ffmpeg::decoder::Audio,
947                                                  samples: &mut Vec<f32>|
948     -> Result<(), MediaDecodeError> {
949        let mut decoded = ffmpeg::util::frame::Audio::empty();
950        while decoder.receive_frame(&mut decoded).is_ok() {
951            let mut output = ffmpeg::util::frame::Audio::empty();
952            resampler
953                .run(&decoded, &mut output)
954                .map_err(ffmpeg_decode_error)?;
955            samples.extend_from_slice(output.plane::<f32>(0));
956        }
957
958        Ok(())
959    };
960
961    for (stream, packet) in input_context.packets() {
962        if stream.index() == audio_stream_index {
963            decoder.send_packet(&packet).map_err(ffmpeg_decode_error)?;
964            receive_and_process_decoded_frames(&mut decoder, &mut samples)?;
965        }
966    }
967
968    decoder.send_eof().map_err(ffmpeg_decode_error)?;
969    receive_and_process_decoded_frames(&mut decoder, &mut samples)?;
970
971    loop {
972        let mut output = ffmpeg::util::frame::Audio::empty();
973        let delayed = resampler.flush(&mut output).map_err(ffmpeg_decode_error)?;
974        if output.samples() > 0 {
975            samples.extend_from_slice(output.plane::<f32>(0));
976        }
977        if delayed.is_none() {
978            break;
979        }
980    }
981
982    let channels = channel_layout.channels() as u16;
983    let duration = if channels == 0 || sample_rate == 0 {
984        Duration::ZERO
985    } else {
986        Duration::from_secs_f64(samples.len() as f64 / channels as f64 / sample_rate as f64)
987    };
988
989    Ok(DecodedAudio {
990        channels,
991        sample_rate,
992        samples: Arc::<[f32]>::from(samples),
993        duration,
994    })
995}
996
997fn open_video_stream(source: &MediaSource) -> Result<OpenedVideoStream, MediaDecodeError> {
998    ffmpeg::init().map_err(ffmpeg_decode_error)?;
999
1000    let input_context = source
1001        .resolve_ffmpeg_input()?
1002        .open_input()
1003        .map_err(ffmpeg_decode_error)?;
1004    let input_stream = input_context
1005        .streams()
1006        .best(ffmpeg::media::Type::Video)
1007        .ok_or(MediaDecodeError::NoVideoStream)?;
1008    let video_stream_index = input_stream.index();
1009    let time_base = input_stream.time_base();
1010    let duration = if input_stream.duration() > 0 {
1011        Some(duration_from_time_base(input_stream.duration(), time_base))
1012    } else if input_context.duration() > 0 {
1013        Some(duration_from_time_base(
1014            input_context.duration(),
1015            ffmpeg::util::mathematics::rescale::TIME_BASE,
1016        ))
1017    } else {
1018        None
1019    };
1020
1021    let context_decoder =
1022        ffmpeg::codec::context::Context::from_parameters(input_stream.parameters())
1023            .map_err(ffmpeg_decode_error)?;
1024    let decoder = context_decoder
1025        .decoder()
1026        .video()
1027        .map_err(ffmpeg_decode_error)?;
1028    let width = decoder.width();
1029    let height = decoder.height();
1030    let scaler = ffmpeg::software::scaling::context::Context::get(
1031        decoder.format(),
1032        width,
1033        height,
1034        ffmpeg::format::Pixel::BGRA,
1035        width,
1036        height,
1037        ffmpeg::software::scaling::flag::Flags::BILINEAR,
1038    )
1039    .map_err(ffmpeg_decode_error)?;
1040
1041    Ok(OpenedVideoStream {
1042        input_context,
1043        decoder,
1044        scaler,
1045        video_stream_index,
1046        time_base,
1047        metadata: VideoMetadata {
1048            width,
1049            height,
1050            duration,
1051        },
1052    })
1053}
1054
1055fn decode_video_frame(
1056    scaler: &mut ffmpeg::software::scaling::context::Context,
1057    decoded: &ffmpeg::util::frame::video::Video,
1058    time_base: ffmpeg::Rational,
1059) -> Result<VideoFrame, MediaDecodeError> {
1060    let mut bgra_frame = ffmpeg::util::frame::video::Video::empty();
1061    scaler
1062        .run(decoded, &mut bgra_frame)
1063        .map_err(ffmpeg_decode_error)?;
1064
1065    Ok(VideoFrame {
1066        data: Arc::<[u8]>::from(copy_bgra_frame(&bgra_frame)),
1067        width: bgra_frame.width(),
1068        height: bgra_frame.height(),
1069        timestamp: duration_from_time_base(decoded.timestamp().unwrap_or_default(), time_base),
1070    })
1071}
1072
1073fn try_create_decoder(
1074    source: &MediaSource,
1075) -> Result<Option<Decoder<MediaReader>>, AudioPlaybackError> {
1076    if !source.direct_reader_supported() {
1077        return Ok(None);
1078    }
1079
1080    match Decoder::new(source.open_reader()?) {
1081        Ok(decoder) => Ok(Some(decoder)),
1082        Err(_) => Ok(None),
1083    }
1084}
1085
1086impl ReaderMediaSource {
1087    fn stage_to_path(&self) -> io::Result<PathBuf> {
1088        if let Some(path) = self.staged_path.lock().clone() {
1089            return Ok(path);
1090        }
1091
1092        let path = staged_media_dir().join(format!(
1093            "reader-{:016x}-{}",
1094            hash_value(&self.key),
1095            STAGED_READER_COUNTER.fetch_add(1, Ordering::Relaxed)
1096        ));
1097        write_path_atomically(&path, |file| {
1098            let mut reader = (self.open)()?;
1099            io::copy(&mut reader, file)?;
1100            Ok(())
1101        })?;
1102        *self.staged_path.lock() = Some(path.clone());
1103        Ok(path)
1104    }
1105}
1106
1107impl ResolvedMediaInput {
1108    fn open_input(&self) -> Result<ffmpeg::format::context::Input, ffmpeg::Error> {
1109        match self {
1110            Self::Path(path) => ffmpeg::format::input(path),
1111            Self::Url(url) => {
1112                ffmpeg::format::network::init();
1113                ffmpeg::format::input(url.as_ref())
1114            }
1115        }
1116    }
1117}
1118
1119fn stage_bytes(bytes: &Arc<[u8]>) -> Result<PathBuf, MediaDecodeError> {
1120    let path = staged_media_dir().join(format!("bytes-{:016x}", hash_value(bytes)));
1121    write_path_atomically(&path, |file| {
1122        file.write_all(bytes.as_ref())?;
1123        Ok(())
1124    })
1125    .map_err(MediaDecodeError::from_io)?;
1126    Ok(path)
1127}
1128
1129fn staged_media_dir() -> PathBuf {
1130    std::env::temp_dir().join("kael-media")
1131}
1132
1133fn write_path_atomically(
1134    path: &Path,
1135    populate: impl FnOnce(&mut File) -> io::Result<()>,
1136) -> io::Result<()> {
1137    if path.exists() {
1138        return Ok(());
1139    }
1140
1141    if let Some(parent) = path.parent() {
1142        fs::create_dir_all(parent)?;
1143    }
1144
1145    let temporary_path = path.with_extension(format!(
1146        "tmp-{}-{}",
1147        std::process::id(),
1148        STAGED_READER_COUNTER.fetch_add(1, Ordering::Relaxed)
1149    ));
1150    let mut file = OpenOptions::new()
1151        .write(true)
1152        .create_new(true)
1153        .open(&temporary_path)?;
1154    let populate_result = populate(&mut file).and_then(|_| file.flush());
1155    if let Err(error) = populate_result {
1156        let _ = fs::remove_file(&temporary_path);
1157        return Err(error);
1158    }
1159
1160    match fs::rename(&temporary_path, path) {
1161        Ok(()) => Ok(()),
1162        Err(_error) if path.exists() => {
1163            let _ = fs::remove_file(&temporary_path);
1164            Ok(())
1165        }
1166        Err(error) => {
1167            let _ = fs::remove_file(&temporary_path);
1168            Err(error)
1169        }
1170    }
1171}
1172
1173fn hash_value(value: &impl Hash) -> u64 {
1174    let mut hasher = DefaultHasher::new();
1175    value.hash(&mut hasher);
1176    hasher.finish()
1177}
1178
1179impl MediaDecodeError {
1180    fn from_io(error: io::Error) -> Self {
1181        Self::Decode(error.to_string())
1182    }
1183}
1184
1185fn ffmpeg_decode_error(error: ffmpeg::Error) -> MediaDecodeError {
1186    MediaDecodeError::Decode(error.to_string())
1187}
1188
1189fn duration_from_time_base(timestamp: i64, time_base: ffmpeg::Rational) -> Duration {
1190    if timestamp <= 0 {
1191        return Duration::ZERO;
1192    }
1193
1194    Duration::from_secs_f64((timestamp as f64) * f64::from(time_base))
1195}
1196
1197fn copy_bgra_frame(frame: &ffmpeg::util::frame::video::Video) -> Box<[u8]> {
1198    let width = frame.width() as usize;
1199    let height = frame.height() as usize;
1200    let row_len = width * 4;
1201    let stride = frame.stride(0);
1202    let source = frame.data(0);
1203    let mut bytes = vec![0u8; row_len * height];
1204
1205    for row in 0..height {
1206        let source_offset = row * stride;
1207        let destination_offset = row * row_len;
1208        bytes[destination_offset..destination_offset + row_len]
1209            .copy_from_slice(&source[source_offset..source_offset + row_len]);
1210    }
1211
1212    bytes.into_boxed_slice()
1213}
1214
1215#[cfg(test)]
1216mod tests {
1217    use super::{
1218        AudioHandle, MAX_DECODED_VIDEO_BYTES, MAX_DECODED_VIDEO_FRAMES, MediaDecodeError,
1219        MediaDecoder, MediaSource, PlaybackState, VideoFrame, push_decoded_video_frame,
1220    };
1221    use std::{io::Cursor, sync::Arc, time::Duration};
1222
1223    #[test]
1224    fn duration_probe_works_for_memory_backed_wav() {
1225        let handle = AudioHandle::new(MediaSource::bytes(silent_wav(8_000, 8_000)));
1226
1227        assert_eq!(handle.state(), PlaybackState::Stopped);
1228        assert_eq!(handle.duration().unwrap(), Some(Duration::from_secs(1)));
1229        assert_eq!(handle.position(), Duration::ZERO);
1230    }
1231
1232    #[test]
1233    fn seek_updates_position_without_starting_playback() {
1234        let handle = AudioHandle::new(MediaSource::bytes(silent_wav(8_000, 8_000)));
1235
1236        handle.seek(Duration::from_millis(250)).unwrap();
1237
1238        assert_eq!(handle.state(), PlaybackState::Stopped);
1239        assert_eq!(handle.position(), Duration::from_millis(250));
1240    }
1241
1242    #[test]
1243    fn duration_probe_works_for_reader_backed_wav() {
1244        let wav = Arc::<[u8]>::from(silent_wav(8_000, 8_000));
1245        let handle = AudioHandle::new(MediaSource::reader("reader-wav", {
1246            move || Ok(Cursor::new(wav.clone()))
1247        }));
1248
1249        assert_eq!(handle.duration().unwrap(), Some(Duration::from_secs(1)));
1250        assert_eq!(handle.position(), Duration::ZERO);
1251        assert_eq!(handle.state(), PlaybackState::Stopped);
1252    }
1253
1254    #[test]
1255    fn stop_resets_position() {
1256        let handle = AudioHandle::new(MediaSource::bytes(silent_wav(8_000, 8_000)));
1257
1258        handle.seek(Duration::from_millis(300)).unwrap();
1259        handle.stop();
1260
1261        assert_eq!(handle.position(), Duration::ZERO);
1262        assert_eq!(handle.state(), PlaybackState::Stopped);
1263    }
1264
1265    #[test]
1266    fn video_decoder_stages_in_memory_sources_before_decode() {
1267        let decoder = MediaDecoder::new(MediaSource::bytes([0u8; 16]));
1268
1269        assert!(matches!(
1270            decoder.video_metadata().unwrap_err(),
1271            MediaDecodeError::Decode(_) | MediaDecodeError::NoVideoStream
1272        ));
1273        assert!(matches!(
1274            decoder.decode_video_frames().unwrap_err(),
1275            MediaDecodeError::Decode(_) | MediaDecodeError::NoVideoStream
1276        ));
1277    }
1278
1279    #[test]
1280    fn video_decoder_accepts_reader_backed_sources() {
1281        let payload = Arc::<[u8]>::from([0u8; 16]);
1282        let decoder = MediaDecoder::new(MediaSource::reader("reader-video", {
1283            move || Ok(Cursor::new(payload.clone()))
1284        }));
1285
1286        assert!(matches!(
1287            decoder.video_metadata().unwrap_err(),
1288            MediaDecodeError::Decode(_) | MediaDecodeError::NoVideoStream
1289        ));
1290        assert!(matches!(
1291            decoder.decode_video_frames().unwrap_err(),
1292            MediaDecodeError::Decode(_) | MediaDecodeError::NoVideoStream
1293        ));
1294    }
1295
1296    #[test]
1297    fn full_video_decode_rejects_excessive_frame_counts() {
1298        let mut frames = Vec::new();
1299        let mut decoded_bytes = 0u64;
1300
1301        for index in 0..MAX_DECODED_VIDEO_FRAMES {
1302            push_decoded_video_frame(
1303                &mut frames,
1304                &mut decoded_bytes,
1305                test_video_frame(index as u64, 4),
1306            )
1307            .unwrap();
1308        }
1309
1310        let error = push_decoded_video_frame(
1311            &mut frames,
1312            &mut decoded_bytes,
1313            test_video_frame(MAX_DECODED_VIDEO_FRAMES as u64, 4),
1314        )
1315        .unwrap_err();
1316
1317        assert!(matches!(error, MediaDecodeError::Decode(message) if message.contains("frames")));
1318    }
1319
1320    #[test]
1321    fn full_video_decode_rejects_excessive_byte_counts() {
1322        let mut frames = Vec::new();
1323        let mut decoded_bytes = MAX_DECODED_VIDEO_BYTES - 1;
1324
1325        let error =
1326            push_decoded_video_frame(&mut frames, &mut decoded_bytes, test_video_frame(0, 2))
1327                .unwrap_err();
1328
1329        assert!(matches!(error, MediaDecodeError::Decode(message) if message.contains("bytes")));
1330    }
1331
1332    fn test_video_frame(timestamp_millis: u64, len: usize) -> VideoFrame {
1333        VideoFrame {
1334            data: Arc::<[u8]>::from(vec![0; len]),
1335            width: 1,
1336            height: 1,
1337            timestamp: Duration::from_millis(timestamp_millis),
1338        }
1339    }
1340
1341    fn silent_wav(sample_rate: u32, samples: u32) -> Vec<u8> {
1342        let channels = 1u16;
1343        let bits_per_sample = 16u16;
1344        let bytes_per_sample = (bits_per_sample / 8) as u32;
1345        let data_len = samples * channels as u32 * bytes_per_sample;
1346        let byte_rate = sample_rate * channels as u32 * bytes_per_sample;
1347        let block_align = channels * (bits_per_sample / 8);
1348        let chunk_size = 36 + data_len;
1349
1350        let mut wav = Vec::with_capacity((44 + data_len) as usize);
1351        wav.extend_from_slice(b"RIFF");
1352        wav.extend_from_slice(&chunk_size.to_le_bytes());
1353        wav.extend_from_slice(b"WAVE");
1354        wav.extend_from_slice(b"fmt ");
1355        wav.extend_from_slice(&16u32.to_le_bytes());
1356        wav.extend_from_slice(&1u16.to_le_bytes());
1357        wav.extend_from_slice(&channels.to_le_bytes());
1358        wav.extend_from_slice(&sample_rate.to_le_bytes());
1359        wav.extend_from_slice(&byte_rate.to_le_bytes());
1360        wav.extend_from_slice(&block_align.to_le_bytes());
1361        wav.extend_from_slice(&bits_per_sample.to_le_bytes());
1362        wav.extend_from_slice(b"data");
1363        wav.extend_from_slice(&data_len.to_le_bytes());
1364        wav.resize(44 + data_len as usize, 0);
1365        wav
1366    }
1367}