use std::sync::Arc;
use crate::GraphicsContext;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ReadbackError {
MapFailed(String),
CopyFailed(String),
EncodeFailed(String),
IoError(String),
InvalidDimensions,
UnsupportedFormat,
}
impl std::fmt::Display for ReadbackError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::MapFailed(msg) => write!(f, "Buffer mapping failed: {}", msg),
Self::CopyFailed(msg) => write!(f, "Texture copy failed: {}", msg),
Self::EncodeFailed(msg) => write!(f, "Image encoding failed: {}", msg),
Self::IoError(msg) => write!(f, "IO error: {}", msg),
Self::InvalidDimensions => write!(f, "Invalid dimensions for readback"),
Self::UnsupportedFormat => write!(f, "Unsupported texture format for readback"),
}
}
}
impl std::error::Error for ReadbackError {}
pub struct GpuReadback {
buffer: wgpu::Buffer,
dimensions: (u32, u32),
bytes_per_row: u32,
format: wgpu::TextureFormat,
}
impl GpuReadback {
pub fn from_texture(
context: Arc<GraphicsContext>,
texture: &wgpu::Texture,
) -> Result<Self, ReadbackError> {
let size = texture.size();
let dimensions = (size.width, size.height);
let format = texture.format();
if dimensions.0 == 0 || dimensions.1 == 0 {
return Err(ReadbackError::InvalidDimensions);
}
let bytes_per_pixel = match format {
wgpu::TextureFormat::Rgba8Unorm
| wgpu::TextureFormat::Rgba8UnormSrgb
| wgpu::TextureFormat::Bgra8Unorm
| wgpu::TextureFormat::Bgra8UnormSrgb => 4,
wgpu::TextureFormat::Rgb10a2Unorm => 4,
_ => return Err(ReadbackError::UnsupportedFormat),
};
let unpadded_bytes_per_row = dimensions.0 * bytes_per_pixel;
let align = wgpu::COPY_BYTES_PER_ROW_ALIGNMENT;
let bytes_per_row = unpadded_bytes_per_row.div_ceil(align) * align;
let buffer_size = (bytes_per_row * dimensions.1) as wgpu::BufferAddress;
let buffer = context.device().create_buffer(&wgpu::BufferDescriptor {
label: Some("readback_buffer"),
size: buffer_size,
usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
mapped_at_creation: false,
});
let mut encoder =
context
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("readback_encoder"),
});
encoder.copy_texture_to_buffer(
wgpu::TexelCopyTextureInfo {
texture,
mip_level: 0,
origin: wgpu::Origin3d::ZERO,
aspect: wgpu::TextureAspect::All,
},
wgpu::TexelCopyBufferInfo {
buffer: &buffer,
layout: wgpu::TexelCopyBufferLayout {
offset: 0,
bytes_per_row: Some(bytes_per_row),
rows_per_image: Some(dimensions.1),
},
},
size,
);
context.queue().submit(Some(encoder.finish()));
Ok(Self {
buffer,
dimensions,
bytes_per_row,
format,
})
}
pub fn read(&self) -> Result<Vec<u8>, ReadbackError> {
let buffer_slice = self.buffer.slice(..);
buffer_slice.map_async(wgpu::MapMode::Read, |_| {});
let data = buffer_slice.get_mapped_range();
let bytes_per_pixel = 4; let mut result =
Vec::with_capacity((self.dimensions.0 * self.dimensions.1 * bytes_per_pixel) as usize);
for y in 0..self.dimensions.1 {
let row_start = (y * self.bytes_per_row) as usize;
let row_end = row_start + (self.dimensions.0 * bytes_per_pixel) as usize;
result.extend_from_slice(&data[row_start..row_end]);
}
drop(data);
self.buffer.unmap();
Ok(result)
}
#[cfg(feature = "image")]
pub fn save_png(&self, path: impl AsRef<std::path::Path>) -> Result<(), ReadbackError> {
let data = self.read()?;
let img = image::RgbaImage::from_raw(self.dimensions.0, self.dimensions.1, data).ok_or(
ReadbackError::EncodeFailed("Failed to create image from raw data".to_string()),
)?;
img.save(path)
.map_err(|e| ReadbackError::IoError(format!("{}", e)))?;
Ok(())
}
pub fn dimensions(&self) -> (u32, u32) {
self.dimensions
}
pub fn format(&self) -> wgpu::TextureFormat {
self.format
}
}
pub trait ReadbackExt {
fn capture_texture(&self, texture: &wgpu::Texture) -> Result<GpuReadback, ReadbackError>;
}
impl ReadbackExt for Arc<GraphicsContext> {
fn capture_texture(&self, texture: &wgpu::Texture) -> Result<GpuReadback, ReadbackError> {
GpuReadback::from_texture(self.clone(), texture)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_readback_error_display() {
let err = ReadbackError::MapFailed("test".to_string());
assert!(format!("{}", err).contains("Buffer mapping failed"));
let err = ReadbackError::InvalidDimensions;
assert!(format!("{}", err).contains("Invalid dimensions"));
}
#[test]
fn test_bytes_per_row_alignment() {
let align = wgpu::COPY_BYTES_PER_ROW_ALIGNMENT;
let unpadded: u32 = 100 * 4;
let padded = unpadded.div_ceil(align) * align;
assert_eq!(padded, 512);
assert_eq!(padded % align, 0);
}
#[test]
fn test_readback_dimensions() {
assert!(matches!(
ReadbackError::InvalidDimensions,
ReadbackError::InvalidDimensions
));
}
}