Skip to main content

runmat_runtime/builtins/math/linalg/structure/
bandwidth.rs

1//! MATLAB-compatible `bandwidth` builtin with GPU-aware semantics for RunMat.
2
3use log::debug;
4use runmat_accelerate_api::{self, GpuTensorHandle};
5use runmat_builtins::{ComplexTensor, LogicalArray, Tensor, Value};
6use runmat_macros::runtime_builtin;
7
8use crate::builtins::common::spec::{
9    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
10    ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
11};
12use crate::builtins::common::{gpu_helpers, tensor};
13use crate::builtins::math::linalg::type_resolvers::bandwidth_type;
14use crate::{build_runtime_error, BuiltinResult, RuntimeError};
15
16#[runmat_macros::register_gpu_spec(
17    builtin_path = "crate::builtins::math::linalg::structure::bandwidth"
18)]
19pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
20    name: "bandwidth",
21    op_kind: GpuOpKind::Custom("structure_analysis"),
22    supported_precisions: &[ScalarType::F32, ScalarType::F64],
23    broadcast: BroadcastSemantics::None,
24    provider_hooks: &[ProviderHook::Custom("bandwidth")],
25    constant_strategy: ConstantStrategy::InlineLiteral,
26    residency: ResidencyPolicy::GatherImmediately,
27    nan_mode: ReductionNaN::Include,
28    two_pass_threshold: None,
29    workgroup_size: None,
30    accepts_nan_mode: false,
31    notes:
32        "WGPU providers compute bandwidth on-device when available; runtimes gather to the host as a fallback when providers lack the hook.",
33};
34
35#[runmat_macros::register_fusion_spec(
36    builtin_path = "crate::builtins::math::linalg::structure::bandwidth"
37)]
38pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
39    name: "bandwidth",
40    shape: ShapeRequirements::Any,
41    constant_strategy: ConstantStrategy::InlineLiteral,
42    elementwise: None,
43    reduction: None,
44    emits_nan: false,
45    notes: "Structure query that returns a small host tensor; fusion treats it as a metadata operation.",
46};
47
48const BUILTIN_NAME: &str = "bandwidth";
49
50fn runtime_error(name: &str, message: impl Into<String>) -> RuntimeError {
51    build_runtime_error(message).with_builtin(name).build()
52}
53
54#[derive(Debug, Clone, Copy, PartialEq, Eq)]
55enum BandSelector {
56    Both,
57    Lower,
58    Upper,
59}
60
61#[runtime_builtin(
62    name = "bandwidth",
63    category = "math/linalg/structure",
64    summary = "Compute the lower and upper bandwidth of a matrix.",
65    keywords = "bandwidth,lower bandwidth,upper bandwidth,structure,gpu",
66    accel = "structure",
67    type_resolver(bandwidth_type),
68    builtin_path = "crate::builtins::math::linalg::structure::bandwidth"
69)]
70async fn bandwidth_builtin(matrix: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
71    let selector = parse_selector(&rest)?;
72    let data = MatrixData::from_value(matrix)?;
73    let (lower, upper) = data.bandwidth().await?;
74    match selector {
75        BandSelector::Both => {
76            let tensor = Tensor::new(vec![lower as f64, upper as f64], vec![1, 2])
77                .map_err(|e| runtime_error(BUILTIN_NAME, format!("{BUILTIN_NAME}: {e}")))?;
78            Ok(Value::Tensor(tensor))
79        }
80        BandSelector::Lower => Ok(Value::Num(lower as f64)),
81        BandSelector::Upper => Ok(Value::Num(upper as f64)),
82    }
83}
84
85fn parse_selector(args: &[Value]) -> BuiltinResult<BandSelector> {
86    match args.len() {
87        0 => Ok(BandSelector::Both),
88        1 => {
89            let text = tensor::value_to_string(&args[0]).ok_or_else(|| {
90                runtime_error(
91                    BUILTIN_NAME,
92                    "bandwidth: selector must be a character vector or string scalar",
93                )
94            })?;
95            let trimmed = text.trim();
96            let lowered = trimmed.to_ascii_lowercase();
97            match lowered.as_str() {
98                "lower" => Ok(BandSelector::Lower),
99                "upper" => Ok(BandSelector::Upper),
100                other => Err(runtime_error(
101                    BUILTIN_NAME,
102                    format!(
103                        "bandwidth: unrecognized selector '{other}'; expected 'lower' or 'upper'"
104                    ),
105                )),
106            }
107        }
108        _ => Err(runtime_error(
109            BUILTIN_NAME,
110            "bandwidth: too many input arguments",
111        )),
112    }
113}
114
115fn value_into_tensor_for(name: &str, value: Value) -> BuiltinResult<Tensor> {
116    match value {
117        Value::Tensor(t) => Ok(t),
118        Value::LogicalArray(logical) => logical_to_tensor(name, &logical),
119        Value::Num(n) => Tensor::new(vec![n], vec![1, 1])
120            .map_err(|e| runtime_error(name, format!("{name}: {e}"))),
121        Value::Int(i) => Tensor::new(vec![i.to_f64()], vec![1, 1])
122            .map_err(|e| runtime_error(name, format!("{name}: {e}"))),
123        Value::Bool(b) => Tensor::new(vec![if b { 1.0 } else { 0.0 }], vec![1, 1])
124            .map_err(|e| runtime_error(name, format!("{name}: {e}"))),
125        other => Err(runtime_error(
126            name,
127            format!(
128                "{name}: unsupported input type {:?}; expected numeric or logical values",
129                other
130            ),
131        )),
132    }
133}
134
135fn logical_to_tensor(name: &str, logical: &LogicalArray) -> BuiltinResult<Tensor> {
136    let data: Vec<f64> = logical
137        .data
138        .iter()
139        .map(|&b| if b != 0 { 1.0 } else { 0.0 })
140        .collect();
141    Tensor::new(data, logical.shape.clone())
142        .map_err(|e| runtime_error(name, format!("{name}: {e}")))
143}
144
145enum MatrixData {
146    Real(Tensor),
147    Complex(ComplexTensor),
148    Gpu(GpuTensorHandle),
149}
150
151impl MatrixData {
152    fn from_value(value: Value) -> BuiltinResult<Self> {
153        match value {
154            Value::ComplexTensor(ct) => Ok(Self::Complex(ct)),
155            Value::Complex(re, im) => {
156                let tensor = ComplexTensor::new(vec![(re, im)], vec![1, 1])
157                    .map_err(|e| runtime_error(BUILTIN_NAME, format!("{BUILTIN_NAME}: {e}")))?;
158                Ok(Self::Complex(tensor))
159            }
160            Value::GpuTensor(handle) => Ok(Self::Gpu(handle)),
161            other => {
162                let tensor = value_into_tensor_for(BUILTIN_NAME, other)?;
163                Ok(Self::Real(tensor))
164            }
165        }
166    }
167
168    async fn bandwidth(&self) -> BuiltinResult<(usize, usize)> {
169        match self {
170            MatrixData::Real(tensor) => bandwidth_host_real_tensor(tensor),
171            MatrixData::Complex(tensor) => bandwidth_host_complex_tensor(tensor),
172            MatrixData::Gpu(handle) => bandwidth_gpu(handle).await,
173        }
174    }
175}
176
177async fn bandwidth_gpu(handle: &GpuTensorHandle) -> BuiltinResult<(usize, usize)> {
178    let (rows, cols) = ensure_matrix_shape(&handle.shape)?;
179    if rows == 0 || cols == 0 {
180        return Ok((0, 0));
181    }
182    if let Some(provider) = runmat_accelerate_api::provider() {
183        match provider.bandwidth(handle) {
184            Ok(result) => {
185                let lower = result.lower as usize;
186                let upper = result.upper as usize;
187                return Ok((lower, upper));
188            }
189            Err(err) => {
190                debug!("bandwidth: provider bandwidth fallback: {err}");
191            }
192        }
193    }
194    let tensor = gpu_helpers::gather_tensor_async(handle).await?;
195    bandwidth_host_real_tensor(&tensor)
196}
197
198pub fn ensure_matrix_shape(shape: &[usize]) -> BuiltinResult<(usize, usize)> {
199    match shape.len() {
200        0 => Ok((1, 1)),
201        1 => Ok((1, shape[0])),
202        _ => {
203            if shape[2..].iter().any(|&dim| dim > 1) {
204                Err(runtime_error(
205                    BUILTIN_NAME,
206                    "bandwidth: input must be a 2-D matrix",
207                ))
208            } else {
209                Ok((shape[0], shape[1]))
210            }
211        }
212    }
213}
214
215pub fn bandwidth_host_real_data(shape: &[usize], data: &[f64]) -> BuiltinResult<(usize, usize)> {
216    let (rows, cols) = ensure_matrix_shape(shape)?;
217    Ok(compute_real_bandwidth(rows, cols, data))
218}
219
220pub fn bandwidth_host_complex_data(
221    shape: &[usize],
222    data: &[(f64, f64)],
223) -> BuiltinResult<(usize, usize)> {
224    let (rows, cols) = ensure_matrix_shape(shape)?;
225    Ok(compute_complex_bandwidth(rows, cols, data))
226}
227
228pub fn bandwidth_host_real_tensor(tensor: &Tensor) -> BuiltinResult<(usize, usize)> {
229    bandwidth_host_real_data(&tensor.shape, &tensor.data)
230}
231
232pub fn bandwidth_host_complex_tensor(tensor: &ComplexTensor) -> BuiltinResult<(usize, usize)> {
233    bandwidth_host_complex_data(&tensor.shape, &tensor.data)
234}
235
236fn compute_real_bandwidth(rows: usize, cols: usize, data: &[f64]) -> (usize, usize) {
237    if rows == 0 || cols == 0 {
238        return (0, 0);
239    }
240    let mut lower = 0usize;
241    let mut upper = 0usize;
242    let stride = rows;
243    for col in 0..cols {
244        for row in 0..rows {
245            let idx = row + col * stride;
246            if idx >= data.len() {
247                break;
248            }
249            let value = data[idx];
250            if value != 0.0 || value.is_nan() {
251                if row >= col {
252                    lower = lower.max(row - col);
253                } else {
254                    upper = upper.max(col - row);
255                }
256            }
257        }
258    }
259    (lower, upper)
260}
261
262fn compute_complex_bandwidth(rows: usize, cols: usize, data: &[(f64, f64)]) -> (usize, usize) {
263    if rows == 0 || cols == 0 {
264        return (0, 0);
265    }
266    let mut lower = 0usize;
267    let mut upper = 0usize;
268    let stride = rows;
269    for col in 0..cols {
270        for row in 0..rows {
271            let idx = row + col * stride;
272            if idx >= data.len() {
273                break;
274            }
275            let (re, im) = data[idx];
276            if !(re == 0.0 && im == 0.0) {
277                if row >= col {
278                    lower = lower.max(row - col);
279                } else {
280                    upper = upper.max(col - row);
281                }
282            }
283        }
284    }
285    (lower, upper)
286}
287
288#[cfg(test)]
289pub(crate) mod tests {
290    use super::*;
291    use crate::builtins::common::test_support;
292    use futures::executor::block_on;
293    use runmat_builtins::{LogicalArray, ResolveContext, Type};
294
295    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
296    #[test]
297    fn bandwidth_diagonal_matrix() {
298        let tensor = Tensor::new(vec![1.0, 0.0, 0.0, 1.0], vec![2, 2]).unwrap();
299        let value = Value::Tensor(tensor);
300        let result = bandwidth_builtin(value, Vec::new()).expect("bandwidth");
301        match result {
302            Value::Tensor(t) => {
303                assert_eq!(t.shape, vec![1, 2]);
304                assert_eq!(t.data, vec![0.0, 0.0]);
305            }
306            other => panic!("expected tensor result, got {other:?}"),
307        }
308    }
309
310    #[test]
311    fn bandwidth_type_defaults_to_two_element_tensor() {
312        let out = bandwidth_type(
313            &[Type::Tensor {
314                shape: Some(vec![Some(3), Some(3)]),
315            }],
316            &ResolveContext::new(Vec::new()),
317        );
318        assert_eq!(
319            out,
320            Type::Tensor {
321                shape: Some(vec![Some(1), Some(2)])
322            }
323        );
324    }
325
326    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
327    #[test]
328    fn bandwidth_lower_selector() {
329        let tensor = Tensor::new(
330            vec![1.0, 2.0, 3.0, 0.0, 1.0, 4.0, 0.0, 0.0, 1.0],
331            vec![3, 3],
332        )
333        .unwrap();
334        let args = vec![Value::from("lower")];
335        let result = bandwidth_builtin(Value::Tensor(tensor), args).expect("bandwidth");
336        match result {
337            Value::Num(n) => assert_eq!(n, 2.0),
338            other => panic!("expected scalar result, got {other:?}"),
339        }
340    }
341
342    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
343    #[test]
344    fn bandwidth_upper_selector() {
345        let tensor = Tensor::new(
346            vec![1.0, 0.0, 0.0, 2.0, 4.0, 0.0, 3.0, 5.0, 6.0],
347            vec![3, 3],
348        )
349        .unwrap();
350        let args = vec![Value::from("upper")];
351        let result = bandwidth_builtin(Value::Tensor(tensor), args).expect("bandwidth");
352        match result {
353            Value::Num(n) => assert_eq!(n, 2.0),
354            other => panic!("expected scalar result, got {other:?}"),
355        }
356    }
357
358    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
359    #[test]
360    fn bandwidth_complex_matrix() {
361        let data = vec![(0.0, 0.0), (1.0, 0.0), (0.0, 2.0), (0.0, 0.0)];
362        let tensor = ComplexTensor::new(data, vec![2, 2]).unwrap();
363        let result =
364            bandwidth_builtin(Value::ComplexTensor(tensor), Vec::new()).expect("bandwidth");
365        match result {
366            Value::Tensor(t) => {
367                assert_eq!(t.data, vec![1.0, 1.0]);
368            }
369            other => panic!("expected tensor result, got {other:?}"),
370        }
371    }
372
373    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
374    #[test]
375    fn bandwidth_rectangular_matrix() {
376        let tensor = Tensor::new(
377            vec![0.0, 8.0, 0.0, 0.0, 0.0, 0.0, 9.0, 0.0, 7.0, 0.0, 0.0, 10.0],
378            vec![4, 3],
379        )
380        .unwrap();
381        let result = bandwidth_builtin(Value::Tensor(tensor), Vec::new()).expect("bandwidth");
382        match result {
383            Value::Tensor(t) => assert_eq!(t.data, vec![1.0, 2.0]),
384            other => panic!("expected tensor result, got {other:?}"),
385        }
386    }
387
388    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
389    #[test]
390    fn bandwidth_empty_matrix_returns_zero() {
391        let tensor = Tensor::new(Vec::new(), vec![0, 0]).unwrap();
392        let result = bandwidth_builtin(Value::Tensor(tensor), Vec::new()).expect("bandwidth");
393        match result {
394            Value::Tensor(t) => assert_eq!(t.data, vec![0.0, 0.0]),
395            other => panic!("expected tensor result, got {other:?}"),
396        }
397    }
398
399    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
400    #[test]
401    fn bandwidth_nan_counts_as_nonzero() {
402        let tensor =
403            Tensor::new(vec![0.0, f64::NAN, 0.0, 0.0], vec![2, 2]).expect("tensor construction");
404        let result = bandwidth_builtin(Value::Tensor(tensor), Vec::new()).expect("bandwidth");
405        match result {
406            Value::Tensor(t) => assert_eq!(t.data, vec![1.0, 0.0]),
407            other => panic!("expected tensor result, got {other:?}"),
408        }
409    }
410
411    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
412    #[test]
413    fn bandwidth_logical_input_supported() {
414        let logical = LogicalArray::new(vec![1, 1, 1, 0], vec![2, 2]).expect("logical array");
415        let result =
416            bandwidth_builtin(Value::LogicalArray(logical), Vec::new()).expect("bandwidth");
417        match result {
418            Value::Tensor(t) => assert_eq!(t.data, vec![1.0, 1.0]),
419            other => panic!("expected tensor result, got {other:?}"),
420        }
421    }
422
423    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
424    #[test]
425    fn bandwidth_selector_validation() {
426        let tensor = Tensor::new(vec![1.0], vec![1, 1]).unwrap();
427        let err =
428            bandwidth_builtin(Value::Tensor(tensor), vec![Value::from("middle")]).unwrap_err();
429        let message = err.to_string();
430        assert!(
431            message.contains("lower") && message.contains("upper"),
432            "unexpected error: {message}"
433        );
434    }
435
436    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
437    #[test]
438    fn bandwidth_rejects_higher_dimensions() {
439        let tensor = Tensor::new(vec![1.0, 2.0], vec![1, 1, 2]).unwrap();
440        let err = bandwidth_builtin(Value::Tensor(tensor), Vec::new()).unwrap_err();
441        let message = err.to_string();
442        assert!(
443            message.contains("2-D"),
444            "unexpected error message: {message}"
445        );
446    }
447
448    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
449    #[test]
450    fn bandwidth_gpu_roundtrip() {
451        test_support::with_test_provider(|provider| {
452            let tensor = Tensor::new(vec![0.0, 2.0, 0.0, 0.0], vec![2, 2]).unwrap();
453            let view = runmat_accelerate_api::HostTensorView {
454                data: &tensor.data,
455                shape: &tensor.shape,
456            };
457            let handle = provider.upload(&view).expect("upload");
458            let result =
459                bandwidth_builtin(Value::GpuTensor(handle), Vec::new()).expect("bandwidth");
460            let gathered = test_support::gather(result).expect("gather");
461            assert_eq!(gathered.shape, vec![1, 2]);
462            assert_eq!(gathered.data, vec![1.0, 0.0]);
463        });
464    }
465
466    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
467    #[test]
468    #[cfg(feature = "wgpu")]
469    fn bandwidth_wgpu_matches_cpu() {
470        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
471            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
472        );
473        let Some(provider) = runmat_accelerate_api::provider() else {
474            return;
475        };
476        let tensor = Tensor::new(
477            vec![0.0, 2.0, 0.0, 0.0, 0.0, 4.0, 5.0, 0.0, 6.0],
478            vec![3, 3],
479        )
480        .unwrap();
481        let cpu = super::bandwidth_host_real_tensor(&tensor).expect("cpu bandwidth");
482        let view = runmat_accelerate_api::HostTensorView {
483            data: &tensor.data,
484            shape: &tensor.shape,
485        };
486        let handle = provider.upload(&view).expect("upload");
487        let gpu_meta = provider.bandwidth(&handle).expect("provider bandwidth");
488        assert_eq!(gpu_meta.lower as usize, cpu.0);
489        assert_eq!(gpu_meta.upper as usize, cpu.1);
490
491        let result =
492            bandwidth_builtin(Value::GpuTensor(handle.clone()), Vec::new()).expect("bandwidth");
493        let gathered = test_support::gather(result).expect("gather");
494        assert_eq!(gathered.shape, vec![1, 2]);
495        assert_eq!(gathered.data, vec![cpu.0 as f64, cpu.1 as f64]);
496        let _ = provider.free(&handle);
497    }
498
499    fn bandwidth_builtin(matrix: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
500        block_on(super::bandwidth_builtin(matrix, rest))
501    }
502}