runmat_runtime/builtins/acceleration/gpu/
gather.rs1use crate::builtins::acceleration::gpu::type_resolvers::gather_type;
4use crate::builtins::common::spec::{
5 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
6 ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
7};
8use crate::{build_runtime_error, make_cell, RuntimeError};
9use runmat_builtins::Value;
10use runmat_macros::runtime_builtin;
11
12#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::acceleration::gpu::gather")]
13pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
14 name: "gather",
15 op_kind: GpuOpKind::Custom("gather"),
16 supported_precisions: &[ScalarType::F32, ScalarType::F64],
17 broadcast: BroadcastSemantics::None,
18 provider_hooks: &[ProviderHook::Custom("download")],
19 constant_strategy: ConstantStrategy::InlineLiteral,
20 residency: ResidencyPolicy::GatherImmediately,
21 nan_mode: ReductionNaN::Include,
22 two_pass_threshold: None,
23 workgroup_size: None,
24 accepts_nan_mode: false,
25 notes: "Downloads gpuArray handles via the provider's `download` hook and clears residency metadata; host inputs pass through unchanged.",
26};
27
28#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::acceleration::gpu::gather")]
29pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
30 name: "gather",
31 shape: ShapeRequirements::Any,
32 constant_strategy: ConstantStrategy::InlineLiteral,
33 elementwise: None,
34 reduction: None,
35 emits_nan: false,
36 notes: "Acts as a residency sink for fusion planning; always materialises host data and clears gpuArray residency tracking.",
37};
38
39fn gather_error(message: impl Into<String>) -> RuntimeError {
40 build_runtime_error(message).with_builtin("gather").build()
41}
42
43#[runtime_builtin(
44 name = "gather",
45 category = "acceleration/gpu",
46 summary = "Bring gpuArray data back to host memory.",
47 keywords = "gather,gpuArray,accelerate,download",
48 accel = "sink",
49 type_resolver(gather_type),
50 builtin_path = "crate::builtins::acceleration::gpu::gather"
51)]
52async fn gather_builtin(args: Vec<Value>) -> crate::BuiltinResult<Value> {
53 let eval = evaluate(&args).await?;
54 let len = eval.len();
55 if let Some(out_count) = crate::output_count::current_output_count() {
56 if out_count == 0 {
57 return Ok(Value::OutputList(Vec::new()));
58 }
59 if len == 1 {
60 if out_count > 1 {
61 return Err(gather_error("gather: too many output arguments").into());
62 }
63 return Ok(Value::OutputList(vec![eval.into_first()]));
64 }
65 if out_count != len {
66 return Err(
67 gather_error("gather: number of outputs must match number of inputs").into(),
68 );
69 }
70 return Ok(Value::OutputList(eval.into_outputs()));
71 }
72 if len == 1 {
73 Ok(eval.into_first())
74 } else {
75 let outputs = eval.into_outputs();
76 make_cell(outputs, 1, len).map_err(|err| gather_error(err).into())
77 }
78}
79
80#[derive(Debug, Clone)]
82pub struct GatherResult {
83 outputs: Vec<Value>,
84}
85
86impl GatherResult {
87 fn new(outputs: Vec<Value>) -> Self {
88 Self { outputs }
89 }
90
91 pub fn len(&self) -> usize {
93 self.outputs.len()
94 }
95
96 pub fn is_empty(&self) -> bool {
97 self.outputs.is_empty()
98 }
99
100 pub fn outputs(&self) -> &[Value] {
102 &self.outputs
103 }
104
105 pub fn into_outputs(self) -> Vec<Value> {
107 self.outputs
108 }
109
110 pub fn into_first(self) -> Value {
112 self.outputs
113 .into_iter()
114 .next()
115 .expect("gather requires at least one input")
116 }
117}
118
119pub async fn evaluate(args: &[Value]) -> crate::BuiltinResult<GatherResult> {
121 if args.is_empty() {
122 return Err(gather_error("gather: not enough input arguments").into());
123 }
124 let mut outputs = Vec::with_capacity(args.len());
125 for value in args {
126 outputs.push(gather_argument(value).await?);
127 }
128 Ok(GatherResult::new(outputs))
129}
130
131async fn gather_argument(value: &Value) -> crate::BuiltinResult<Value> {
132 crate::dispatcher::gather_if_needed_async(value).await
133}
134
135#[cfg(test)]
136pub(crate) mod tests {
137 use super::*;
138 use crate::builtins::common::test_support;
139 use futures::executor::block_on;
140 use runmat_accelerate_api::HostTensorView;
141 use runmat_builtins::{CellArray, ResolveContext, StructValue, Tensor, Type};
142
143 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
144 #[test]
145 fn gather_passes_through_host_values() {
146 let value = Value::Num(42.0);
147 let result = block_on(gather_builtin(vec![value.clone()])).expect("gather");
148 assert_eq!(result, value);
149 }
150
151 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
152 #[test]
153 fn gather_downloads_gpu_tensor() {
154 test_support::with_test_provider(|provider| {
155 let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
156 let view = HostTensorView {
157 data: &tensor.data,
158 shape: &tensor.shape,
159 };
160 let handle = provider.upload(&view).expect("upload");
161 let result = block_on(gather_builtin(vec![Value::GpuTensor(handle)])).expect("gather");
162 match result {
163 Value::Tensor(host) => {
164 assert_eq!(host.shape, tensor.shape);
165 assert_eq!(host.data, tensor.data);
166 }
167 other => panic!("expected tensor result, got {other:?}"),
168 }
169 });
170 }
171
172 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
173 #[test]
174 fn gather_preserves_logical_gpu_tensors() {
175 test_support::with_test_provider(|provider| {
176 let data = vec![0.0, 1.0, 1.0, 0.0];
177 let tensor = Tensor::new(data.clone(), vec![2, 2]).unwrap();
178 let view = HostTensorView {
179 data: &tensor.data,
180 shape: &tensor.shape,
181 };
182 let handle = provider.upload(&view).expect("upload");
183 runmat_accelerate_api::set_handle_logical(&handle, true);
184 let result = block_on(gather_builtin(vec![Value::GpuTensor(handle)])).expect("gather");
185 match result {
186 Value::LogicalArray(logical) => {
187 assert_eq!(logical.shape, vec![2, 2]);
188 assert_eq!(logical.data, vec![0, 1, 1, 0]);
189 }
190 other => panic!("expected logical array, got {other:?}"),
191 }
192 });
193 }
194
195 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
196 #[test]
197 fn gather_recurses_into_cells() {
198 test_support::with_test_provider(|provider| {
199 let tensor = Tensor::new(vec![7.0, 8.0], vec![2, 1]).unwrap();
200 let view = HostTensorView {
201 data: &tensor.data,
202 shape: &tensor.shape,
203 };
204 let handle = provider.upload(&view).expect("upload");
205 let cell = CellArray::new(vec![Value::GpuTensor(handle), Value::from("host")], 1, 2)
206 .expect("cell");
207 let result = block_on(gather_builtin(vec![Value::Cell(cell)])).expect("gather");
208 let Value::Cell(gathered) = result else {
209 panic!("expected cell result");
210 };
211 let first = gathered.get(0, 0).expect("first element");
212 match first {
213 Value::Tensor(t) => {
214 assert_eq!(t.shape, vec![2, 1]);
215 assert_eq!(t.data, tensor.data);
216 }
217 other => panic!("expected tensor in cell, got {other:?}"),
218 }
219 let second = gathered.get(0, 1).expect("second element");
220 assert_eq!(second, Value::from("host"));
221 });
222 }
223
224 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
225 #[test]
226 fn gather_recurses_into_structs() {
227 test_support::with_test_provider(|provider| {
228 let tensor = Tensor::new(vec![3.5, -1.25], vec![2, 1]).unwrap();
229 let view = HostTensorView {
230 data: &tensor.data,
231 shape: &tensor.shape,
232 };
233 let handle = provider.upload(&view).expect("upload");
234 let mut st = StructValue::new();
235 st.insert("data", Value::GpuTensor(handle));
236 st.insert("label", Value::from("gpu result"));
237
238 let result = block_on(gather_builtin(vec![Value::Struct(st)])).expect("gather");
239 let Value::Struct(gathered) = result else {
240 panic!("expected struct result");
241 };
242 let Some(Value::Tensor(host)) = gathered.fields.get("data") else {
243 panic!("missing tensor field");
244 };
245 assert_eq!(host.shape, vec![2, 1]);
246 assert_eq!(host.data, tensor.data);
247 let Some(Value::String(label)) = gathered.fields.get("label") else {
248 panic!("missing label");
249 };
250 assert_eq!(label, "gpu result");
251 });
252 }
253
254 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
255 #[test]
256 fn gather_returns_cell_for_multiple_inputs() {
257 let result = block_on(gather_builtin(vec![Value::Num(1.0), Value::from("two")]))
258 .expect("gather cell");
259 let Value::Cell(cell) = result else {
260 panic!("expected cell for multiple inputs");
261 };
262 assert_eq!(cell.rows, 1);
263 assert_eq!(cell.cols, 2);
264 assert_eq!(cell.get(0, 0).unwrap(), Value::Num(1.0));
265 assert_eq!(cell.get(0, 1).unwrap(), Value::from("two"));
266 }
267
268 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
269 #[test]
270 fn evaluate_returns_outputs_in_order() {
271 let eval = block_on(evaluate(&[
272 Value::Num(5.0),
273 Value::Bool(true),
274 Value::from("hello"),
275 ]))
276 .expect("eval");
277 assert_eq!(eval.len(), 3);
278 assert_eq!(eval.outputs()[0], Value::Num(5.0));
279 assert_eq!(eval.outputs()[1], Value::Bool(true));
280 assert_eq!(eval.outputs()[2], Value::from("hello"));
281 }
282
283 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
284 #[test]
285 fn gather_requires_at_least_one_argument() {
286 let err = block_on(gather_builtin(Vec::new())).expect_err("expected error");
287 assert_eq!(err.to_string(), "gather: not enough input arguments");
288 }
289
290 #[test]
291 fn gather_type_resolves_multiple_outputs_to_cell() {
292 assert_eq!(
293 gather_type(&[Type::Num, Type::String], &ResolveContext::new(Vec::new())),
294 Type::cell()
295 );
296 }
297
298 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
299 #[test]
300 #[cfg(feature = "wgpu")]
301 fn gather_wgpu_provider_roundtrip() {
302 use runmat_accelerate_api::AccelProvider;
303
304 match runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
305 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
306 ) {
307 Ok(provider) => {
308 let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
309 let view = HostTensorView {
310 data: &tensor.data,
311 shape: &tensor.shape,
312 };
313 let handle = provider.upload(&view).expect("upload");
314 let eval =
315 block_on(evaluate(&[Value::GpuTensor(handle.clone())])).expect("evaluate");
316 let outputs = eval.into_outputs();
317 assert_eq!(outputs.len(), 1);
318 match outputs.into_iter().next().unwrap() {
319 Value::Tensor(host) => {
320 assert_eq!(host.shape, tensor.shape);
321 assert_eq!(host.data, tensor.data);
322 }
323 other => panic!("expected tensor value, got {other:?}"),
324 }
325 let _ = provider.free(&handle);
326 }
327 Err(err) => {
328 tracing::warn!("Skipping gather_wgpu_provider_roundtrip: {err}");
329 }
330 }
331 runmat_accelerate::simple_provider::register_inprocess_provider();
333 }
334}