Skip to main content

any_gpu/
tensor.rs

1// Unlicense — cochranblock.org
2// Contributors: GotEmCoach, KOVA, Claude Opus 4.6
3//
4// Tensor: shaped view over a GpuBuffer. Tracks dimensions for op dispatch.
5// No autograd yet — that comes in Sprint 4.
6
7use crate::device::{GpuBuffer, GpuDevice};
8use anyhow::{ensure, Result};
9
10/// GPU tensor with shape metadata. Wraps a GpuBuffer.
11/// Shape is stored inline (max 6 dims covers batch x channel x D x H x W + extra).
12pub struct Tensor {
13    pub(crate) buf: GpuBuffer,
14    dims: [u32; 6],
15    ndim: u8,
16}
17
18impl Tensor {
19    /// Create a tensor from data with the given shape.
20    pub fn new(dev: &GpuDevice, data: &[f32], shape: &[u32]) -> Result<Self> {
21        let numel: u32 = shape.iter().product();
22        ensure!(
23            data.len() == numel as usize,
24            "shape {:?} needs {} elements, got {}",
25            shape, numel, data.len()
26        );
27        ensure!(shape.len() <= 6, "max 6 dimensions, got {}", shape.len());
28        let buf = dev.upload(data);
29        let mut dims = [0u32; 6];
30        dims[..shape.len()].copy_from_slice(shape);
31        Ok(Self { buf, dims, ndim: shape.len() as u8 })
32    }
33
34    /// Create a tensor from an existing GpuBuffer with the given shape.
35    pub fn from_buf(buf: GpuBuffer, shape: &[u32]) -> Result<Self> {
36        let numel: u32 = shape.iter().product();
37        ensure!(buf.len == numel as usize, "buffer has {} elements, shape needs {}", buf.len, numel);
38        ensure!(shape.len() <= 6, "max 6 dimensions");
39        let mut dims = [0u32; 6];
40        dims[..shape.len()].copy_from_slice(shape);
41        Ok(Self { buf, dims, ndim: shape.len() as u8 })
42    }
43
44    /// Create a zero tensor with the given shape.
45    pub fn zeros(dev: &GpuDevice, shape: &[u32]) -> Result<Self> {
46        let numel: u32 = shape.iter().product();
47        ensure!(shape.len() <= 6, "max 6 dimensions");
48        let buf = dev.alloc(numel as usize);
49        let mut dims = [0u32; 6];
50        dims[..shape.len()].copy_from_slice(shape);
51        Ok(Self { buf, dims, ndim: shape.len() as u8 })
52    }
53
54    /// Shape as a slice.
55    #[inline]
56    pub fn shape(&self) -> &[u32] {
57        &self.dims[..self.ndim as usize]
58    }
59
60    /// Number of dimensions.
61    #[inline]
62    pub fn ndim(&self) -> usize {
63        self.ndim as usize
64    }
65
66    /// Total number of elements.
67    #[inline]
68    pub fn numel(&self) -> usize {
69        self.buf.len
70    }
71
72    /// Read tensor data back to CPU.
73    pub fn to_vec(&self, dev: &GpuDevice) -> Result<Vec<f32>> {
74        dev.read(&self.buf)
75    }
76
77    /// Borrow the underlying GpuBuffer.
78    #[inline]
79    pub fn buffer(&self) -> &GpuBuffer {
80        &self.buf
81    }
82
83    /// Reshape to a new shape (same total elements, no data copy).
84    pub fn reshape(self, new_shape: &[u32]) -> Result<Self> {
85        let numel: u32 = new_shape.iter().product();
86        ensure!(
87            self.buf.len == numel as usize,
88            "reshape: {} elements can't become shape {:?} ({})",
89            self.buf.len, new_shape, numel
90        );
91        ensure!(new_shape.len() <= 6, "max 6 dimensions");
92        let mut dims = [0u32; 6];
93        dims[..new_shape.len()].copy_from_slice(new_shape);
94        Ok(Self { buf: self.buf, dims, ndim: new_shape.len() as u8 })
95    }
96
97    /// Get a single dimension size.
98    #[inline]
99    pub fn dim(&self, i: usize) -> u32 {
100        self.dims[i]
101    }
102}
103
104#[cfg(test)]
105mod tests {
106    use super::*;
107
108    fn dev() -> &'static GpuDevice { &crate::ops::TEST_DEV }
109
110    #[test]
111    fn test_tensor_new() {
112        let t = Tensor::new(dev(), &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
113        assert_eq!(t.shape(), &[2, 3]);
114        assert_eq!(t.ndim(), 2);
115        assert_eq!(t.numel(), 6);
116    }
117
118    #[test]
119    fn test_tensor_readback() {
120        let data = vec![1.0, 2.0, 3.0];
121        let t = Tensor::new(dev(), &data, &[3]).unwrap();
122        assert_eq!(t.to_vec(dev()).unwrap(), data);
123    }
124
125    #[test]
126    fn test_tensor_reshape() {
127        let t = Tensor::new(dev(), &[1.0; 12], &[3, 4]).unwrap();
128        let t2 = t.reshape(&[2, 6]).unwrap();
129        assert_eq!(t2.shape(), &[2, 6]);
130        assert_eq!(t2.numel(), 12);
131    }
132
133    #[test]
134    fn test_tensor_reshape_mismatch() {
135        let t = Tensor::new(dev(), &[1.0; 12], &[3, 4]).unwrap();
136        assert!(t.reshape(&[2, 5]).is_err());
137    }
138
139    #[test]
140    fn test_tensor_shape_mismatch() {
141        assert!(Tensor::new(dev(), &[1.0, 2.0, 3.0], &[2, 2]).is_err());
142    }
143
144    #[test]
145    fn test_tensor_4d() {
146        // NCHW: batch=2, channels=3, height=4, width=5
147        let t = Tensor::new(dev(), &[0.0; 120], &[2, 3, 4, 5]).unwrap();
148        assert_eq!(t.shape(), &[2, 3, 4, 5]);
149        assert_eq!(t.ndim(), 4);
150        assert_eq!(t.numel(), 120);
151        assert_eq!(t.dim(0), 2);
152        assert_eq!(t.dim(1), 3);
153    }
154
155    #[test]
156    fn test_tensor_zeros() {
157        let t = Tensor::zeros(dev(), &[3, 3]).unwrap();
158        let data = t.to_vec(dev()).unwrap();
159        assert_eq!(data, vec![0.0; 9]);
160    }
161
162    #[test]
163    fn test_tensor_scalar() {
164        let t = Tensor::new(dev(), &[42.0], &[1]).unwrap();
165        assert_eq!(t.shape(), &[1]);
166        assert_eq!(t.to_vec(dev()).unwrap(), vec![42.0]);
167    }
168
169    #[test]
170    fn test_tensor_from_buf() {
171        let buf = dev().upload(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
172        let t = Tensor::from_buf(buf, &[2, 3]).unwrap();
173        assert_eq!(t.shape(), &[2, 3]);
174        assert_eq!(t.numel(), 6);
175        assert_eq!(t.to_vec(dev()).unwrap(), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
176    }
177
178    #[test]
179    fn test_tensor_from_buf_mismatch() {
180        let buf = dev().upload(&[1.0; 10]);
181        assert!(Tensor::from_buf(buf, &[3, 4]).is_err());
182    }
183
184    #[test]
185    fn test_tensor_buffer_access() {
186        let t = Tensor::new(dev(), &[7.0, 8.0], &[2]).unwrap();
187        let buf = t.buffer();
188        assert_eq!(buf.len, 2);
189        assert_eq!(dev().read(buf).unwrap(), vec![7.0, 8.0]);
190    }
191
192    #[test]
193    fn test_tensor_dim_all() {
194        let t = Tensor::new(dev(), &[0.0; 120], &[2, 3, 4, 5]).unwrap();
195        assert_eq!(t.dim(0), 2);
196        assert_eq!(t.dim(1), 3);
197        assert_eq!(t.dim(2), 4);
198        assert_eq!(t.dim(3), 5);
199    }
200
201    #[test]
202    fn test_tensor_6d_max() {
203        let t = Tensor::new(dev(), &[0.0; 1], &[1, 1, 1, 1, 1, 1]).unwrap();
204        assert_eq!(t.ndim(), 6);
205        assert_eq!(t.shape(), &[1, 1, 1, 1, 1, 1]);
206    }
207
208    #[test]
209    fn test_tensor_7d_exceeds_max() {
210        assert!(Tensor::new(dev(), &[0.0; 1], &[1, 1, 1, 1, 1, 1, 1]).is_err());
211    }
212
213    #[test]
214    fn test_tensor_reshape_flatten_unflatten() {
215        let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
216        let t = Tensor::new(dev(), &data, &[2, 3, 4]).unwrap();
217        let flat = t.reshape(&[24]).unwrap();
218        assert_eq!(flat.shape(), &[24]);
219        let back = flat.reshape(&[2, 3, 4]).unwrap();
220        assert_eq!(back.shape(), &[2, 3, 4]);
221        assert_eq!(back.to_vec(dev()).unwrap(), data);
222    }
223
224    #[test]
225    fn test_tensor_reshape_7d_exceeds() {
226        let t = Tensor::new(dev(), &[0.0; 1], &[1]).unwrap();
227        assert!(t.reshape(&[1, 1, 1, 1, 1, 1, 1]).is_err());
228    }
229
230    #[test]
231    fn test_tensor_zeros_odd_dim() {
232        let t = Tensor::zeros(dev(), &[7, 13]).unwrap();
233        assert_eq!(t.numel(), 91);
234        let data = t.to_vec(dev()).unwrap();
235        assert!(data.iter().all(|&v| v == 0.0));
236    }
237}