recorder-for-jetkvm 0.1.0

JetKVM recorder and screenshot utility
Documentation
use std::io::ErrorKind;
use std::path::Path;

use anyhow::{Context, Result, anyhow, bail};
use tokio::sync::{broadcast, watch};
use tracing::{info, warn};

use crate::h264::{NAL_TYPE_IDR, NalUnit};

#[derive(Debug, Default)]
struct AccessUnit {
    rtp_timestamp: u32,
    data: Vec<u8>,
    contains_idr: bool,
    contains_sps: bool,
    contains_pps: bool,
}

impl AccessUnit {
    fn new(rtp_timestamp: u32) -> Self {
        Self {
            rtp_timestamp,
            ..Self::default()
        }
    }

    fn push(&mut self, nal: &NalUnit) {
        self.data.extend_from_slice(&nal.data);

        match nal.nal_type() {
            Some(NAL_TYPE_IDR) => self.contains_idr = true,
            Some(7) => self.contains_sps = true,
            Some(8) => self.contains_pps = true,
            _ => {}
        }
    }
}

pub async fn capture(
    mut nal_rx: broadcast::Receiver<NalUnit>,
    output_path: &Path,
    mut shutdown: watch::Receiver<bool>,
) -> Result<()> {
    let codec = ffmpeg_the_third::decoder::find(ffmpeg_the_third::codec::Id::H264)
        .context("H.264 decoder not found in linked FFmpeg")?;
    let mut decoder = ffmpeg_the_third::codec::Context::new()
        .decoder()
        .open_as(codec)
        .context("failed to open H.264 decoder")?
        .video()
        .context("linked H.264 decoder is not a video decoder")?;

    let mut current_sps: Option<Vec<u8>> = None;
    let mut current_pps: Option<Vec<u8>> = None;
    let mut pending_access_unit: Option<AccessUnit> = None;
    let mut started = false;

    loop {
        tokio::select! {
            _ = shutdown.changed() => {
                bail!("shutdown requested before screenshot was captured");
            }
            recv_result = nal_rx.recv() => {
                match recv_result {
                    Ok(nal) => {
                        if pending_access_unit
                            .as_ref()
                            .is_some_and(|au| au.rtp_timestamp != nal.rtp_timestamp)
                        {
                            let access_unit = pending_access_unit
                                .take()
                                .expect("checked pending access unit above");
                            if maybe_capture_access_unit(
                                &mut decoder,
                                &access_unit,
                                current_sps.as_deref(),
                                current_pps.as_deref(),
                                &mut started,
                                output_path,
                            )? {
                                return Ok(());
                            }
                        }

                        match nal.nal_type() {
                            Some(7) => current_sps = Some(nal.data.to_vec()),
                            Some(8) => current_pps = Some(nal.data.to_vec()),
                            _ => {}
                        }

                        pending_access_unit
                            .get_or_insert_with(|| AccessUnit::new(nal.rtp_timestamp))
                            .push(&nal);
                    }
                    Err(broadcast::error::RecvError::Lagged(n)) => {
                        warn!("screenshot capture lagged, missed {n} NAL units");
                        pending_access_unit = None;
                        started = false;
                        decoder.flush();
                    }
                    Err(broadcast::error::RecvError::Closed) => {
                        if let Some(access_unit) = pending_access_unit.take()
                            && maybe_capture_access_unit(
                                &mut decoder,
                                &access_unit,
                                current_sps.as_deref(),
                                current_pps.as_deref(),
                                &mut started,
                                output_path,
                            )?
                        {
                            return Ok(());
                        }

                        bail!("video stream ended before a screenshot could be captured");
                    }
                }
            }
        }
    }
}

fn maybe_capture_access_unit(
    decoder: &mut ffmpeg_the_third::decoder::Video,
    access_unit: &AccessUnit,
    current_sps: Option<&[u8]>,
    current_pps: Option<&[u8]>,
    started: &mut bool,
    output_path: &Path,
) -> Result<bool> {
    if access_unit.data.is_empty() {
        return Ok(false);
    }

    let first_keyframe = !*started;
    if first_keyframe && !access_unit.contains_idr {
        return Ok(false);
    }

    if first_keyframe
        && ((!access_unit.contains_sps && current_sps.is_none())
            || (!access_unit.contains_pps && current_pps.is_none()))
    {
        warn!("skipping keyframe without cached SPS/PPS data");
        return Ok(false);
    }

    *started = true;

    let packet_data = build_decoder_packet(access_unit, current_sps, current_pps, first_keyframe);
    let packet = ffmpeg_the_third::Packet::copy(&packet_data);

    if let Err(err) = decoder.send_packet(&packet) {
        warn!("failed to send access unit to decoder, waiting for next keyframe: {err}");
        decoder.flush();
        *started = false;
        return Ok(false);
    }

    loop {
        let mut decoded = ffmpeg_the_third::frame::Video::empty();
        match decoder.receive_frame(&mut decoded) {
            Ok(()) => {
                encode_png(&decoded, output_path)?;
                info!(
                    path = %output_path.display(),
                    width = decoded.width(),
                    height = decoded.height(),
                    "saved screenshot"
                );
                return Ok(true);
            }
            Err(err) if is_would_block(&err) || matches!(err, ffmpeg_the_third::Error::Eof) => {
                return Ok(false);
            }
            Err(err) => {
                warn!("failed to decode access unit, waiting for next keyframe: {err}");
                decoder.flush();
                *started = false;
                return Ok(false);
            }
        }
    }
}

fn build_decoder_packet(
    access_unit: &AccessUnit,
    current_sps: Option<&[u8]>,
    current_pps: Option<&[u8]>,
    prepend_parameter_sets: bool,
) -> Vec<u8> {
    let mut packet = Vec::with_capacity(
        access_unit.data.len()
            + current_sps.map_or(0, |data| data.len())
            + current_pps.map_or(0, |data| data.len()),
    );

    if prepend_parameter_sets {
        if !access_unit.contains_sps
            && let Some(sps) = current_sps
        {
            packet.extend_from_slice(sps);
        }
        if !access_unit.contains_pps
            && let Some(pps) = current_pps
        {
            packet.extend_from_slice(pps);
        }
    }

    packet.extend_from_slice(&access_unit.data);
    packet
}

fn encode_png(frame: &ffmpeg_the_third::frame::Video, output_path: &Path) -> Result<()> {
    let mut rgb_frame = ffmpeg_the_third::frame::Video::new(
        ffmpeg_the_third::format::Pixel::RGB24,
        frame.width(),
        frame.height(),
    );
    let mut scaler = frame
        .converter(ffmpeg_the_third::format::Pixel::RGB24)
        .context("failed to create RGB frame converter")?;
    scaler
        .run(frame, &mut rgb_frame)
        .context("failed to convert decoded frame to RGB24")?;

    let codec = ffmpeg_the_third::encoder::find(ffmpeg_the_third::codec::Id::PNG)
        .context("PNG encoder not found in linked FFmpeg")?;
    let mut encoder = ffmpeg_the_third::codec::Context::new()
        .encoder()
        .video()
        .context("failed to create PNG encoder context")?;
    encoder.set_width(rgb_frame.width());
    encoder.set_height(rgb_frame.height());
    encoder.set_format(ffmpeg_the_third::format::Pixel::RGB24);
    encoder.set_time_base(ffmpeg_the_third::Rational(1, 1));

    let mut encoder = encoder.open_as(codec).context("failed to open PNG encoder")?;
    encoder
        .send_frame(&rgb_frame)
        .context("failed to send RGB frame to PNG encoder")?;
    encoder
        .send_eof()
        .context("failed to flush PNG encoder")?;

    let mut packet = ffmpeg_the_third::Packet::empty();
    loop {
        match encoder.receive_packet(&mut packet) {
            Ok(()) => {
                let data = packet
                    .data()
                    .ok_or_else(|| anyhow!("PNG encoder returned an empty packet"))?;

                if let Some(parent) = output_path.parent() {
                    std::fs::create_dir_all(parent).with_context(|| {
                        format!("failed to create screenshot directory: {}", parent.display())
                    })?;
                }

                std::fs::write(output_path, data).with_context(|| {
                    format!("failed to write screenshot: {}", output_path.display())
                })?;
                return Ok(());
            }
            Err(err) if is_would_block(&err) => continue,
            Err(ffmpeg_the_third::Error::Eof) => {
                bail!("PNG encoder produced no output packet");
            }
            Err(err) => {
                return Err(err).context("failed to receive PNG packet from encoder");
            }
        }
    }
}

fn is_would_block(err: &ffmpeg_the_third::Error) -> bool {
    match err {
        ffmpeg_the_third::Error::Other { errno } => {
            std::io::Error::from_raw_os_error(*errno).kind() == ErrorKind::WouldBlock
        }
        _ => false,
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use bytes::Bytes;
    use std::time::Instant;

    fn nal(rtp_timestamp: u32, bytes: &[u8]) -> NalUnit {
        NalUnit {
            data: Bytes::copy_from_slice(bytes),
            is_keyframe: matches!(bytes.get(4).map(|b| b & 0x1f), Some(5 | 7 | 8)),
            timestamp: Instant::now(),
            rtp_timestamp,
        }
    }

    #[test]
    fn test_build_decoder_packet_prefixes_cached_parameter_sets_for_first_keyframe() {
        let mut access_unit = AccessUnit::new(42);
        access_unit.push(&nal(42, &[0, 0, 0, 1, 0x65, 0xaa]));

        let packet = build_decoder_packet(
            &access_unit,
            Some(&[0, 0, 0, 1, 0x67, 0x64]),
            Some(&[0, 0, 0, 1, 0x68, 0xee]),
            true,
        );

        assert_eq!(
            packet,
            vec![
                0, 0, 0, 1, 0x67, 0x64, 0, 0, 0, 1, 0x68, 0xee, 0, 0, 0, 1, 0x65, 0xaa
            ]
        );
    }

    #[test]
    fn test_build_decoder_packet_does_not_duplicate_parameter_sets() {
        let mut access_unit = AccessUnit::new(42);
        access_unit.push(&nal(42, &[0, 0, 0, 1, 0x67, 0x64]));
        access_unit.push(&nal(42, &[0, 0, 0, 1, 0x68, 0xee]));
        access_unit.push(&nal(42, &[0, 0, 0, 1, 0x65, 0xaa]));

        let packet = build_decoder_packet(
            &access_unit,
            Some(&[0, 0, 0, 1, 0x67, 0x64]),
            Some(&[0, 0, 0, 1, 0x68, 0xee]),
            true,
        );

        assert_eq!(
            packet,
            vec![
                0, 0, 0, 1, 0x67, 0x64, 0, 0, 0, 1, 0x68, 0xee, 0, 0, 0, 1, 0x65, 0xaa
            ]
        );
    }
}