Skip to main content

runmat_runtime/builtins/acceleration/gpu/
gather.rs

1//! MATLAB-compatible `gather` builtin with provider-aware semantics.
2
3use crate::builtins::acceleration::gpu::type_resolvers::gather_type;
4use crate::builtins::common::spec::{
5    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
6    ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
7};
8use crate::{build_runtime_error, make_cell, RuntimeError};
9use runmat_builtins::Value;
10use runmat_macros::runtime_builtin;
11
12#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::acceleration::gpu::gather")]
13pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
14    name: "gather",
15    op_kind: GpuOpKind::Custom("gather"),
16    supported_precisions: &[ScalarType::F32, ScalarType::F64],
17    broadcast: BroadcastSemantics::None,
18    provider_hooks: &[ProviderHook::Custom("download")],
19    constant_strategy: ConstantStrategy::InlineLiteral,
20    residency: ResidencyPolicy::GatherImmediately,
21    nan_mode: ReductionNaN::Include,
22    two_pass_threshold: None,
23    workgroup_size: None,
24    accepts_nan_mode: false,
25    notes: "Downloads gpuArray handles via the provider's `download` hook and clears residency metadata; host inputs pass through unchanged.",
26};
27
28#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::acceleration::gpu::gather")]
29pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
30    name: "gather",
31    shape: ShapeRequirements::Any,
32    constant_strategy: ConstantStrategy::InlineLiteral,
33    elementwise: None,
34    reduction: None,
35    emits_nan: false,
36    notes: "Acts as a residency sink for fusion planning; always materialises host data and clears gpuArray residency tracking.",
37};
38
39fn gather_error(message: impl Into<String>) -> RuntimeError {
40    build_runtime_error(message).with_builtin("gather").build()
41}
42
43#[runtime_builtin(
44    name = "gather",
45    category = "acceleration/gpu",
46    summary = "Bring gpuArray data back to host memory.",
47    keywords = "gather,gpuArray,accelerate,download",
48    accel = "sink",
49    type_resolver(gather_type),
50    builtin_path = "crate::builtins::acceleration::gpu::gather"
51)]
52async fn gather_builtin(args: Vec<Value>) -> crate::BuiltinResult<Value> {
53    let eval = evaluate(&args).await?;
54    let len = eval.len();
55    if let Some(out_count) = crate::output_count::current_output_count() {
56        if out_count == 0 {
57            return Ok(Value::OutputList(Vec::new()));
58        }
59        if len == 1 {
60            if out_count > 1 {
61                return Err(gather_error("gather: too many output arguments").into());
62            }
63            return Ok(Value::OutputList(vec![eval.into_first()]));
64        }
65        if out_count != len {
66            return Err(
67                gather_error("gather: number of outputs must match number of inputs").into(),
68            );
69        }
70        return Ok(Value::OutputList(eval.into_outputs()));
71    }
72    if len == 1 {
73        Ok(eval.into_first())
74    } else {
75        let outputs = eval.into_outputs();
76        make_cell(outputs, 1, len).map_err(|err| gather_error(err).into())
77    }
78}
79
80/// Combined gather result used by single- and multi-output call sites.
81#[derive(Debug, Clone)]
82pub struct GatherResult {
83    outputs: Vec<Value>,
84}
85
86impl GatherResult {
87    fn new(outputs: Vec<Value>) -> Self {
88        Self { outputs }
89    }
90
91    /// Number of gathered outputs.
92    pub fn len(&self) -> usize {
93        self.outputs.len()
94    }
95
96    pub fn is_empty(&self) -> bool {
97        self.outputs.is_empty()
98    }
99
100    /// Borrowed slice of outputs (in call-order).
101    pub fn outputs(&self) -> &[Value] {
102        &self.outputs
103    }
104
105    /// Consume the result, yielding all outputs.
106    pub fn into_outputs(self) -> Vec<Value> {
107        self.outputs
108    }
109
110    /// Consume the result, yielding the first output (requires at least one input).
111    pub fn into_first(self) -> Value {
112        self.outputs
113            .into_iter()
114            .next()
115            .expect("gather requires at least one input")
116    }
117}
118
119/// Evaluate `gather` for arbitrary argument lists and return all outputs.
120pub async fn evaluate(args: &[Value]) -> crate::BuiltinResult<GatherResult> {
121    if args.is_empty() {
122        return Err(gather_error("gather: not enough input arguments").into());
123    }
124    let mut outputs = Vec::with_capacity(args.len());
125    for value in args {
126        outputs.push(gather_argument(value).await?);
127    }
128    Ok(GatherResult::new(outputs))
129}
130
131async fn gather_argument(value: &Value) -> crate::BuiltinResult<Value> {
132    crate::dispatcher::gather_if_needed_async(value).await
133}
134
135#[cfg(test)]
136pub(crate) mod tests {
137    use super::*;
138    use crate::builtins::common::test_support;
139    use futures::executor::block_on;
140    use runmat_accelerate_api::HostTensorView;
141    use runmat_builtins::{CellArray, ResolveContext, StructValue, Tensor, Type};
142
143    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
144    #[test]
145    fn gather_passes_through_host_values() {
146        let value = Value::Num(42.0);
147        let result = block_on(gather_builtin(vec![value.clone()])).expect("gather");
148        assert_eq!(result, value);
149    }
150
151    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
152    #[test]
153    fn gather_downloads_gpu_tensor() {
154        test_support::with_test_provider(|provider| {
155            let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
156            let view = HostTensorView {
157                data: &tensor.data,
158                shape: &tensor.shape,
159            };
160            let handle = provider.upload(&view).expect("upload");
161            let result = block_on(gather_builtin(vec![Value::GpuTensor(handle)])).expect("gather");
162            match result {
163                Value::Tensor(host) => {
164                    assert_eq!(host.shape, tensor.shape);
165                    assert_eq!(host.data, tensor.data);
166                }
167                other => panic!("expected tensor result, got {other:?}"),
168            }
169        });
170    }
171
172    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
173    #[test]
174    fn gather_preserves_logical_gpu_tensors() {
175        test_support::with_test_provider(|provider| {
176            let data = vec![0.0, 1.0, 1.0, 0.0];
177            let tensor = Tensor::new(data.clone(), vec![2, 2]).unwrap();
178            let view = HostTensorView {
179                data: &tensor.data,
180                shape: &tensor.shape,
181            };
182            let handle = provider.upload(&view).expect("upload");
183            runmat_accelerate_api::set_handle_logical(&handle, true);
184            let result = block_on(gather_builtin(vec![Value::GpuTensor(handle)])).expect("gather");
185            match result {
186                Value::LogicalArray(logical) => {
187                    assert_eq!(logical.shape, vec![2, 2]);
188                    assert_eq!(logical.data, vec![0, 1, 1, 0]);
189                }
190                other => panic!("expected logical array, got {other:?}"),
191            }
192        });
193    }
194
195    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
196    #[test]
197    fn gather_recurses_into_cells() {
198        test_support::with_test_provider(|provider| {
199            let tensor = Tensor::new(vec![7.0, 8.0], vec![2, 1]).unwrap();
200            let view = HostTensorView {
201                data: &tensor.data,
202                shape: &tensor.shape,
203            };
204            let handle = provider.upload(&view).expect("upload");
205            let cell = CellArray::new(vec![Value::GpuTensor(handle), Value::from("host")], 1, 2)
206                .expect("cell");
207            let result = block_on(gather_builtin(vec![Value::Cell(cell)])).expect("gather");
208            let Value::Cell(gathered) = result else {
209                panic!("expected cell result");
210            };
211            let first = gathered.get(0, 0).expect("first element");
212            match first {
213                Value::Tensor(t) => {
214                    assert_eq!(t.shape, vec![2, 1]);
215                    assert_eq!(t.data, tensor.data);
216                }
217                other => panic!("expected tensor in cell, got {other:?}"),
218            }
219            let second = gathered.get(0, 1).expect("second element");
220            assert_eq!(second, Value::from("host"));
221        });
222    }
223
224    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
225    #[test]
226    fn gather_recurses_into_structs() {
227        test_support::with_test_provider(|provider| {
228            let tensor = Tensor::new(vec![3.5, -1.25], vec![2, 1]).unwrap();
229            let view = HostTensorView {
230                data: &tensor.data,
231                shape: &tensor.shape,
232            };
233            let handle = provider.upload(&view).expect("upload");
234            let mut st = StructValue::new();
235            st.insert("data", Value::GpuTensor(handle));
236            st.insert("label", Value::from("gpu result"));
237
238            let result = block_on(gather_builtin(vec![Value::Struct(st)])).expect("gather");
239            let Value::Struct(gathered) = result else {
240                panic!("expected struct result");
241            };
242            let Some(Value::Tensor(host)) = gathered.fields.get("data") else {
243                panic!("missing tensor field");
244            };
245            assert_eq!(host.shape, vec![2, 1]);
246            assert_eq!(host.data, tensor.data);
247            let Some(Value::String(label)) = gathered.fields.get("label") else {
248                panic!("missing label");
249            };
250            assert_eq!(label, "gpu result");
251        });
252    }
253
254    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
255    #[test]
256    fn gather_returns_cell_for_multiple_inputs() {
257        let result = block_on(gather_builtin(vec![Value::Num(1.0), Value::from("two")]))
258            .expect("gather cell");
259        let Value::Cell(cell) = result else {
260            panic!("expected cell for multiple inputs");
261        };
262        assert_eq!(cell.rows, 1);
263        assert_eq!(cell.cols, 2);
264        assert_eq!(cell.get(0, 0).unwrap(), Value::Num(1.0));
265        assert_eq!(cell.get(0, 1).unwrap(), Value::from("two"));
266    }
267
268    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
269    #[test]
270    fn evaluate_returns_outputs_in_order() {
271        let eval = block_on(evaluate(&[
272            Value::Num(5.0),
273            Value::Bool(true),
274            Value::from("hello"),
275        ]))
276        .expect("eval");
277        assert_eq!(eval.len(), 3);
278        assert_eq!(eval.outputs()[0], Value::Num(5.0));
279        assert_eq!(eval.outputs()[1], Value::Bool(true));
280        assert_eq!(eval.outputs()[2], Value::from("hello"));
281    }
282
283    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
284    #[test]
285    fn gather_requires_at_least_one_argument() {
286        let err = block_on(gather_builtin(Vec::new())).expect_err("expected error");
287        assert_eq!(err.to_string(), "gather: not enough input arguments");
288    }
289
290    #[test]
291    fn gather_type_resolves_multiple_outputs_to_cell() {
292        assert_eq!(
293            gather_type(&[Type::Num, Type::String], &ResolveContext::new(Vec::new())),
294            Type::cell()
295        );
296    }
297
298    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
299    #[test]
300    #[cfg(feature = "wgpu")]
301    fn gather_wgpu_provider_roundtrip() {
302        use runmat_accelerate_api::AccelProvider;
303
304        match runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
305            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
306        ) {
307            Ok(provider) => {
308                let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
309                let view = HostTensorView {
310                    data: &tensor.data,
311                    shape: &tensor.shape,
312                };
313                let handle = provider.upload(&view).expect("upload");
314                let eval =
315                    block_on(evaluate(&[Value::GpuTensor(handle.clone())])).expect("evaluate");
316                let outputs = eval.into_outputs();
317                assert_eq!(outputs.len(), 1);
318                match outputs.into_iter().next().unwrap() {
319                    Value::Tensor(host) => {
320                        assert_eq!(host.shape, tensor.shape);
321                        assert_eq!(host.data, tensor.data);
322                    }
323                    other => panic!("expected tensor value, got {other:?}"),
324                }
325                let _ = provider.free(&handle);
326            }
327            Err(err) => {
328                tracing::warn!("Skipping gather_wgpu_provider_roundtrip: {err}");
329            }
330        }
331        // Restore the simple provider so subsequent tests see a predictable backend.
332        runmat_accelerate::simple_provider::register_inprocess_provider();
333    }
334}