1use crate::device::{GpuBuffer, GpuDevice};
7use anyhow::{ensure, Result};
8
9#[repr(C)]
12#[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
13struct ConcatParams {
14 n: u32,
15 outer: u32,
16 a_inner: u32,
17 b_inner: u32,
18}
19
20const SHADER_CONCAT: &str = "
21struct P { n: u32, outer: u32, a_inner: u32, b_inner: u32, }
22@group(0) @binding(0) var<uniform> p: P;
23@group(0) @binding(1) var<storage, read> a: array<f32>;
24@group(0) @binding(2) var<storage, read> b: array<f32>;
25@group(0) @binding(3) var<storage, read_write> out: array<f32>;
26@compute @workgroup_size(256)
27fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
28 let idx = gid.x + gid.y * 65535u * 256u;
29 if idx >= p.n { return; }
30
31 let combined = p.a_inner + p.b_inner;
32 let outer_idx = idx / combined;
33 let inner_idx = idx % combined;
34
35 if inner_idx < p.a_inner {
36 out[idx] = a[outer_idx * p.a_inner + inner_idx];
37 } else {
38 out[idx] = b[outer_idx * p.b_inner + (inner_idx - p.a_inner)];
39 }
40}
41";
42
43#[repr(C)]
46#[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
47struct TransposeParams {
48 n: u32,
49 d0: u32,
50 d1: u32,
51 inner: u32,
52 outer_stride: u32,
53 _pad: [u32; 3],
54}
55
56const SHADER_TRANSPOSE: &str = "
57struct P { n: u32, d0: u32, d1: u32, inner: u32, outer_stride: u32, _p0: u32, _p1: u32, _p2: u32, }
58@group(0) @binding(0) var<uniform> p: P;
59@group(0) @binding(1) var<storage, read> a: array<f32>;
60@group(0) @binding(2) var<storage, read_write> out: array<f32>;
61@compute @workgroup_size(256)
62fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
63 let idx = gid.x + gid.y * 65535u * 256u;
64 if idx >= p.n { return; }
65
66 // Decompose output index: outer * (d1 * d0 * inner) + i1 * (d0 * inner) + i0 * inner + inner_idx
67 let block = p.d0 * p.d1 * p.inner;
68 let outer = idx / block;
69 let rem = idx % block;
70 let i1 = rem / (p.d0 * p.inner);
71 let rem2 = rem % (p.d0 * p.inner);
72 let i0 = rem2 / p.inner;
73 let inner_idx = rem2 % p.inner;
74
75 // In input, dims are (d0, d1) so source index swaps i0 and i1
76 let src = outer * block + i0 * (p.d1 * p.inner) + i1 * p.inner + inner_idx;
77 out[idx] = a[src];
78}
79";
80
81impl GpuDevice {
82 pub fn concat(
87 &self,
88 a: &GpuBuffer, b: &GpuBuffer,
89 outer_size: u32, a_inner: u32, b_inner: u32,
90 ) -> Result<GpuBuffer> {
91 ensure!(a.len == (outer_size * a_inner) as usize);
92 ensure!(b.len == (outer_size * b_inner) as usize);
93 let total = outer_size * (a_inner + b_inner);
94 let out = self.alloc(total as usize);
95 let params = ConcatParams { n: total, outer: outer_size, a_inner, b_inner };
96 self.dispatch_shader(
97 SHADER_CONCAT, Some("concat"),
98 ¶ms, &[a, b], &out,
99 super::dispatch_1d(total),
100 );
101 Ok(out)
102 }
103
104 pub fn transpose(
109 &self,
110 a: &GpuBuffer,
111 outer_size: u32, d0: u32, d1: u32, inner: u32,
112 ) -> Result<GpuBuffer> {
113 let total = outer_size * d0 * d1 * inner;
114 ensure!(a.len == total as usize);
115 let out = self.alloc(total as usize);
116 let params = TransposeParams {
117 n: total, d0, d1, inner,
118 outer_stride: d0 * d1 * inner, _pad: [0; 3],
119 };
120 self.dispatch_shader(
121 SHADER_TRANSPOSE, Some("transpose"),
122 ¶ms, &[a], &out,
123 super::dispatch_1d(total),
124 );
125 Ok(out)
126 }
127}
128
129#[cfg(test)]
130mod tests {
131 use super::*;
132 fn dev() -> &'static GpuDevice { &crate::ops::TEST_DEV }
133
134 #[test]
135 fn test_concat_flat() {
136 let result = dev().read(&dev().concat(&dev().upload(&[1.0, 2.0, 3.0]), &dev().upload(&[4.0, 5.0, 6.0]), 1, 3, 3).unwrap()).unwrap();
137 assert_eq!(result, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
138 }
139
140 #[test]
141 fn test_concat_asymmetric() {
142 let a = dev().upload(&[1.0, 2.0]);
144 let b = dev().upload(&[3.0, 4.0, 5.0]);
145 let result = dev().read(&dev().concat(&a, &b, 1, 2, 3).unwrap()).unwrap();
146 assert_eq!(result, vec![1.0, 2.0, 3.0, 4.0, 5.0]);
147 }
148
149 #[test]
150 fn test_concat_batched_channel_axis() {
151 let a = dev().upload(&[10.0, 20.0, 30.0, 40.0]);
155 let b = dev().upload(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
156 let result = dev().read(&dev().concat(&a, &b, 2, 2, 4).unwrap()).unwrap();
158 assert_eq!(result, vec![10.0, 20.0, 1.0, 2.0, 3.0, 4.0, 30.0, 40.0, 5.0, 6.0, 7.0, 8.0]);
160 }
161
162 #[test]
163 fn test_transpose_2d() {
164 let a = dev().upload(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
165 let result = dev().read(&dev().transpose(&a, 1, 2, 3, 1).unwrap()).unwrap();
166 assert_eq!(result, vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
167 }
168
169 #[test]
170 fn test_transpose_square() {
171 let a = dev().upload(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]);
173 let result = dev().read(&dev().transpose(&a, 1, 3, 3, 1).unwrap()).unwrap();
174 assert_eq!(result, vec![1.0, 4.0, 7.0, 2.0, 5.0, 8.0, 3.0, 6.0, 9.0]);
175 }
176
177 #[test]
178 fn test_transpose_batched() {
179 let a = dev().upload(&[
180 1.0, 2.0, 3.0, 4.0, 5.0, 6.0,
181 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
182 ]);
183 let result = dev().read(&dev().transpose(&a, 2, 2, 3, 1).unwrap()).unwrap();
184 assert_eq!(result, vec![
185 1.0, 4.0, 2.0, 5.0, 3.0, 6.0,
186 7.0, 10.0, 8.0, 11.0, 9.0, 12.0,
187 ]);
188 }
189
190 #[test]
191 fn test_transpose_1x_n() {
192 let a = dev().upload(&[1.0, 2.0, 3.0, 4.0, 5.0]);
194 let result = dev().read(&dev().transpose(&a, 1, 1, 5, 1).unwrap()).unwrap();
195 assert_eq!(result, vec![1.0, 2.0, 3.0, 4.0, 5.0]);
197 }
198
199 #[test]
200 fn test_transpose_roundtrip() {
201 let data: Vec<f32> = (0..20).map(|i| i as f32).collect(); let t1 = dev().transpose(&dev().upload(&data), 1, 4, 5, 1).unwrap();
204 let t2 = dev().transpose(&t1, 1, 5, 4, 1).unwrap();
205 let result = dev().read(&t2).unwrap();
206 assert_eq!(result, data);
207 }
208}