use super::device::{ComputeBuffer, ComputeDevice};
pub struct ComputeTensor<B: ComputeBuffer> {
pub buffer: B,
shape: Vec<usize>,
}
impl<B: ComputeBuffer> ComputeTensor<B> {
pub fn from_data<D: ComputeDevice<Buffer = B>>(dev: &D, data: &[f32], shape: &[usize]) -> Self {
let numel: usize = shape.iter().product();
assert_eq!(data.len(), numel, "data length {} != shape product {}", data.len(), numel);
Self { buffer: dev.upload(data), shape: shape.to_vec() }
}
pub fn zeros<D: ComputeDevice<Buffer = B>>(dev: &D, shape: &[usize]) -> Self {
let numel: usize = shape.iter().product();
let data = vec![0.0f32; numel];
Self { buffer: dev.upload(&data), shape: shape.to_vec() }
}
pub fn from_buffer(buffer: B, shape: Vec<usize>) -> Self {
let numel: usize = shape.iter().product();
assert_eq!(buffer.len(), numel, "buffer len {} != shape product {}", buffer.len(), numel);
Self { buffer, shape }
}
pub fn reshape(self, new_shape: &[usize]) -> Self {
let new_numel: usize = new_shape.iter().product();
assert_eq!(self.numel(), new_numel, "reshape: {} != {}", self.numel(), new_numel);
Self { buffer: self.buffer, shape: new_shape.to_vec() }
}
pub fn shape(&self) -> &[usize] {
&self.shape
}
pub fn numel(&self) -> usize {
self.shape.iter().product()
}
pub fn to_vec(&self) -> Vec<f32> {
self.buffer.to_vec()
}
pub fn transpose_2d<D: ComputeDevice<Buffer = B>>(self, dev: &D) -> Self {
assert_eq!(self.shape.len(), 2, "transpose_2d requires 2D tensor");
let rows = self.shape[0];
let cols = self.shape[1];
let buf = dev.transpose_2d(&self.buffer, rows, cols);
Self::from_buffer(buf, vec![cols, rows])
}
}