Skip to main content

any_gpu/ops/
tensor_ops.rs

1// Unlicense — cochranblock.org
2// Contributors: GotEmCoach, KOVA, Claude Opus 4.6
3//
4// Tensor manipulation: concat, transpose.
5
6use crate::device::{GpuBuffer, GpuDevice};
7use anyhow::{ensure, Result};
8
9// --- Concat ---
10
11#[repr(C)]
12#[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
13struct ConcatParams {
14    n: u32,
15    outer: u32,
16    a_inner: u32,
17    b_inner: u32,
18}
19
20const SHADER_CONCAT: &str = "
21struct P { n: u32, outer: u32, a_inner: u32, b_inner: u32, }
22@group(0) @binding(0) var<uniform> p: P;
23@group(0) @binding(1) var<storage, read> a: array<f32>;
24@group(0) @binding(2) var<storage, read> b: array<f32>;
25@group(0) @binding(3) var<storage, read_write> out: array<f32>;
26@compute @workgroup_size(256)
27fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
28    let idx = gid.x + gid.y * 65535u * 256u;
29    if idx >= p.n { return; }
30
31    let combined = p.a_inner + p.b_inner;
32    let outer_idx = idx / combined;
33    let inner_idx = idx % combined;
34
35    if inner_idx < p.a_inner {
36        out[idx] = a[outer_idx * p.a_inner + inner_idx];
37    } else {
38        out[idx] = b[outer_idx * p.b_inner + (inner_idx - p.a_inner)];
39    }
40}
41";
42
43// --- Transpose (swap two dims) ---
44
45#[repr(C)]
46#[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
47struct TransposeParams {
48    n: u32,
49    d0: u32,
50    d1: u32,
51    inner: u32,
52    outer_stride: u32,
53    _pad: [u32; 3],
54}
55
56const SHADER_TRANSPOSE: &str = "
57struct P { n: u32, d0: u32, d1: u32, inner: u32, outer_stride: u32, _p0: u32, _p1: u32, _p2: u32, }
58@group(0) @binding(0) var<uniform> p: P;
59@group(0) @binding(1) var<storage, read> a: array<f32>;
60@group(0) @binding(2) var<storage, read_write> out: array<f32>;
61@compute @workgroup_size(256)
62fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
63    let idx = gid.x + gid.y * 65535u * 256u;
64    if idx >= p.n { return; }
65
66    // Decompose output index: outer * (d1 * d0 * inner) + i1 * (d0 * inner) + i0 * inner + inner_idx
67    let block = p.d0 * p.d1 * p.inner;
68    let outer = idx / block;
69    let rem = idx % block;
70    let i1 = rem / (p.d0 * p.inner);
71    let rem2 = rem % (p.d0 * p.inner);
72    let i0 = rem2 / p.inner;
73    let inner_idx = rem2 % p.inner;
74
75    // In input, dims are (d0, d1) so source index swaps i0 and i1
76    let src = outer * block + i0 * (p.d1 * p.inner) + i1 * p.inner + inner_idx;
77    out[idx] = a[src];
78}
79";
80
81impl GpuDevice {
82    /// Concat two buffers along a given axis.
83    /// `outer_size` = product of dims before concat axis.
84    /// `a_inner` = a's size along concat axis * product of dims after.
85    /// `b_inner` = same for b.
86    pub fn concat(
87        &self,
88        a: &GpuBuffer, b: &GpuBuffer,
89        outer_size: u32, a_inner: u32, b_inner: u32,
90    ) -> Result<GpuBuffer> {
91        ensure!(a.len == (outer_size * a_inner) as usize);
92        ensure!(b.len == (outer_size * b_inner) as usize);
93        let total = outer_size * (a_inner + b_inner);
94        let out = self.alloc(total as usize);
95        let params = ConcatParams { n: total, outer: outer_size, a_inner, b_inner };
96        self.dispatch_shader(
97            SHADER_CONCAT, Some("concat"),
98            &params, &[a, b], &out,
99            super::dispatch_1d(total),
100        );
101        Ok(out)
102    }
103
104    /// Transpose two dimensions of a tensor.
105    /// Shape is [..., d0, d1, ...inner_dims].
106    /// `outer_size` = product of dims before d0.
107    /// `inner` = product of dims after d1.
108    pub fn transpose(
109        &self,
110        a: &GpuBuffer,
111        outer_size: u32, d0: u32, d1: u32, inner: u32,
112    ) -> Result<GpuBuffer> {
113        let total = outer_size * d0 * d1 * inner;
114        ensure!(a.len == total as usize);
115        let out = self.alloc(total as usize);
116        let params = TransposeParams {
117            n: total, d0, d1, inner,
118            outer_stride: d0 * d1 * inner, _pad: [0; 3],
119        };
120        self.dispatch_shader(
121            SHADER_TRANSPOSE, Some("transpose"),
122            &params, &[a], &out,
123            super::dispatch_1d(total),
124        );
125        Ok(out)
126    }
127}
128
129#[cfg(test)]
130mod tests {
131    use super::*;
132    fn dev() -> &'static GpuDevice { &crate::ops::TEST_DEV }
133
134    #[test]
135    fn test_concat_flat() {
136        let result = dev().read(&dev().concat(&dev().upload(&[1.0, 2.0, 3.0]), &dev().upload(&[4.0, 5.0, 6.0]), 1, 3, 3).unwrap()).unwrap();
137        assert_eq!(result, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
138    }
139
140    #[test]
141    fn test_concat_asymmetric() {
142        // Different sized inner dims: a has 2 elements, b has 3 per outer block
143        let a = dev().upload(&[1.0, 2.0]);
144        let b = dev().upload(&[3.0, 4.0, 5.0]);
145        let result = dev().read(&dev().concat(&a, &b, 1, 2, 3).unwrap()).unwrap();
146        assert_eq!(result, vec![1.0, 2.0, 3.0, 4.0, 5.0]);
147    }
148
149    #[test]
150    fn test_concat_batched_channel_axis() {
151        // batch=2, concat 1-channel and 2-channel tensors along C, spatial=2
152        // a: [batch=2, c=1, spatial=2] = [10, 20, 30, 40]
153        // b: [batch=2, c=2, spatial=2] = [1, 2, 3, 4, 5, 6, 7, 8]
154        let a = dev().upload(&[10.0, 20.0, 30.0, 40.0]);
155        let b = dev().upload(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
156        // outer=batch=2, a_inner=1*2=2, b_inner=2*2=4
157        let result = dev().read(&dev().concat(&a, &b, 2, 2, 4).unwrap()).unwrap();
158        // batch 0: [10,20, 1,2,3,4], batch 1: [30,40, 5,6,7,8]
159        assert_eq!(result, vec![10.0, 20.0, 1.0, 2.0, 3.0, 4.0, 30.0, 40.0, 5.0, 6.0, 7.0, 8.0]);
160    }
161
162    #[test]
163    fn test_transpose_2d() {
164        let a = dev().upload(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
165        let result = dev().read(&dev().transpose(&a, 1, 2, 3, 1).unwrap()).unwrap();
166        assert_eq!(result, vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
167    }
168
169    #[test]
170    fn test_transpose_square() {
171        // 3x3 -> 3x3 transpose
172        let a = dev().upload(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]);
173        let result = dev().read(&dev().transpose(&a, 1, 3, 3, 1).unwrap()).unwrap();
174        assert_eq!(result, vec![1.0, 4.0, 7.0, 2.0, 5.0, 8.0, 3.0, 6.0, 9.0]);
175    }
176
177    #[test]
178    fn test_transpose_batched() {
179        let a = dev().upload(&[
180            1.0, 2.0, 3.0, 4.0, 5.0, 6.0,
181            7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
182        ]);
183        let result = dev().read(&dev().transpose(&a, 2, 2, 3, 1).unwrap()).unwrap();
184        assert_eq!(result, vec![
185            1.0, 4.0, 2.0, 5.0, 3.0, 6.0,
186            7.0, 10.0, 8.0, 11.0, 9.0, 12.0,
187        ]);
188    }
189
190    #[test]
191    fn test_transpose_1x_n() {
192        // 1xN transpose = Nx1 (column vector)
193        let a = dev().upload(&[1.0, 2.0, 3.0, 4.0, 5.0]);
194        let result = dev().read(&dev().transpose(&a, 1, 1, 5, 1).unwrap()).unwrap();
195        // 5x1 is same flat data (no-op for 1-row)
196        assert_eq!(result, vec![1.0, 2.0, 3.0, 4.0, 5.0]);
197    }
198
199    #[test]
200    fn test_transpose_roundtrip() {
201        // Transpose twice = identity
202        let data: Vec<f32> = (0..20).map(|i| i as f32).collect(); // 4x5
203        let t1 = dev().transpose(&dev().upload(&data), 1, 4, 5, 1).unwrap();
204        let t2 = dev().transpose(&t1, 1, 5, 4, 1).unwrap();
205        let result = dev().read(&t2).unwrap();
206        assert_eq!(result, data);
207    }
208}