use std::fmt;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Weak};
use parking_lot::RwLock;
#[repr(transparent)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct FenceId(usize);
impl FenceId {
pub(crate) const fn new(raw: usize) -> Self {
debug_assert!(raw != NO_FENCE, "FenceId cannot be usize::MAX (reserved sentinel)");
Self(raw)
}
pub fn get(self) -> usize {
self.0
}
}
#[repr(transparent)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct GradId(pub usize);
#[repr(transparent)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct OpId(pub usize);
#[repr(transparent)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct ParamId(pub usize);
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum DType {
F32,
F16,
Q8 { block_size: usize },
}
impl DType {
pub fn byte_size(self) -> usize {
match self {
DType::F32 => 4,
DType::F16 => 2,
DType::Q8 { .. } => 1, }
}
pub fn gpu_buf_size(self, numel: usize) -> u64 {
let raw = match self {
DType::F32 => numel * 4,
DType::F16 => numel * 2,
DType::Q8 { block_size } => {
let num_blocks = (numel + block_size - 1) / block_size;
num_blocks * (4 + block_size)
}
};
((raw + 3) & !3) as u64
}
pub fn q8_block_stride(self) -> usize {
match self {
DType::Q8 { block_size } => 4 + block_size,
_ => panic!("q8_block_stride called on non-Q8 dtype"),
}
}
pub fn is_quantized(self) -> bool {
matches!(self, DType::Q8 { .. })
}
}
const NO_FENCE: usize = usize::MAX;
pub type DataReadGuard<'a> = parking_lot::MappedRwLockReadGuard<'a, [f32]>;
pub type DataWriteGuard<'a> = parking_lot::MappedRwLockWriteGuard<'a, [f32]>;
#[cfg(feature = "gpu")]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[allow(dead_code)] pub(crate) enum DirtySide {
Cpu,
Gpu,
Clean,
}
#[allow(dead_code)] pub(crate) enum StorageData {
Cpu(Vec<f32>),
#[cfg(feature = "gpu")]
Gpu {
buffer: wgpu::Buffer,
len: usize,
},
#[cfg(feature = "gpu")]
Both {
cpu: Vec<f32>,
gpu: wgpu::Buffer,
dirty: DirtySide,
},
#[cfg(feature = "gpu")]
Transferring,
#[cfg(feature = "jit")]
Deferred {
var_id: usize,
},
}
impl fmt::Debug for StorageData {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
StorageData::Cpu(v) => f.debug_tuple("Cpu").field(&v.len()).finish(),
#[cfg(feature = "gpu")]
StorageData::Gpu { len, .. } => {
f.debug_struct("Gpu").field("len", len).finish()
}
#[cfg(feature = "gpu")]
StorageData::Both { cpu, dirty, .. } => f
.debug_struct("Both")
.field("len", &cpu.len())
.field("dirty", dirty)
.finish(),
#[cfg(feature = "gpu")]
StorageData::Transferring => write!(f, "Transferring"),
#[cfg(feature = "jit")]
StorageData::Deferred { var_id } => {
f.debug_struct("Deferred").field("var_id", var_id).finish()
}
}
}
}
pub struct StorageInner {
data: RwLock<StorageData>,
len: usize,
dtype: DType,
pub(crate) device_index: usize,
version: AtomicUsize,
fence: AtomicUsize,
}
impl fmt::Debug for StorageInner {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("StorageInner")
.field("len", &self.len)
.field("version", &self.version.load(Ordering::Relaxed))
.field("fence", &self.fence.load(Ordering::Relaxed))
.finish()
}
}
#[cfg(feature = "gpu")]
impl Drop for StorageInner {
fn drop(&mut self) {
let data = self.data.get_mut();
let placeholder = StorageData::Cpu(Vec::new());
let owned = std::mem::replace(data, placeholder);
match owned {
StorageData::Gpu { buffer, .. } => {
if let Some(ctx) = crate::backend::gpu::context::GpuContext::get() {
ctx.pool.release(buffer);
}
}
StorageData::Both { gpu, .. } => {
if let Some(ctx) = crate::backend::gpu::context::GpuContext::get() {
ctx.pool.release(gpu);
}
}
_ => {}
}
}
}
#[derive(Clone, Debug)]
pub struct StorageHandle {
inner: Arc<StorageInner>,
}
impl StorageHandle {
pub fn new(data: Vec<f32>) -> Self {
let len = data.len();
Self {
inner: Arc::new(StorageInner {
data: RwLock::new(StorageData::Cpu(data)),
len,
dtype: DType::F32,
device_index: 0,
version: AtomicUsize::new(0),
fence: AtomicUsize::new(NO_FENCE),
}),
}
}
#[cfg(feature = "gpu")]
pub fn new_gpu(buffer: wgpu::Buffer, len: usize) -> Self {
Self {
inner: Arc::new(StorageInner {
data: RwLock::new(StorageData::Gpu { buffer, len }),
len,
dtype: DType::F32,
device_index: 0,
version: AtomicUsize::new(0),
fence: AtomicUsize::new(NO_FENCE),
}),
}
}
#[cfg(feature = "gpu")]
pub fn new_gpu_f16(buffer: wgpu::Buffer, len: usize) -> Self {
Self {
inner: Arc::new(StorageInner {
data: RwLock::new(StorageData::Gpu { buffer, len }),
len,
dtype: DType::F16,
device_index: 0,
version: AtomicUsize::new(0),
fence: AtomicUsize::new(NO_FENCE),
}),
}
}
#[cfg(feature = "gpu")]
pub fn new_gpu_q8(buffer: wgpu::Buffer, len: usize, block_size: usize) -> Self {
Self {
inner: Arc::new(StorageInner {
data: RwLock::new(StorageData::Gpu { buffer, len }),
len,
dtype: DType::Q8 { block_size },
device_index: 0,
version: AtomicUsize::new(0),
fence: AtomicUsize::new(NO_FENCE),
}),
}
}
#[cfg(feature = "jit")]
pub fn new_deferred(var_id: usize, len: usize, dtype: DType) -> Self {
Self {
inner: Arc::new(StorageInner {
data: RwLock::new(StorageData::Deferred { var_id }),
len,
dtype,
device_index: 0,
version: AtomicUsize::new(0),
fence: AtomicUsize::new(NO_FENCE),
}),
}
}
#[cfg(feature = "jit")]
pub fn materialize_gpu(&self, buffer: wgpu::Buffer) {
let mut guard = self.inner.data.write();
*guard = StorageData::Gpu { buffer, len: self.inner.len };
}
pub fn len(&self) -> usize {
self.inner.len
}
pub fn device_index(&self) -> usize {
self.inner.device_index
}
pub(crate) fn set_device_index(&mut self, idx: usize) {
if let Some(inner) = Arc::get_mut(&mut self.inner) {
inner.device_index = idx;
}
}
pub fn dtype(&self) -> DType {
self.inner.dtype
}
pub fn ptr_id(&self) -> usize {
Arc::as_ptr(&self.inner) as usize
}
#[cfg(feature = "gpu")]
pub fn download_raw_bytes(&self) -> Vec<u8> {
self.ensure_gpu();
let guard = self.inner.data.read();
let buffer = match &*guard {
StorageData::Gpu { buffer, .. } => buffer,
StorageData::Both { gpu, .. } => gpu,
_ => unreachable!("ensure_gpu guarantees GPU data"),
};
let byte_size = self.inner.dtype.gpu_buf_size(self.inner.len);
let ctx = self.device_ctx();
ctx.download_raw_bytes(buffer, byte_size)
}
#[cfg(feature = "gpu")]
fn device_ctx(&self) -> &'static crate::backend::gpu::context::GpuContext {
#[cfg(feature = "multi_gpu")]
{
let mgpu = crate::backend::gpu::context::MultiGpuContext::get()
.expect("GPU context required");
return mgpu.device(self.inner.device_index);
}
#[cfg(not(feature = "multi_gpu"))]
{
crate::backend::gpu::context::GpuContext::get()
.expect("GPU context required")
}
}
#[cfg(feature = "gpu")]
pub fn is_gpu(&self) -> bool {
matches!(
&*self.inner.data.read(),
StorageData::Gpu { .. } | StorageData::Both { .. }
)
}
#[cfg(feature = "gpu")]
pub fn mark_gpu_dirty(&self) {
let mut guard = self.inner.data.write();
if let StorageData::Both { dirty, .. } = &mut *guard {
*dirty = DirtySide::Gpu;
}
}
pub fn version(&self) -> usize {
self.inner.version.load(Ordering::Acquire)
}
pub fn bump_version(&self) -> usize {
self.inner.version.fetch_add(1, Ordering::AcqRel)
}
pub fn fence(&self) -> Option<FenceId> {
let raw = self.inner.fence.load(Ordering::Acquire);
if raw == NO_FENCE {
None
} else {
Some(FenceId::new(raw))
}
}
pub fn set_fence(&self, fence: FenceId) {
self.inner.fence.store(fence.get(), Ordering::Release);
}
pub fn clear_fence(&self) {
self.inner.fence.store(NO_FENCE, Ordering::Release);
}
pub fn data(&self) -> DataReadGuard<'_> {
#[cfg(feature = "gpu")]
match self.inner.dtype {
DType::F16 => self.ensure_cpu_f16_as_f32(),
DType::Q8 { .. } => self.ensure_cpu_q8_as_f32(),
_ => self.ensure_cpu(),
}
#[cfg(not(feature = "gpu"))]
assert_eq!(self.inner.dtype, DType::F32, "F16/Q8 tensors require GPU feature");
parking_lot::RwLockReadGuard::map(self.inner.data.read(), |sd| match sd {
StorageData::Cpu(v) => v.as_slice(),
#[cfg(feature = "gpu")]
StorageData::Both { cpu, .. } => cpu.as_slice(),
#[cfg(feature = "gpu")]
StorageData::Gpu { .. } => unreachable!("ensure_cpu guarantees CPU data"),
#[cfg(feature = "gpu")]
StorageData::Transferring => {
panic!("cannot read tensor data while a device transfer is in progress")
}
#[cfg(feature = "jit")]
StorageData::Deferred { .. } => {
panic!("cannot read JIT-deferred tensor — flush the JIT block first")
}
})
}
#[cfg(feature = "gpu")]
fn ensure_cpu_f16_as_f32(&self) {
{
let guard = self.inner.data.read();
match &*guard {
StorageData::Cpu(_) => return,
StorageData::Both {
dirty: DirtySide::Clean | DirtySide::Cpu,
..
} => return,
_ => {}
}
}
let ctx = self.device_ctx();
let extracted = {
let mut guard = self.inner.data.write();
match &*guard {
StorageData::Cpu(_) => return,
StorageData::Both {
dirty: DirtySide::Clean | DirtySide::Cpu,
..
} => return,
_ => {}
}
std::mem::replace(&mut *guard, StorageData::Transferring)
};
let new_state = match extracted {
StorageData::Gpu { buffer, len } => {
let f32_buf = ctx.pool.acquire(
&ctx.device,
(len * 4) as u64,
crate::backend::gpu::context::STORAGE_USAGE,
);
crate::backend::gpu::compute::cast_f16_to_f32_dispatch(
ctx, &buffer, &f32_buf, len as u32,
);
let cpu_data = ctx.download(&f32_buf, len);
ctx.pool.release(f32_buf);
StorageData::Both {
cpu: cpu_data,
gpu: buffer,
dirty: DirtySide::Clean,
}
}
StorageData::Both {
gpu,
dirty: DirtySide::Gpu,
..
} => {
let len = self.inner.len;
let f32_buf = ctx.pool.acquire(
&ctx.device,
(len * 4) as u64,
crate::backend::gpu::context::STORAGE_USAGE,
);
crate::backend::gpu::compute::cast_f16_to_f32_dispatch(
ctx, &gpu, &f32_buf, len as u32,
);
let cpu_data = ctx.download(&f32_buf, len);
ctx.pool.release(f32_buf);
StorageData::Both {
cpu: cpu_data,
gpu,
dirty: DirtySide::Clean,
}
}
other => other,
};
*self.inner.data.write() = new_state;
}
#[cfg(feature = "gpu")]
fn ensure_cpu_q8_as_f32(&self) {
{
let guard = self.inner.data.read();
match &*guard {
StorageData::Cpu(_) => return,
StorageData::Both {
dirty: DirtySide::Clean | DirtySide::Cpu,
..
} => return,
_ => {}
}
}
let ctx = self.device_ctx();
let extracted = {
let mut guard = self.inner.data.write();
match &*guard {
StorageData::Cpu(_) => return,
StorageData::Both {
dirty: DirtySide::Clean | DirtySide::Cpu,
..
} => return,
_ => {}
}
std::mem::replace(&mut *guard, StorageData::Transferring)
};
let block_size = match self.inner.dtype {
DType::Q8 { block_size } => block_size,
_ => unreachable!(),
};
let new_state = match extracted {
StorageData::Gpu { buffer, len } => {
let f32_buf = ctx.pool.acquire(
&ctx.device,
(len * 4) as u64,
crate::backend::gpu::context::STORAGE_USAGE,
);
crate::backend::gpu::compute::dequantize_dispatch(
ctx, &buffer, &f32_buf, len as u32, block_size as u32,
);
let cpu_data = ctx.download(&f32_buf, len);
ctx.pool.release(f32_buf);
StorageData::Both {
cpu: cpu_data,
gpu: buffer,
dirty: DirtySide::Clean,
}
}
StorageData::Both {
gpu,
dirty: DirtySide::Gpu,
..
} => {
let len = self.inner.len;
let f32_buf = ctx.pool.acquire(
&ctx.device,
(len * 4) as u64,
crate::backend::gpu::context::STORAGE_USAGE,
);
crate::backend::gpu::compute::dequantize_dispatch(
ctx, &gpu, &f32_buf, len as u32, block_size as u32,
);
let cpu_data = ctx.download(&f32_buf, len);
ctx.pool.release(f32_buf);
StorageData::Both {
cpu: cpu_data,
gpu,
dirty: DirtySide::Clean,
}
}
other => other,
};
*self.inner.data.write() = new_state;
}
pub fn data_write(&self) -> DataWriteGuard<'_> {
#[cfg(feature = "gpu")]
self.ensure_cpu();
parking_lot::RwLockWriteGuard::map(self.inner.data.write(), |sd| match sd {
StorageData::Cpu(v) => v.as_mut_slice(),
#[cfg(feature = "gpu")]
StorageData::Both { cpu, dirty, .. } => {
*dirty = DirtySide::Cpu;
cpu.as_mut_slice()
}
#[cfg(feature = "gpu")]
StorageData::Gpu { .. } => unreachable!("ensure_cpu guarantees CPU data"),
#[cfg(feature = "gpu")]
StorageData::Transferring => {
panic!("cannot write tensor data while a device transfer is in progress")
}
#[cfg(feature = "jit")]
StorageData::Deferred { .. } => {
panic!("cannot write JIT-deferred tensor — flush the JIT block first")
}
})
}
pub fn downgrade(&self) -> WeakStorageHandle {
WeakStorageHandle {
inner: Arc::downgrade(&self.inner),
}
}
#[cfg(feature = "gpu")]
pub fn gpu_buffer(
&self,
) -> parking_lot::MappedRwLockReadGuard<'_, wgpu::Buffer> {
self.ensure_gpu();
parking_lot::RwLockReadGuard::map(self.inner.data.read(), |sd| match sd {
StorageData::Gpu { buffer, .. } => buffer,
StorageData::Both { gpu, .. } => gpu,
StorageData::Cpu(_) => unreachable!("ensure_gpu guarantees GPU data"),
StorageData::Transferring => {
panic!("cannot access GPU buffer while a device transfer is in progress")
}
#[cfg(feature = "jit")]
StorageData::Deferred { .. } => {
panic!("cannot access GPU buffer of JIT-deferred tensor — flush the JIT block first")
}
})
}
#[cfg(feature = "gpu")]
fn ensure_cpu(&self) {
{
let guard = self.inner.data.read();
match &*guard {
StorageData::Cpu(_) => return,
StorageData::Both {
dirty: DirtySide::Clean | DirtySide::Cpu,
..
} => return,
_ => {}
}
}
let extracted = {
let mut guard = self.inner.data.write();
match &*guard {
StorageData::Cpu(_) => return,
StorageData::Both {
dirty: DirtySide::Clean | DirtySide::Cpu,
..
} => return,
_ => {}
}
std::mem::replace(&mut *guard, StorageData::Transferring)
};
let ctx = self.device_ctx();
let new_state = match extracted {
StorageData::Gpu { buffer, len } => {
let cpu_data = ctx.download(&buffer, len);
StorageData::Both {
cpu: cpu_data,
gpu: buffer,
dirty: DirtySide::Clean,
}
}
StorageData::Both {
gpu,
dirty: DirtySide::Gpu,
..
} => {
let len = self.inner.len;
let cpu_data = ctx.download(&gpu, len);
StorageData::Both {
cpu: cpu_data,
gpu,
dirty: DirtySide::Clean,
}
}
other => other,
};
*self.inner.data.write() = new_state;
}
#[cfg(feature = "gpu")]
pub fn ensure_gpu(&self) {
{
let guard = self.inner.data.read();
match &*guard {
StorageData::Gpu { .. } => return,
StorageData::Both {
dirty: DirtySide::Clean | DirtySide::Gpu,
..
} => return,
_ => {}
}
}
let extracted = {
let mut guard = self.inner.data.write();
match &*guard {
StorageData::Gpu { .. } => return,
StorageData::Both {
dirty: DirtySide::Clean | DirtySide::Gpu,
..
} => return,
_ => {}
}
std::mem::replace(&mut *guard, StorageData::Transferring)
};
let ctx = self.device_ctx();
let new_state = match extracted {
StorageData::Cpu(cpu_data) => {
let buffer = ctx.upload(&cpu_data);
StorageData::Both {
cpu: cpu_data,
gpu: buffer,
dirty: DirtySide::Clean,
}
}
StorageData::Both {
cpu,
gpu,
dirty: DirtySide::Cpu,
} => {
ctx.queue.write_buffer(&gpu, 0, bytemuck::cast_slice(&cpu));
StorageData::Both {
cpu,
gpu,
dirty: DirtySide::Clean,
}
}
other => other,
};
*self.inner.data.write() = new_state;
}
}
#[derive(Clone)]
pub struct WeakStorageHandle {
inner: Weak<StorageInner>,
}
impl fmt::Debug for WeakStorageHandle {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self.inner.upgrade() {
Some(inner) => f
.debug_struct("WeakStorageHandle")
.field("alive", &true)
.field("version", &inner.version.load(Ordering::Relaxed))
.finish(),
None => f
.debug_struct("WeakStorageHandle")
.field("alive", &false)
.finish(),
}
}
}
impl WeakStorageHandle {
pub fn upgrade(&self) -> Option<StorageHandle> {
self.inner.upgrade().map(|inner| StorageHandle { inner })
}
}
const _: () = {
fn _assert_send<T: Send>() {}
fn _assert_sync<T: Sync>() {}
fn _assertions() {
_assert_send::<StorageInner>();
_assert_sync::<StorageInner>();
_assert_send::<StorageHandle>();
_assert_sync::<StorageHandle>();
}
};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Layout {
shape: Vec<usize>,
strides: Vec<usize>,
offset: usize,
}
impl Layout {
pub fn contiguous(shape: Vec<usize>) -> Self {
let ndim = shape.len();
let mut strides = vec![0usize; ndim];
if ndim > 0 {
strides[ndim - 1] = 1;
for i in (0..ndim - 1).rev() {
strides[i] = strides[i + 1] * shape[i + 1];
}
}
Self { shape, strides, offset: 0 }
}
pub fn ndim(&self) -> usize { self.shape.len() }
pub fn numel(&self) -> usize { self.shape.iter().product() }
pub fn shape(&self) -> &[usize] { &self.shape }
pub fn strides(&self) -> &[usize] { &self.strides }
pub fn offset(&self) -> usize { self.offset }
pub fn is_contiguous(&self) -> bool {
if self.shape.is_empty() { return true; }
let mut expected = 1usize;
for i in (0..self.ndim()).rev() {
if self.strides[i] != expected { return false; }
expected *= self.shape[i];
}
true
}
pub fn transposed(&self, dim0: usize, dim1: usize) -> Self {
assert!(
dim0 < self.ndim() && dim1 < self.ndim(),
"transpose dims ({}, {}) out of range for {}-D tensor",
dim0, dim1, self.ndim(),
);
let mut shape = self.shape.clone();
let mut strides = self.strides.clone();
shape.swap(dim0, dim1);
strides.swap(dim0, dim1);
Self { shape, strides, offset: self.offset }
}
pub fn reshaped(&self, new_shape: Vec<usize>) -> Option<Self> {
let old_numel: usize = self.shape.iter().product();
let new_numel: usize = new_shape.iter().product();
assert_eq!(
old_numel, new_numel,
"cannot reshape {} elements into shape {:?} ({} elements)",
old_numel, new_shape, new_numel,
);
if !self.is_contiguous() { return None; }
Some(Self::contiguous_with_offset(new_shape, self.offset))
}
fn contiguous_with_offset(shape: Vec<usize>, offset: usize) -> Self {
let ndim = shape.len();
let mut strides = vec![0usize; ndim];
if ndim > 0 {
strides[ndim - 1] = 1;
for i in (0..ndim - 1).rev() {
strides[i] = strides[i + 1] * shape[i + 1];
}
}
Self { shape, strides, offset }
}
}
pub struct TensorMeta {
pub requires_grad: bool,
pub grad_id: Option<GradId>,
pub creator: Option<OpId>,
pub is_leaf: bool,
pub retains_grad: bool,
pub total_grads: AtomicUsize,
}
impl TensorMeta {
pub fn leaf(requires_grad: bool) -> Self {
Self {
requires_grad,
grad_id: None,
creator: None,
is_leaf: true,
retains_grad: false,
total_grads: AtomicUsize::new(0),
}
}
}
impl fmt::Debug for TensorMeta {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("TensorMeta")
.field("requires_grad", &self.requires_grad)
.field("grad_id", &self.grad_id)
.field("creator", &self.creator)
.field("is_leaf", &self.is_leaf)
.field("retains_grad", &self.retains_grad)
.field("total_grads", &self.total_grads.load(Ordering::Relaxed))
.finish()
}
}
#[derive(Debug, Clone)]
pub enum AutogradState {
None,
Tracked(Arc<TensorMeta>),
}
#[derive(Clone, Debug)]
pub struct Tensor {
pub(crate) storage: StorageHandle,
pub(crate) layout: Layout,
pub(crate) state: AutogradState,
}
impl Tensor {
pub fn new(data: Vec<f32>, shape: Vec<usize>) -> Self {
let numel: usize = shape.iter().product();
assert_eq!(
numel,
data.len(),
"shape {:?} expects {} elements but got {}",
shape, numel, data.len(),
);
let layout = Layout::contiguous(shape);
Self {
storage: StorageHandle::new(data),
layout,
state: AutogradState::None,
}
}
pub fn shape(&self) -> &[usize] { self.layout.shape() }
pub fn strides(&self) -> &[usize] { self.layout.strides() }
pub fn ndim(&self) -> usize { self.layout.ndim() }
pub fn numel(&self) -> usize { self.layout.numel() }
pub fn is_contiguous(&self) -> bool { self.layout.is_contiguous() }
pub fn dtype(&self) -> DType { self.storage.dtype() }
pub fn requires_grad(&self) -> bool {
match &self.state {
AutogradState::None => false,
AutogradState::Tracked(meta) => meta.requires_grad,
}
}
pub fn version(&self) -> usize { self.storage.version() }
pub fn data(&self) -> DataReadGuard<'_> {
self.storage.data()
}
pub fn set_requires_grad(&mut self, requires_grad: bool) {
match &self.state {
AutogradState::None if requires_grad => {
let grad_id = crate::autograd::context::next_grad_id();
let mut meta = TensorMeta::leaf(true);
meta.grad_id = Some(grad_id);
self.state = AutogradState::Tracked(Arc::new(meta));
}
AutogradState::Tracked(_) if !requires_grad => {
self.state = AutogradState::None;
}
_ => {}
}
}
pub(crate) fn meta(&self) -> Option<&Arc<TensorMeta>> {
match &self.state {
AutogradState::Tracked(meta) => Some(meta),
AutogradState::None => None,
}
}
pub fn grad_id(&self) -> Option<GradId> {
match &self.state {
AutogradState::Tracked(meta) => meta.grad_id,
AutogradState::None => None,
}
}
pub(crate) fn from_storage_and_layout(
storage: StorageHandle,
layout: Layout,
) -> Tensor {
Tensor { storage, layout, state: AutogradState::None }
}
#[cfg(feature = "gpu")]
pub fn to_gpu(&self) {
self.storage.ensure_gpu();
}
#[cfg(feature = "multi_gpu")]
pub fn to_device(&self, target_device: usize) -> Tensor {
use crate::backend::gpu::context::{MultiGpuContext, STORAGE_USAGE};
if self.storage.device_index() == target_device {
if self.storage.is_gpu() {
return self.clone();
}
}
let mgpu = MultiGpuContext::get().expect("MultiGpuContext required for to_device");
assert!(target_device < mgpu.num_devices(), "device index out of range");
let target_ctx = mgpu.device(target_device);
let guard = self.storage.data();
let cpu_data: &[f32] = &guard;
let byte_size = self.dtype().gpu_buf_size(self.numel());
let buffer = target_ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("multi_gpu_transfer"),
size: byte_size,
usage: STORAGE_USAGE,
mapped_at_creation: false,
});
target_ctx.queue.write_buffer(&buffer, 0, bytemuck::cast_slice(cpu_data));
drop(guard);
let mut storage = StorageHandle::new_gpu(buffer, self.numel());
storage.set_device_index(target_device);
Tensor {
storage,
layout: crate::tensor::Layout::contiguous(self.shape().to_vec()),
state: AutogradState::None,
}
}
pub fn device_index(&self) -> usize {
self.storage.device_index()
}
}