use crate::{
Buffer, BufferAddress, BufferDescriptor, BufferSize, BufferUsage, BufferViewMut,
CommandEncoder, Device, MapMode,
};
use std::pin::Pin;
use std::task::{self, Poll};
use std::{future::Future, sync::mpsc};
struct Join<F> {
futures: Vec<Option<F>>,
}
impl<F: Future<Output = ()>> Future for Join<F> {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<Self::Output> {
let all_ready = unsafe {
self.get_unchecked_mut().futures.iter_mut().all(|opt| {
if let Some(future) = opt {
if Pin::new_unchecked(future).poll(cx) == Poll::Ready(()) {
*opt = None;
}
}
opt.is_none()
})
};
if all_ready {
Poll::Ready(())
} else {
Poll::Pending
}
}
}
struct Chunk {
buffer: Buffer,
size: BufferAddress,
offset: BufferAddress,
}
pub struct StagingBelt {
chunk_size: BufferAddress,
active_chunks: Vec<Chunk>,
closed_chunks: Vec<Chunk>,
free_chunks: Vec<Chunk>,
sender: mpsc::Sender<Chunk>,
receiver: mpsc::Receiver<Chunk>,
}
impl StagingBelt {
pub fn new(chunk_size: BufferAddress) -> Self {
let (sender, receiver) = mpsc::channel();
StagingBelt {
chunk_size,
active_chunks: Vec::new(),
closed_chunks: Vec::new(),
free_chunks: Vec::new(),
sender,
receiver,
}
}
pub fn write_buffer(
&mut self,
encoder: &mut CommandEncoder,
target: &Buffer,
offset: BufferAddress,
size: BufferSize,
device: &Device,
) -> BufferViewMut {
let mut chunk = if let Some(index) = self
.active_chunks
.iter()
.position(|chunk| chunk.offset + size.get() <= chunk.size)
{
self.active_chunks.swap_remove(index)
} else if let Some(index) = self
.free_chunks
.iter()
.position(|chunk| size.get() <= chunk.size)
{
self.free_chunks.swap_remove(index)
} else {
let size = self.chunk_size.max(size.get());
Chunk {
buffer: device.create_buffer(&BufferDescriptor {
label: Some("staging"),
size,
usage: BufferUsage::MAP_WRITE | BufferUsage::COPY_SRC,
mapped_at_creation: true,
}),
size,
offset: 0,
}
};
encoder.copy_buffer_to_buffer(&chunk.buffer, chunk.offset, target, offset, size.get());
let old_offset = chunk.offset;
chunk.offset += size.get();
let remainder = chunk.offset % crate::MAP_ALIGNMENT;
if remainder != 0 {
chunk.offset += crate::MAP_ALIGNMENT - remainder;
}
self.active_chunks.push(chunk);
self.active_chunks
.last()
.unwrap()
.buffer
.slice(old_offset..old_offset + size.get())
.get_mapped_range_mut()
}
pub fn finish(&mut self) {
for chunk in self.active_chunks.drain(..) {
chunk.buffer.unmap();
self.closed_chunks.push(chunk);
}
}
pub fn recall(&mut self) -> impl Future<Output = ()> + Send {
while let Ok(mut chunk) = self.receiver.try_recv() {
chunk.offset = 0;
self.free_chunks.push(chunk);
}
let sender = &self.sender;
let futures = self
.closed_chunks
.drain(..)
.map(|chunk| {
let sender = sender.clone();
let async_buffer = chunk.buffer.slice(..).map_async(MapMode::Write);
Some(async move {
async_buffer.await.ok();
let _ = sender.send(chunk);
})
})
.collect::<Vec<_>>();
Join { futures }
}
}