1use crate::device::{GpuBuffer, GpuDevice};
8use anyhow::{ensure, Result};
9
10pub struct Tensor {
13 pub(crate) buf: GpuBuffer,
14 dims: [u32; 6],
15 ndim: u8,
16}
17
18impl Tensor {
19 pub fn new(dev: &GpuDevice, data: &[f32], shape: &[u32]) -> Result<Self> {
21 let numel: u32 = shape.iter().product();
22 ensure!(
23 data.len() == numel as usize,
24 "shape {:?} needs {} elements, got {}",
25 shape, numel, data.len()
26 );
27 ensure!(shape.len() <= 6, "max 6 dimensions, got {}", shape.len());
28 let buf = dev.upload(data);
29 let mut dims = [0u32; 6];
30 dims[..shape.len()].copy_from_slice(shape);
31 Ok(Self { buf, dims, ndim: shape.len() as u8 })
32 }
33
34 pub fn from_buf(buf: GpuBuffer, shape: &[u32]) -> Result<Self> {
36 let numel: u32 = shape.iter().product();
37 ensure!(buf.len == numel as usize, "buffer has {} elements, shape needs {}", buf.len, numel);
38 ensure!(shape.len() <= 6, "max 6 dimensions");
39 let mut dims = [0u32; 6];
40 dims[..shape.len()].copy_from_slice(shape);
41 Ok(Self { buf, dims, ndim: shape.len() as u8 })
42 }
43
44 pub fn zeros(dev: &GpuDevice, shape: &[u32]) -> Result<Self> {
46 let numel: u32 = shape.iter().product();
47 ensure!(shape.len() <= 6, "max 6 dimensions");
48 let buf = dev.alloc(numel as usize);
49 let mut dims = [0u32; 6];
50 dims[..shape.len()].copy_from_slice(shape);
51 Ok(Self { buf, dims, ndim: shape.len() as u8 })
52 }
53
54 #[inline]
56 pub fn shape(&self) -> &[u32] {
57 &self.dims[..self.ndim as usize]
58 }
59
60 #[inline]
62 pub fn ndim(&self) -> usize {
63 self.ndim as usize
64 }
65
66 #[inline]
68 pub fn numel(&self) -> usize {
69 self.buf.len
70 }
71
72 pub fn to_vec(&self, dev: &GpuDevice) -> Result<Vec<f32>> {
74 dev.read(&self.buf)
75 }
76
77 #[inline]
79 pub fn buffer(&self) -> &GpuBuffer {
80 &self.buf
81 }
82
83 pub fn reshape(self, new_shape: &[u32]) -> Result<Self> {
85 let numel: u32 = new_shape.iter().product();
86 ensure!(
87 self.buf.len == numel as usize,
88 "reshape: {} elements can't become shape {:?} ({})",
89 self.buf.len, new_shape, numel
90 );
91 ensure!(new_shape.len() <= 6, "max 6 dimensions");
92 let mut dims = [0u32; 6];
93 dims[..new_shape.len()].copy_from_slice(new_shape);
94 Ok(Self { buf: self.buf, dims, ndim: new_shape.len() as u8 })
95 }
96
97 #[inline]
99 pub fn dim(&self, i: usize) -> u32 {
100 self.dims[i]
101 }
102}
103
104#[cfg(test)]
105mod tests {
106 use super::*;
107
108 fn dev() -> &'static GpuDevice { &crate::ops::TEST_DEV }
109
110 #[test]
111 fn test_tensor_new() {
112 let t = Tensor::new(dev(), &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
113 assert_eq!(t.shape(), &[2, 3]);
114 assert_eq!(t.ndim(), 2);
115 assert_eq!(t.numel(), 6);
116 }
117
118 #[test]
119 fn test_tensor_readback() {
120 let data = vec![1.0, 2.0, 3.0];
121 let t = Tensor::new(dev(), &data, &[3]).unwrap();
122 assert_eq!(t.to_vec(dev()).unwrap(), data);
123 }
124
125 #[test]
126 fn test_tensor_reshape() {
127 let t = Tensor::new(dev(), &[1.0; 12], &[3, 4]).unwrap();
128 let t2 = t.reshape(&[2, 6]).unwrap();
129 assert_eq!(t2.shape(), &[2, 6]);
130 assert_eq!(t2.numel(), 12);
131 }
132
133 #[test]
134 fn test_tensor_reshape_mismatch() {
135 let t = Tensor::new(dev(), &[1.0; 12], &[3, 4]).unwrap();
136 assert!(t.reshape(&[2, 5]).is_err());
137 }
138
139 #[test]
140 fn test_tensor_shape_mismatch() {
141 assert!(Tensor::new(dev(), &[1.0, 2.0, 3.0], &[2, 2]).is_err());
142 }
143
144 #[test]
145 fn test_tensor_4d() {
146 let t = Tensor::new(dev(), &[0.0; 120], &[2, 3, 4, 5]).unwrap();
148 assert_eq!(t.shape(), &[2, 3, 4, 5]);
149 assert_eq!(t.ndim(), 4);
150 assert_eq!(t.numel(), 120);
151 assert_eq!(t.dim(0), 2);
152 assert_eq!(t.dim(1), 3);
153 }
154
155 #[test]
156 fn test_tensor_zeros() {
157 let t = Tensor::zeros(dev(), &[3, 3]).unwrap();
158 let data = t.to_vec(dev()).unwrap();
159 assert_eq!(data, vec![0.0; 9]);
160 }
161
162 #[test]
163 fn test_tensor_scalar() {
164 let t = Tensor::new(dev(), &[42.0], &[1]).unwrap();
165 assert_eq!(t.shape(), &[1]);
166 assert_eq!(t.to_vec(dev()).unwrap(), vec![42.0]);
167 }
168
169 #[test]
170 fn test_tensor_from_buf() {
171 let buf = dev().upload(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
172 let t = Tensor::from_buf(buf, &[2, 3]).unwrap();
173 assert_eq!(t.shape(), &[2, 3]);
174 assert_eq!(t.numel(), 6);
175 assert_eq!(t.to_vec(dev()).unwrap(), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
176 }
177
178 #[test]
179 fn test_tensor_from_buf_mismatch() {
180 let buf = dev().upload(&[1.0; 10]);
181 assert!(Tensor::from_buf(buf, &[3, 4]).is_err());
182 }
183
184 #[test]
185 fn test_tensor_buffer_access() {
186 let t = Tensor::new(dev(), &[7.0, 8.0], &[2]).unwrap();
187 let buf = t.buffer();
188 assert_eq!(buf.len, 2);
189 assert_eq!(dev().read(buf).unwrap(), vec![7.0, 8.0]);
190 }
191
192 #[test]
193 fn test_tensor_dim_all() {
194 let t = Tensor::new(dev(), &[0.0; 120], &[2, 3, 4, 5]).unwrap();
195 assert_eq!(t.dim(0), 2);
196 assert_eq!(t.dim(1), 3);
197 assert_eq!(t.dim(2), 4);
198 assert_eq!(t.dim(3), 5);
199 }
200
201 #[test]
202 fn test_tensor_6d_max() {
203 let t = Tensor::new(dev(), &[0.0; 1], &[1, 1, 1, 1, 1, 1]).unwrap();
204 assert_eq!(t.ndim(), 6);
205 assert_eq!(t.shape(), &[1, 1, 1, 1, 1, 1]);
206 }
207
208 #[test]
209 fn test_tensor_7d_exceeds_max() {
210 assert!(Tensor::new(dev(), &[0.0; 1], &[1, 1, 1, 1, 1, 1, 1]).is_err());
211 }
212
213 #[test]
214 fn test_tensor_reshape_flatten_unflatten() {
215 let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
216 let t = Tensor::new(dev(), &data, &[2, 3, 4]).unwrap();
217 let flat = t.reshape(&[24]).unwrap();
218 assert_eq!(flat.shape(), &[24]);
219 let back = flat.reshape(&[2, 3, 4]).unwrap();
220 assert_eq!(back.shape(), &[2, 3, 4]);
221 assert_eq!(back.to_vec(dev()).unwrap(), data);
222 }
223
224 #[test]
225 fn test_tensor_reshape_7d_exceeds() {
226 let t = Tensor::new(dev(), &[0.0; 1], &[1]).unwrap();
227 assert!(t.reshape(&[1, 1, 1, 1, 1, 1, 1]).is_err());
228 }
229
230 #[test]
231 fn test_tensor_zeros_odd_dim() {
232 let t = Tensor::zeros(dev(), &[7, 13]).unwrap();
233 assert_eq!(t.numel(), 91);
234 let data = t.to_vec(dev()).unwrap();
235 assert!(data.iter().all(|&v| v == 0.0));
236 }
237}