use std::{
ffi::CString,
os::fd::OwnedFd,
time::{SystemTime, UNIX_EPOCH},
};
use gbm::BufferObject;
use image::{ColorType, DynamicImage, ImageBuffer, Pixel, Rgba};
use memmap2::MmapMut;
use r_egl_wayland::{EGL_INSTALCE, r_egl as egl};
use rustix::{
fs::{self, SealFlags},
io, shm,
};
use wayland_client::protocol::{
wl_buffer::WlBuffer, wl_output, wl_shm::Format, wl_shm_pool::WlShmPool,
};
use crate::{
Error, Result,
convert::create_converter,
region::{LogicalRegion, Size},
};
pub struct FrameGuard {
pub buffer: WlBuffer,
pub shm_pool: WlShmPool,
pub size: Size,
}
impl Drop for FrameGuard {
fn drop(&mut self) {
self.buffer.destroy();
self.shm_pool.destroy();
}
}
pub struct DMAFrameGuard {
pub buffer: WlBuffer,
}
impl Drop for DMAFrameGuard {
fn drop(&mut self) {
self.buffer.destroy();
}
}
pub struct EGLImageGuard {
pub image: egl::Image,
pub(crate) egl_display: egl::Display,
}
impl Drop for EGLImageGuard {
fn drop(&mut self) {
EGL_INSTALCE
.destroy_image(self.egl_display, self.image)
.unwrap_or_else(|e| {
tracing::error!("EGLimage destruction had error: {e}");
});
}
}
#[derive(Debug, Copy, Clone, PartialEq)]
pub struct FrameFormat {
pub format: Format,
pub size: Size,
pub stride: u32,
}
#[derive(Debug, Copy, Clone, PartialEq)]
pub struct DMAFrameFormat {
pub format: u32,
pub size: Size,
}
impl FrameFormat {
pub fn byte_size(&self) -> u64 {
self.stride as u64 * self.size.height as u64
}
}
#[tracing::instrument(skip(frame_data))]
fn create_image_buffer<P, C>(frame_format: &FrameFormat, frame_data: C) -> Result<ImageBuffer<P, C>>
where
P: Pixel<Subpixel = u8>,
C: std::ops::Deref<Target = [u8]>,
{
tracing::debug!("Creating image buffer");
ImageBuffer::from_raw(
frame_format.size.width,
frame_format.size.height,
frame_data,
)
.ok_or(Error::BufferTooSmall)
}
#[derive(Debug)]
pub enum FrameData {
Mmap(MmapMut),
GBMBo(BufferObject<()>),
}
#[derive(Debug)]
pub struct FrameCopy {
pub frame_format: FrameFormat,
pub frame_color_type: ColorType,
pub frame_data: FrameData,
pub transform: wl_output::Transform,
pub logical_region: LogicalRegion,
pub physical_size: Size,
pub(crate) color_converted: bool,
}
impl FrameCopy {
pub(crate) fn convert_color_inplace(&mut self) -> Result<ColorType, Error> {
if self.color_converted {
return Ok(self.frame_color_type);
}
let frame_color_type = match create_converter(self.frame_format.format) {
Some(converter) => {
let FrameData::Mmap(raw) = &mut self.frame_data else {
return Err(Error::InvalidColor);
};
converter.convert_inplace(raw)
}
_ => {
tracing::error!("Unsupported buffer format: {:?}", self.frame_format.format);
tracing::error!(
"You can send a feature request for the above format to the mailing list for wayshot over at https://sr.ht/~shinyzenith/wayshot."
);
return Err(Error::NoSupportedBufferFormat);
}
};
self.frame_color_type = frame_color_type;
self.color_converted = true;
Ok(frame_color_type)
}
pub(crate) fn into_mmap_rgba_image_buffer(self) -> Result<ImageBuffer<Rgba<u8>, MmapMut>> {
if self.frame_color_type != ColorType::Rgba8 {
return Err(Error::InvalidColor);
}
match self.frame_data {
FrameData::Mmap(frame_mmap) => create_image_buffer(&self.frame_format, frame_mmap),
FrameData::GBMBo(_) => todo!(),
}
}
pub(crate) fn get_image(&mut self) -> Result<DynamicImage, Error> {
self.convert_color_inplace()?;
let image: DynamicImage = (&*self).try_into()?;
Ok(image)
}
}
impl TryFrom<&FrameCopy> for DynamicImage {
type Error = Error;
fn try_from(value: &FrameCopy) -> Result<Self> {
Ok(match value.frame_color_type {
ColorType::Rgb8 => {
let frame_data = match &value.frame_data {
FrameData::Mmap(frame_mmap) => frame_mmap.to_vec(),
FrameData::GBMBo(_) => todo!(),
};
Self::ImageRgb8(create_image_buffer(&value.frame_format, frame_data)?)
}
ColorType::Rgba8 => {
let frame_data = match &value.frame_data {
FrameData::Mmap(frame_mmap) => frame_mmap.to_vec(),
FrameData::GBMBo(_) => todo!(),
};
Self::ImageRgba8(create_image_buffer(&value.frame_format, frame_data)?)
}
_ => return Err(Error::InvalidColor),
})
}
}
fn get_mem_file_handle() -> String {
format!(
"/libwayshot-{}",
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|time| time.subsec_nanos().to_string())
.unwrap_or("unknown".into())
)
}
pub fn create_shm_fd() -> std::io::Result<OwnedFd> {
#[cfg(any(target_os = "linux", target_os = "freebsd"))]
loop {
match fs::memfd_create(
CString::new("libwayshot")?.as_c_str(),
fs::MemfdFlags::CLOEXEC | fs::MemfdFlags::ALLOW_SEALING,
) {
Ok(fd) => {
let _ = fs::fcntl_add_seals(&fd, fs::SealFlags::SHRINK | SealFlags::SEAL);
return Ok(fd);
}
Err(io::Errno::INTR) => continue,
Err(io::Errno::NOSYS) => break,
Err(errno) => return Err(std::io::Error::from(errno)),
}
}
let mut mem_file_handle = get_mem_file_handle();
loop {
let open_result = shm::open(
mem_file_handle.as_str(),
shm::OFlags::CREATE | shm::OFlags::EXCL | shm::OFlags::RDWR,
fs::Mode::RUSR | fs::Mode::WUSR,
);
match open_result {
Ok(fd) => match shm::unlink(mem_file_handle.as_str()) {
Ok(_) => return Ok(fd),
Err(errno) => return Err(std::io::Error::from(errno)),
},
Err(io::Errno::EXIST) => {
mem_file_handle = get_mem_file_handle();
continue;
}
Err(io::Errno::INTR) => continue,
Err(errno) => return Err(std::io::Error::from(errno)),
}
}
}