use crate::{buffer::TensorBuffer, Device, Result, TensorError};
use scirs2_core::ndarray::ArrayD;
use std::sync::Arc;
use wgpu::util::DeviceExt;
#[cfg(feature = "gpu")]
use crate::gpu::memory_tracing::{AllocationId, GLOBAL_GPU_MEMORY_TRACKER};
#[cfg(feature = "gpu")]
#[derive(Debug)]
pub struct GpuBuffer<T> {
buffer: Arc<wgpu::Buffer>,
pub device: Arc<wgpu::Device>,
pub queue: Arc<wgpu::Queue>,
device_enum: Device,
len: usize,
is_pinned: bool,
#[cfg(feature = "gpu")]
allocation_id: Option<AllocationId>,
_phantom: std::marker::PhantomData<T>,
}
pub struct GpuBufferView<T> {
parent_buffer: Arc<GpuBuffer<T>>,
offset: usize,
len: usize,
device_enum: Device,
_phantom: std::marker::PhantomData<T>,
}
#[cfg(feature = "gpu")]
impl<T: bytemuck::Pod + bytemuck::Zeroable + Clone + Send + Sync + 'static> TensorBuffer
for GpuBufferView<T>
{
type Elem = T;
fn device(&self) -> &Device {
&self.device_enum
}
fn len(&self) -> usize {
self.len
}
fn size_bytes(&self) -> usize {
self.len * std::mem::size_of::<T>()
}
fn view(&self, offset: usize, len: usize) -> Result<Box<dyn TensorBuffer<Elem = T>>> {
if offset + len > self.len {
return Err(TensorError::invalid_operation_simple(format!(
"View out of bounds: {}+{} > {}",
offset, len, self.len
)));
}
Ok(Box::new(GpuBufferView::new(
Arc::clone(&self.parent_buffer),
self.offset + offset,
len,
)?))
}
fn to_cpu(&self) -> Result<Vec<T>> {
self.parent_buffer.to_cpu()
}
unsafe fn as_ptr(&self) -> *const T {
std::ptr::null()
}
unsafe fn as_mut_ptr(&mut self) -> *mut T {
std::ptr::null_mut()
}
fn clone_buffer(&self) -> Result<Box<dyn TensorBuffer<Elem = T>>> {
Ok(Box::new(self.clone()))
}
}
impl<T: bytemuck::Pod + bytemuck::Zeroable + Clone + Send + Sync + 'static> GpuBufferView<T> {
pub fn new(parent: Arc<GpuBuffer<T>>, offset: usize, len: usize) -> Result<Self> {
if offset + len > parent.len() {
return Err(TensorError::invalid_operation_simple(format!(
"Buffer view out of bounds: {}+{} > {}",
offset,
len,
parent.len()
)));
}
Ok(Self {
device_enum: parent.device_enum.clone(),
parent_buffer: parent,
offset,
len,
_phantom: std::marker::PhantomData,
})
}
#[inline]
pub fn parent(&self) -> &GpuBuffer<T> {
&self.parent_buffer
}
#[inline]
pub fn offset(&self) -> usize {
self.offset
}
#[inline]
pub fn buffer(&self) -> &wgpu::Buffer {
&self.parent_buffer.buffer
}
#[inline]
pub fn device(&self) -> &wgpu::Device {
&self.parent_buffer.device
}
#[inline]
pub fn queue(&self) -> &wgpu::Queue {
&self.parent_buffer.queue
}
pub fn to_cpu_array(&self) -> Result<ArrayD<T>> {
let full_data = self.parent_buffer.to_cpu()?;
let view_data = full_data[self.offset..self.offset + self.len].to_vec();
Ok(scirs2_core::ndarray::Array1::from(view_data).into_dyn())
}
}
impl<T> Clone for GpuBufferView<T> {
fn clone(&self) -> Self {
Self {
parent_buffer: Arc::clone(&self.parent_buffer),
offset: self.offset,
len: self.len,
device_enum: self.device_enum.clone(),
_phantom: std::marker::PhantomData,
}
}
}
impl<T> Clone for GpuBuffer<T> {
fn clone(&self) -> Self {
Self {
buffer: Arc::clone(&self.buffer),
device: Arc::clone(&self.device),
queue: Arc::clone(&self.queue),
device_enum: self.device_enum.clone(),
len: self.len,
is_pinned: self.is_pinned,
#[cfg(feature = "gpu")]
allocation_id: self.allocation_id,
_phantom: std::marker::PhantomData,
}
}
}
impl<T: bytemuck::Pod + bytemuck::Zeroable + Clone + Send + Sync + 'static> GpuBuffer<T> {
#[cfg(feature = "gpu")]
fn track_allocation(
size_bytes: usize,
device_id: usize,
operation: &str,
) -> Option<AllocationId> {
if let Ok(mut tracker) = GLOBAL_GPU_MEMORY_TRACKER.lock() {
Some(tracker.track_allocation(
size_bytes,
device_id,
operation.to_string(),
None, Some(std::any::type_name::<T>().to_string()),
))
} else {
None
}
}
#[cfg(feature = "gpu")]
pub fn allocation_id(&self) -> Option<AllocationId> {
self.allocation_id
}
pub fn zeros(len: usize, device_id: usize) -> Result<Self> {
use wgpu::util::DeviceExt;
let context = crate::gpu::GpuContext::global()?;
let device = &context.device;
let queue = &context.queue;
let zeros_data = vec![T::zeroed(); len];
let buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("gpu_buffer_zeros"),
contents: bytemuck::cast_slice(&zeros_data),
usage: wgpu::BufferUsages::STORAGE
| wgpu::BufferUsages::COPY_DST
| wgpu::BufferUsages::COPY_SRC,
});
#[cfg(feature = "gpu")]
let allocation_id =
Self::track_allocation(len * std::mem::size_of::<T>(), device_id, "zeros");
Ok(Self {
buffer: Arc::new(buffer),
device: Arc::clone(&context.device),
queue: Arc::clone(&context.queue),
device_enum: Device::Gpu(device_id),
len,
is_pinned: false,
#[cfg(feature = "gpu")]
allocation_id,
_phantom: std::marker::PhantomData,
})
}
pub fn from_wgpu_buffer(
buffer: wgpu::Buffer,
device: Arc<wgpu::Device>,
queue: Arc<wgpu::Queue>,
device_enum: Device,
len: usize,
) -> Self {
#[cfg(feature = "gpu")]
let allocation_id = if let Device::Gpu(device_id) = device_enum {
Self::track_allocation(
len * std::mem::size_of::<T>(),
device_id,
"from_wgpu_buffer",
)
} else {
None
};
Self {
buffer: Arc::new(buffer),
device,
queue,
device_enum,
len,
is_pinned: false,
#[cfg(feature = "gpu")]
allocation_id,
_phantom: std::marker::PhantomData,
}
}
pub fn from_raw_buffer(
buffer: wgpu::Buffer,
device: Arc<wgpu::Device>,
queue: Arc<wgpu::Queue>,
device_enum: Device,
len: usize,
) -> Self {
Self::from_wgpu_buffer(buffer, device, queue, device_enum, len)
}
pub fn from_shared_buffer(
buffer: Arc<wgpu::Buffer>,
device: Arc<wgpu::Device>,
queue: Arc<wgpu::Queue>,
device_enum: Device,
len: usize,
) -> Self {
#[cfg(feature = "gpu")]
let allocation_id = if let Device::Gpu(device_id) = device_enum {
Self::track_allocation(
len * std::mem::size_of::<T>(),
device_id,
"from_shared_buffer",
)
} else {
None
};
Self {
buffer,
device,
queue,
device_enum,
len,
is_pinned: false,
#[cfg(feature = "gpu")]
allocation_id,
_phantom: std::marker::PhantomData,
}
}
pub fn from_cpu_array(array: &ArrayD<T>, device_id: usize) -> Result<Self> {
use wgpu::util::DeviceExt;
let context = crate::gpu::GpuContext::global()?;
let device = &context.device;
let queue = &context.queue;
let slice = if array.is_standard_layout() {
array.as_slice().ok_or_else(|| {
TensorError::invalid_operation_simple(
"Cannot convert non-contiguous array to slice".to_string(),
)
})?
} else {
let data: Vec<T> = array.iter().cloned().collect();
return Self::from_slice(&data, &Device::Gpu(device_id));
};
let buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("gpu_buffer_from_cpu_array"),
contents: bytemuck::cast_slice(slice),
usage: wgpu::BufferUsages::STORAGE
| wgpu::BufferUsages::COPY_DST
| wgpu::BufferUsages::COPY_SRC,
});
#[cfg(feature = "gpu")]
let allocation_id = Self::track_allocation(
array.len() * std::mem::size_of::<T>(),
device_id,
"from_cpu_array",
);
Ok(Self {
buffer: Arc::new(buffer),
device: Arc::clone(&context.device),
queue: Arc::clone(&context.queue),
device_enum: Device::Gpu(device_id),
len: array.len(),
is_pinned: false,
#[cfg(feature = "gpu")]
allocation_id,
_phantom: std::marker::PhantomData,
})
}
pub fn to_cpu_array(&self) -> Result<ArrayD<T>> {
let data = self.to_cpu()?;
Ok(scirs2_core::ndarray::Array1::from(data).into_dyn())
}
pub fn len(&self) -> usize {
self.len
}
pub fn is_empty(&self) -> bool {
self.len == 0
}
pub fn size_bytes(&self) -> usize {
self.len * std::mem::size_of::<T>()
}
pub fn from_slice(slice: &[T], device: &Device) -> Result<Self> {
match device {
Device::Gpu(device_id) => {
let context = crate::gpu::GpuContext::global()?;
let gpu_device = &context.device;
let queue = &context.queue;
let buffer = gpu_device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("tensor_buffer_from_slice"),
contents: bytemuck::cast_slice(slice),
usage: wgpu::BufferUsages::STORAGE
| wgpu::BufferUsages::COPY_DST
| wgpu::BufferUsages::COPY_SRC,
});
#[cfg(feature = "gpu")]
let allocation_id =
Self::track_allocation(std::mem::size_of_val(slice), *device_id, "from_slice");
Ok(Self {
buffer: Arc::new(buffer),
device: Arc::clone(&context.device),
queue: Arc::clone(&context.queue),
device_enum: device.clone(),
len: slice.len(),
is_pinned: false,
#[cfg(feature = "gpu")]
allocation_id,
_phantom: std::marker::PhantomData,
})
}
_ => Err(TensorError::invalid_operation_simple(
"Expected GPU device".to_string(),
)),
}
}
pub fn transfer_to_device(&self, target_device: &Device) -> Result<Self> {
match target_device {
Device::Gpu(_device_id) => {
let cpu_data = self.to_cpu()?;
Self::from_slice(&cpu_data, target_device)
}
_ => Err(TensorError::invalid_operation_simple(
"Expected GPU device".to_string(),
)),
}
}
pub fn to_cpu(&self) -> Result<Vec<T>> {
let staging_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("staging_buffer"),
size: self.size_bytes() as u64,
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
let mut encoder = self
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("gpu_to_cpu_encoder"),
});
encoder.copy_buffer_to_buffer(
&self.buffer,
0,
&staging_buffer,
0,
self.size_bytes() as u64,
);
self.queue.submit(Some(encoder.finish()));
let buffer_slice = staging_buffer.slice(..);
let (sender, receiver) = futures::channel::oneshot::channel();
buffer_slice.map_async(wgpu::MapMode::Read, move |result| {
let _ = sender.send(result);
});
self.device.poll(wgpu::PollType::wait_indefinitely()).ok();
match futures::executor::block_on(receiver) {
Ok(Ok(())) => {
let data = buffer_slice.get_mapped_range();
let result = bytemuck::cast_slice(&data).to_vec();
drop(data);
staging_buffer.unmap();
Ok(result)
}
_ => Err(TensorError::invalid_operation_simple(
"Failed to read GPU buffer".to_string(),
)),
}
}
pub fn buffer(&self) -> &wgpu::Buffer {
&self.buffer
}
pub fn buffer_arc(&self) -> Arc<wgpu::Buffer> {
Arc::clone(&self.buffer)
}
pub fn device_enum(&self) -> Device {
self.device_enum.clone()
}
pub fn device(&self) -> &wgpu::Device {
&self.device
}
#[inline]
pub fn queue(&self) -> &wgpu::Queue {
&self.queue
}
#[inline]
pub fn is_pinned(&self) -> bool {
self.is_pinned
}
pub fn from_cpu_array_pinned(array: &ArrayD<T>, device_id: usize) -> Result<Self> {
let mut buffer = Self::from_cpu_array(array, device_id)?;
buffer.is_pinned = true;
Ok(buffer)
}
pub fn zeros_pinned(len: usize, device_id: usize) -> Result<Self> {
let mut buffer = Self::zeros(len, device_id)?;
buffer.is_pinned = true;
Ok(buffer)
}
}
#[cfg(feature = "gpu")]
impl<T> Drop for GpuBuffer<T> {
fn drop(&mut self) {
if Arc::strong_count(&self.buffer) == 1 {
if let Some(alloc_id) = self.allocation_id {
if let Ok(mut tracker) = GLOBAL_GPU_MEMORY_TRACKER.lock() {
tracker.track_free(alloc_id);
}
}
}
}
}
#[cfg(feature = "gpu")]
impl<T: bytemuck::Pod + bytemuck::Zeroable + Clone + Send + Sync + 'static> TensorBuffer
for GpuBuffer<T>
{
type Elem = T;
fn device(&self) -> &Device {
&self.device_enum
}
fn len(&self) -> usize {
self.len
}
fn size_bytes(&self) -> usize {
self.len * std::mem::size_of::<T>()
}
fn view(&self, offset: usize, len: usize) -> Result<Box<dyn TensorBuffer<Elem = T>>> {
if offset + len > self.len {
return Err(TensorError::invalid_operation_simple(format!(
"View out of bounds: {}+{} > {}",
offset, len, self.len
)));
}
Ok(Box::new(GpuBufferView::new(
Arc::new(self.clone()),
offset,
len,
)?))
}
fn to_cpu(&self) -> Result<Vec<T>> {
GpuBuffer::to_cpu(self)
}
unsafe fn as_ptr(&self) -> *const T {
std::ptr::null()
}
unsafe fn as_mut_ptr(&mut self) -> *mut T {
std::ptr::null_mut()
}
fn clone_buffer(&self) -> Result<Box<dyn TensorBuffer<Elem = T>>> {
Ok(Box::new(self.clone()))
}
}
pub trait GpuBufferOps<T> {
fn create_buffer(&self, size: usize) -> Result<GpuBuffer<T>>;
fn copy_from_host(&self, data: &[T]) -> Result<GpuBuffer<T>>;
fn copy_to_host(&self, buffer: &GpuBuffer<T>) -> Result<Vec<T>>;
}
pub struct BufferManager {
device: Arc<wgpu::Device>,
queue: Arc<wgpu::Queue>,
}
impl BufferManager {
pub fn new(device: Arc<wgpu::Device>, queue: Arc<wgpu::Queue>) -> Self {
Self { device, queue }
}
pub fn allocate<T>(&self, size: usize) -> Result<GpuBuffer<T>>
where
T: bytemuck::Pod + bytemuck::Zeroable + Clone + Send + Sync + 'static,
{
GpuBuffer::zeros(size, 0)
}
}
#[cfg(all(test, feature = "gpu"))]
mod tests {
use super::*;
use crate::gpu::memory_tracing::{current_gpu_memory_usage, GLOBAL_GPU_MEMORY_TRACKER};
#[test]
fn test_gpu_buffer_memory_tracking() {
if let Ok(mut tracker) = GLOBAL_GPU_MEMORY_TRACKER.lock() {
tracker.reset();
}
let initial_usage = current_gpu_memory_usage();
let buffer = GpuBuffer::<f32>::zeros(1024, 0);
if let Ok(buf) = buffer {
let usage_after_alloc = current_gpu_memory_usage();
assert!(
usage_after_alloc >= initial_usage,
"Memory usage should increase after allocation"
);
assert!(
buf.allocation_id().is_some(),
"Buffer should have an allocation ID"
);
drop(buf);
let final_usage = current_gpu_memory_usage();
assert!(
final_usage <= usage_after_alloc,
"Memory usage should not increase after dropping"
);
}
}
}