use cubecl_common::stream_id::StreamId;
use cubecl_zspace::{Shape, Strides};
use crate::{
memory_management::{ManagedMemoryBinding, ManagedMemoryHandle},
server::CopyDescriptor,
};
pub struct Handle {
pub memory: ManagedMemoryHandle,
pub offset_start: Option<u64>,
pub offset_end: Option<u64>,
pub stream: StreamId,
pub(crate) size: u64,
}
impl core::fmt::Debug for Handle {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("Handle")
.field("id", &self.memory)
.field("offset_start", &self.offset_start)
.field("offset_end", &self.offset_end)
.field("stream", &self.stream)
.field("size", &self.size)
.finish()
}
}
impl Clone for Handle {
fn clone(&self) -> Self {
Self {
memory: self.memory.clone(),
offset_start: self.offset_start,
offset_end: self.offset_end,
stream: self.stream,
size: self.size,
}
}
}
impl Handle {
pub fn from_memory(id: ManagedMemoryHandle, stream: StreamId, size: u64) -> Self {
Self {
memory: id,
offset_start: None,
offset_end: None,
stream,
size,
}
}
pub fn new(stream: StreamId, size: u64) -> Self {
Self {
memory: ManagedMemoryHandle::new(),
offset_start: None,
offset_end: None,
stream,
size,
}
}
pub fn can_mut(&self) -> bool {
self.memory.can_mut()
}
pub fn binding(self) -> Binding {
Binding {
memory: self.memory.binding(),
offset_start: self.offset_start,
offset_end: self.offset_end,
stream: self.stream,
size: self.size,
}
}
pub fn offset_start(mut self, offset: u64) -> Self {
if let Some(val) = &mut self.offset_start {
*val += offset;
} else {
self.offset_start = Some(offset);
}
self
}
pub fn offset_end(mut self, offset: u64) -> Self {
if let Some(val) = &mut self.offset_end {
*val += offset;
} else {
self.offset_end = Some(offset);
}
self
}
pub fn copy_descriptor(
self,
shape: Shape,
strides: Strides,
elem_size: usize,
) -> CopyDescriptor {
CopyDescriptor {
shape,
strides,
elem_size,
handle: self.binding(),
}
}
pub fn size_in_used(&self) -> u64 {
self.size - self.offset_start.unwrap_or(0) - self.offset_end.unwrap_or(0)
}
pub fn size(&self) -> u64 {
self.size
}
}
#[derive(Clone, Debug)]
pub struct Binding {
pub memory: ManagedMemoryBinding,
pub offset_start: Option<u64>,
pub offset_end: Option<u64>,
pub stream: StreamId,
pub size: u64,
}
impl Binding {
pub fn size_in_used(&self) -> u64 {
self.size - self.offset_start.unwrap_or(0) - self.offset_end.unwrap_or(0)
}
pub fn size(&self) -> u64 {
self.size
}
}