use std::mem::MaybeUninit;
use slop_alloc::{mem::CopyError, Buffer, HasBackend, Slice};
use crate::{DeviceCopy, TaskScope};
pub struct DeviceBuffer<T> {
buf: Buffer<T, TaskScope>,
}
impl<T: DeviceCopy> HasBackend for DeviceBuffer<T> {
type Backend = TaskScope;
fn backend(&self) -> &TaskScope {
self.buf.backend()
}
}
impl<T: DeviceCopy> DeviceBuffer<T> {
pub fn with_capacity_in(capacity: usize, scope: TaskScope) -> Self {
Self { buf: Buffer::with_capacity_in(capacity, scope) }
}
pub fn from_raw(buf: Buffer<T, TaskScope>) -> Self {
Self { buf }
}
pub fn as_ptr(&self) -> *const T {
self.buf.as_ptr()
}
pub fn as_mut_ptr(&mut self) -> *mut T {
self.buf.as_mut_ptr()
}
pub unsafe fn copy_to_host_slice(&self, dst: &mut [MaybeUninit<T>]) -> Result<(), CopyError> {
self.buf.copy_into_host(dst)
}
pub unsafe fn extend_from_host_slice(&mut self, src: &[T]) -> Result<(), CopyError> {
self.buf.extend_from_host_slice(src)
}
pub fn extend_from_device_slice(&mut self, src: &Slice<T, TaskScope>) -> Result<(), CopyError> {
self.buf.extend_from_device_slice(src)
}
pub fn split_off(&mut self, at: usize) -> Self {
let len = self.len();
assert!(at <= len, "split_off index out of bounds: at {}, len {}", at, len);
let mut tail = DeviceBuffer::with_capacity_in(len - at, self.backend().clone());
tail.extend_from_device_slice(&self.buf[at..]).unwrap();
unsafe {
self.buf.set_len(at);
}
tail
}
pub fn to_host(&self) -> Result<Vec<T>, CopyError> {
let len = self.buf.len();
let mut host_vec = Vec::with_capacity(len);
unsafe {
self.copy_to_host_slice(host_vec.spare_capacity_mut())?;
host_vec.set_len(len);
}
Ok(host_vec)
}
pub fn len(&self) -> usize {
self.buf.len()
}
pub fn is_empty(&self) -> bool {
self.buf.is_empty()
}
pub unsafe fn set_len(&mut self, len: usize) {
self.buf.set_len(len);
}
#[allow(clippy::ptr_arg)]
pub fn extend_from_vec(&mut self, host_data: &Vec<T>) -> Result<(), CopyError> {
unsafe { self.extend_from_host_slice(host_data) }
}
pub fn extend(&mut self, host_data: &Buffer<T>) -> Result<(), CopyError> {
unsafe { self.extend_from_host_slice(host_data) }
}
pub fn from_host(host_buf: &Buffer<T>, scope: &TaskScope) -> Result<Self, CopyError> {
let mut device_buf = Self::with_capacity_in(host_buf.len(), scope.clone());
device_buf.extend(host_buf)?;
Ok(device_buf)
}
pub fn from_host_slice(host_slice: &[T], scope: &TaskScope) -> Result<Self, CopyError> {
let mut device_buf = Self::with_capacity_in(host_slice.len(), scope.clone());
unsafe { device_buf.extend_from_host_slice(host_slice)? };
Ok(device_buf)
}
pub fn into_inner(self) -> Buffer<T, TaskScope> {
self.buf
}
pub unsafe fn assume_init(&mut self) {
self.buf.assume_init();
}
}
#[cfg(test)]
mod tests {
use rand::{thread_rng, Rng};
use sp1_primitives::SP1Field;
use super::*;
#[test]
fn test_copy_buffer_into_backend() {
let mut rng = thread_rng();
let buffer: Vec<SP1Field> = (0..10000).map(|_| rng.gen::<SP1Field>()).collect();
let buffer_back = crate::run_sync_in_place(|t| {
let mut device_buffer = DeviceBuffer::with_capacity_in(buffer.len(), t);
device_buffer.extend_from_vec(&buffer).unwrap();
device_buffer.to_host().unwrap()
})
.unwrap();
assert_eq!(buffer_back, buffer);
}
}