use crate as wgpu;
use std::fmt;
use std::future::Future;
use std::sync::atomic::{self, AtomicU32};
use std::sync::{Arc, Mutex};
use std::time::Duration;
use image::{GenericImage, GenericImageView};
#[derive(Debug, Default)]
pub struct Capturer {
converter_data_pair: Mutex<Option<ConverterDataPair>>,
thread_pool: Arc<Mutex<Option<Arc<ThreadPool>>>>,
workers: Option<u32>,
timeout: Option<Duration>,
}
#[derive(Debug)]
struct ThreadPool {
active_futures: Arc<AtomicU32>,
workers: u32,
timeout: Option<Duration>,
}
pub struct Snapshot {
buffer: wgpu::RowPaddedBuffer,
thread_pool: Arc<Mutex<Option<Arc<ThreadPool>>>>,
workers: Option<u32>,
timeout: Option<Duration>,
}
pub struct AwaitWorkerTimeout<F>(pub F);
#[derive(Debug)]
struct ConverterDataPair {
src_descriptor: wgpu::TextureDescriptor<'static>,
reshaper: wgpu::TextureReshaper,
dst_texture: wgpu::Texture,
}
pub struct Rgba8AsyncMappedImageBuffer<'buffer>(wgpu::ImageReadMapping<'buffer>);
impl ThreadPool {
fn spawn_when_worker_available<F>(&self, future: F) -> Result<(), AwaitWorkerTimeout<F>>
where
F: 'static + Future<Output = ()> + Send,
{
let mut start = None;
let mut interval_us = 128;
while self.active_futures() >= self.workers() {
if let Some(timeout) = self.timeout {
let start = start.get_or_insert_with(instant::Instant::now);
if start.elapsed() > timeout {
return Err(AwaitWorkerTimeout(future));
}
}
let duration = Duration::from_micros(interval_us);
std::thread::sleep(duration);
interval_us *= 2;
}
let active_futures = self.active_futures.clone();
let future = async move {
active_futures.fetch_add(1, atomic::Ordering::SeqCst);
future.await;
active_futures.fetch_sub(1, atomic::Ordering::SeqCst);
};
tokio::spawn(future);
Ok(())
}
fn active_futures(&self) -> u32 {
self.active_futures.load(atomic::Ordering::SeqCst)
}
fn workers(&self) -> u32 {
self.workers
}
fn await_active_futures(&self, device: &wgpu::Device) -> Result<(), AwaitWorkerTimeout<()>> {
let mut start = None;
let mut interval_us = 128;
while self.active_futures() > 0 {
if let Some(timeout) = self.timeout {
let start = start.get_or_insert_with(instant::Instant::now);
if start.elapsed() > timeout {
return Err(AwaitWorkerTimeout(()));
}
}
device.poll(wgpu::Maintain::Wait);
let duration = Duration::from_micros(interval_us);
std::thread::sleep(duration);
interval_us *= 2;
}
Ok(())
}
}
impl Capturer {
pub const DST_FORMAT: wgpu::TextureFormat = wgpu::TextureFormat::Rgba8UnormSrgb;
pub fn new(workers: Option<u32>, timeout: Option<Duration>) -> Self {
Capturer {
converter_data_pair: Default::default(),
thread_pool: Default::default(),
workers,
timeout,
}
}
pub fn active_snapshots(&self) -> u32 {
if let Ok(guard) = self.thread_pool.lock() {
if let Some(tp) = guard.as_ref() {
return tp.active_futures.load(atomic::Ordering::SeqCst);
}
}
0
}
pub fn workers(&self) -> u32 {
if let Ok(guard) = self.thread_pool.lock() {
if let Some(tp) = guard.as_ref() {
return tp.workers();
}
}
self.workers.unwrap_or(num_cpus::get() as u32)
}
pub fn capture(
&self,
device: &wgpu::Device,
encoder: &mut wgpu::CommandEncoder,
src_texture: &wgpu::Texture,
) -> Snapshot {
let buffer = if src_texture.format() != Self::DST_FORMAT {
let mut converter_data_pair = self
.converter_data_pair
.lock()
.expect("failed to lock converter");
let converter_data_pair = converter_data_pair
.get_or_insert_with(|| create_converter_data_pair(device, src_texture));
if !wgpu::texture_descriptor_eq(
src_texture.descriptor(),
&converter_data_pair.src_descriptor,
) {
*converter_data_pair = create_converter_data_pair(device, src_texture);
}
let dst_view = converter_data_pair.dst_texture.view();
converter_data_pair
.reshaper
.encode_render_pass(&dst_view.build(), encoder);
converter_data_pair.dst_texture.to_buffer(device, encoder)
} else {
src_texture.to_buffer(device, encoder)
};
Snapshot {
buffer,
thread_pool: self.thread_pool.clone(),
workers: self.workers,
timeout: self.timeout,
}
}
pub fn await_active_snapshots(
&self,
device: &wgpu::Device,
) -> Result<(), AwaitWorkerTimeout<()>> {
if let Ok(guard) = self.thread_pool.lock() {
if let Some(tp) = guard.as_ref() {
return tp.await_active_futures(device);
}
}
Ok(())
}
}
impl Snapshot {
pub async fn read_async<'buffer>(
&'buffer self,
) -> Result<Rgba8AsyncMappedImageBuffer<'buffer>, wgpu::BufferAsyncError> {
let mapping = self.buffer.read().await?;
Ok(Rgba8AsyncMappedImageBuffer(mapping))
}
pub fn read<F>(self, callback: F) -> Result<(), AwaitWorkerTimeout<impl Future<Output = ()>>>
where
F: 'static + Send + FnOnce(Result<Rgba8AsyncMappedImageBuffer, wgpu::BufferAsyncError>),
{
let thread_pool = self.thread_pool();
let read_future = async move {
let res = self.read_async().await;
callback(res);
};
thread_pool.spawn_when_worker_available(read_future)
}
fn thread_pool(&self) -> Arc<ThreadPool> {
let mut guard = self
.thread_pool
.lock()
.expect("failed to acquire thread handle");
let thread_pool = guard.get_or_insert_with(|| {
let workers = self.workers.unwrap_or(num_cpus::get() as u32);
let thread_pool = ThreadPool {
active_futures: Arc::new(AtomicU32::new(0)),
workers,
timeout: self.timeout,
};
Arc::new(thread_pool)
});
thread_pool.clone()
}
}
impl<'b> Rgba8AsyncMappedImageBuffer<'b> {
pub fn as_image(&self) -> image::SubImage<wgpu::ImageHolder<image::Rgba<u8>>> {
unsafe { self.0.as_image::<image::Rgba<u8>>() }
}
pub fn to_owned(&self) -> image::ImageBuffer<image::Rgba<u8>, Vec<u8>> {
let view = self.as_image();
let mut result = image::ImageBuffer::new(view.width(), view.height());
result
.copy_from(&view, 0, 0)
.expect("nannou internal error: image copy failed");
result
}
}
impl<T> std::error::Error for AwaitWorkerTimeout<T> {}
impl<T> fmt::Debug for AwaitWorkerTimeout<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("AwaitWorkerTimeout").finish()
}
}
impl<T> fmt::Display for AwaitWorkerTimeout<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("AwaitWorkerTimeout").finish()
}
}
fn create_converter_data_pair(
device: &wgpu::Device,
src_texture: &wgpu::Texture,
) -> ConverterDataPair {
let dst_texture = wgpu::TextureBuilder::from(src_texture.descriptor.clone())
.sample_count(1)
.format(Capturer::DST_FORMAT)
.usage(wgpu::TextureUsages::RENDER_ATTACHMENT | wgpu::TextureUsages::COPY_SRC)
.build(device);
let src_sample_count = src_texture.sample_count();
let src_sample_type = src_texture.sample_type();
let src_view = src_texture.create_view(&wgpu::TextureViewDescriptor::default());
let dst_sample_count = 1;
let dst_format = dst_texture.format();
let reshaper = wgpu::TextureReshaper::new(
device,
&src_view,
src_sample_count,
src_sample_type,
dst_sample_count,
dst_format,
);
let src_descriptor = src_texture.descriptor.clone();
ConverterDataPair {
src_descriptor,
reshaper,
dst_texture,
}
}