1use log::debug;
4use runmat_accelerate_api::{self, GpuTensorHandle};
5use runmat_builtins::{ComplexTensor, LogicalArray, Tensor, Value};
6use runmat_macros::runtime_builtin;
7
8use crate::builtins::common::spec::{
9 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
10 ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
11};
12use crate::builtins::common::{gpu_helpers, tensor};
13use crate::builtins::math::linalg::type_resolvers::bandwidth_type;
14use crate::{build_runtime_error, BuiltinResult, RuntimeError};
15
16#[runmat_macros::register_gpu_spec(
17 builtin_path = "crate::builtins::math::linalg::structure::bandwidth"
18)]
19pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
20 name: "bandwidth",
21 op_kind: GpuOpKind::Custom("structure_analysis"),
22 supported_precisions: &[ScalarType::F32, ScalarType::F64],
23 broadcast: BroadcastSemantics::None,
24 provider_hooks: &[ProviderHook::Custom("bandwidth")],
25 constant_strategy: ConstantStrategy::InlineLiteral,
26 residency: ResidencyPolicy::GatherImmediately,
27 nan_mode: ReductionNaN::Include,
28 two_pass_threshold: None,
29 workgroup_size: None,
30 accepts_nan_mode: false,
31 notes:
32 "WGPU providers compute bandwidth on-device when available; runtimes gather to the host as a fallback when providers lack the hook.",
33};
34
35#[runmat_macros::register_fusion_spec(
36 builtin_path = "crate::builtins::math::linalg::structure::bandwidth"
37)]
38pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
39 name: "bandwidth",
40 shape: ShapeRequirements::Any,
41 constant_strategy: ConstantStrategy::InlineLiteral,
42 elementwise: None,
43 reduction: None,
44 emits_nan: false,
45 notes: "Structure query that returns a small host tensor; fusion treats it as a metadata operation.",
46};
47
48const BUILTIN_NAME: &str = "bandwidth";
49
50fn runtime_error(name: &str, message: impl Into<String>) -> RuntimeError {
51 build_runtime_error(message).with_builtin(name).build()
52}
53
54#[derive(Debug, Clone, Copy, PartialEq, Eq)]
55enum BandSelector {
56 Both,
57 Lower,
58 Upper,
59}
60
61#[runtime_builtin(
62 name = "bandwidth",
63 category = "math/linalg/structure",
64 summary = "Compute the lower and upper bandwidth of a matrix.",
65 keywords = "bandwidth,lower bandwidth,upper bandwidth,structure,gpu",
66 accel = "structure",
67 type_resolver(bandwidth_type),
68 builtin_path = "crate::builtins::math::linalg::structure::bandwidth"
69)]
70async fn bandwidth_builtin(matrix: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
71 let selector = parse_selector(&rest)?;
72 let data = MatrixData::from_value(matrix)?;
73 let (lower, upper) = data.bandwidth().await?;
74 match selector {
75 BandSelector::Both => {
76 let tensor = Tensor::new(vec![lower as f64, upper as f64], vec![1, 2])
77 .map_err(|e| runtime_error(BUILTIN_NAME, format!("{BUILTIN_NAME}: {e}")))?;
78 Ok(Value::Tensor(tensor))
79 }
80 BandSelector::Lower => Ok(Value::Num(lower as f64)),
81 BandSelector::Upper => Ok(Value::Num(upper as f64)),
82 }
83}
84
85fn parse_selector(args: &[Value]) -> BuiltinResult<BandSelector> {
86 match args.len() {
87 0 => Ok(BandSelector::Both),
88 1 => {
89 let text = tensor::value_to_string(&args[0]).ok_or_else(|| {
90 runtime_error(
91 BUILTIN_NAME,
92 "bandwidth: selector must be a character vector or string scalar",
93 )
94 })?;
95 let trimmed = text.trim();
96 let lowered = trimmed.to_ascii_lowercase();
97 match lowered.as_str() {
98 "lower" => Ok(BandSelector::Lower),
99 "upper" => Ok(BandSelector::Upper),
100 other => Err(runtime_error(
101 BUILTIN_NAME,
102 format!(
103 "bandwidth: unrecognized selector '{other}'; expected 'lower' or 'upper'"
104 ),
105 )),
106 }
107 }
108 _ => Err(runtime_error(
109 BUILTIN_NAME,
110 "bandwidth: too many input arguments",
111 )),
112 }
113}
114
115fn value_into_tensor_for(name: &str, value: Value) -> BuiltinResult<Tensor> {
116 match value {
117 Value::Tensor(t) => Ok(t),
118 Value::LogicalArray(logical) => logical_to_tensor(name, &logical),
119 Value::Num(n) => Tensor::new(vec![n], vec![1, 1])
120 .map_err(|e| runtime_error(name, format!("{name}: {e}"))),
121 Value::Int(i) => Tensor::new(vec![i.to_f64()], vec![1, 1])
122 .map_err(|e| runtime_error(name, format!("{name}: {e}"))),
123 Value::Bool(b) => Tensor::new(vec![if b { 1.0 } else { 0.0 }], vec![1, 1])
124 .map_err(|e| runtime_error(name, format!("{name}: {e}"))),
125 other => Err(runtime_error(
126 name,
127 format!(
128 "{name}: unsupported input type {:?}; expected numeric or logical values",
129 other
130 ),
131 )),
132 }
133}
134
135fn logical_to_tensor(name: &str, logical: &LogicalArray) -> BuiltinResult<Tensor> {
136 let data: Vec<f64> = logical
137 .data
138 .iter()
139 .map(|&b| if b != 0 { 1.0 } else { 0.0 })
140 .collect();
141 Tensor::new(data, logical.shape.clone())
142 .map_err(|e| runtime_error(name, format!("{name}: {e}")))
143}
144
145enum MatrixData {
146 Real(Tensor),
147 Complex(ComplexTensor),
148 Gpu(GpuTensorHandle),
149}
150
151impl MatrixData {
152 fn from_value(value: Value) -> BuiltinResult<Self> {
153 match value {
154 Value::ComplexTensor(ct) => Ok(Self::Complex(ct)),
155 Value::Complex(re, im) => {
156 let tensor = ComplexTensor::new(vec![(re, im)], vec![1, 1])
157 .map_err(|e| runtime_error(BUILTIN_NAME, format!("{BUILTIN_NAME}: {e}")))?;
158 Ok(Self::Complex(tensor))
159 }
160 Value::GpuTensor(handle) => Ok(Self::Gpu(handle)),
161 other => {
162 let tensor = value_into_tensor_for(BUILTIN_NAME, other)?;
163 Ok(Self::Real(tensor))
164 }
165 }
166 }
167
168 async fn bandwidth(&self) -> BuiltinResult<(usize, usize)> {
169 match self {
170 MatrixData::Real(tensor) => bandwidth_host_real_tensor(tensor),
171 MatrixData::Complex(tensor) => bandwidth_host_complex_tensor(tensor),
172 MatrixData::Gpu(handle) => bandwidth_gpu(handle).await,
173 }
174 }
175}
176
177async fn bandwidth_gpu(handle: &GpuTensorHandle) -> BuiltinResult<(usize, usize)> {
178 let (rows, cols) = ensure_matrix_shape(&handle.shape)?;
179 if rows == 0 || cols == 0 {
180 return Ok((0, 0));
181 }
182 if let Some(provider) = runmat_accelerate_api::provider() {
183 match provider.bandwidth(handle) {
184 Ok(result) => {
185 let lower = result.lower as usize;
186 let upper = result.upper as usize;
187 return Ok((lower, upper));
188 }
189 Err(err) => {
190 debug!("bandwidth: provider bandwidth fallback: {err}");
191 }
192 }
193 }
194 let tensor = gpu_helpers::gather_tensor_async(handle).await?;
195 bandwidth_host_real_tensor(&tensor)
196}
197
198pub fn ensure_matrix_shape(shape: &[usize]) -> BuiltinResult<(usize, usize)> {
199 match shape.len() {
200 0 => Ok((1, 1)),
201 1 => Ok((1, shape[0])),
202 _ => {
203 if shape[2..].iter().any(|&dim| dim > 1) {
204 Err(runtime_error(
205 BUILTIN_NAME,
206 "bandwidth: input must be a 2-D matrix",
207 ))
208 } else {
209 Ok((shape[0], shape[1]))
210 }
211 }
212 }
213}
214
215pub fn bandwidth_host_real_data(shape: &[usize], data: &[f64]) -> BuiltinResult<(usize, usize)> {
216 let (rows, cols) = ensure_matrix_shape(shape)?;
217 Ok(compute_real_bandwidth(rows, cols, data))
218}
219
220pub fn bandwidth_host_complex_data(
221 shape: &[usize],
222 data: &[(f64, f64)],
223) -> BuiltinResult<(usize, usize)> {
224 let (rows, cols) = ensure_matrix_shape(shape)?;
225 Ok(compute_complex_bandwidth(rows, cols, data))
226}
227
228pub fn bandwidth_host_real_tensor(tensor: &Tensor) -> BuiltinResult<(usize, usize)> {
229 bandwidth_host_real_data(&tensor.shape, &tensor.data)
230}
231
232pub fn bandwidth_host_complex_tensor(tensor: &ComplexTensor) -> BuiltinResult<(usize, usize)> {
233 bandwidth_host_complex_data(&tensor.shape, &tensor.data)
234}
235
236fn compute_real_bandwidth(rows: usize, cols: usize, data: &[f64]) -> (usize, usize) {
237 if rows == 0 || cols == 0 {
238 return (0, 0);
239 }
240 let mut lower = 0usize;
241 let mut upper = 0usize;
242 let stride = rows;
243 for col in 0..cols {
244 for row in 0..rows {
245 let idx = row + col * stride;
246 if idx >= data.len() {
247 break;
248 }
249 let value = data[idx];
250 if value != 0.0 || value.is_nan() {
251 if row >= col {
252 lower = lower.max(row - col);
253 } else {
254 upper = upper.max(col - row);
255 }
256 }
257 }
258 }
259 (lower, upper)
260}
261
262fn compute_complex_bandwidth(rows: usize, cols: usize, data: &[(f64, f64)]) -> (usize, usize) {
263 if rows == 0 || cols == 0 {
264 return (0, 0);
265 }
266 let mut lower = 0usize;
267 let mut upper = 0usize;
268 let stride = rows;
269 for col in 0..cols {
270 for row in 0..rows {
271 let idx = row + col * stride;
272 if idx >= data.len() {
273 break;
274 }
275 let (re, im) = data[idx];
276 if !(re == 0.0 && im == 0.0) {
277 if row >= col {
278 lower = lower.max(row - col);
279 } else {
280 upper = upper.max(col - row);
281 }
282 }
283 }
284 }
285 (lower, upper)
286}
287
288#[cfg(test)]
289pub(crate) mod tests {
290 use super::*;
291 use crate::builtins::common::test_support;
292 use futures::executor::block_on;
293 use runmat_builtins::{LogicalArray, ResolveContext, Type};
294
295 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
296 #[test]
297 fn bandwidth_diagonal_matrix() {
298 let tensor = Tensor::new(vec![1.0, 0.0, 0.0, 1.0], vec![2, 2]).unwrap();
299 let value = Value::Tensor(tensor);
300 let result = bandwidth_builtin(value, Vec::new()).expect("bandwidth");
301 match result {
302 Value::Tensor(t) => {
303 assert_eq!(t.shape, vec![1, 2]);
304 assert_eq!(t.data, vec![0.0, 0.0]);
305 }
306 other => panic!("expected tensor result, got {other:?}"),
307 }
308 }
309
310 #[test]
311 fn bandwidth_type_defaults_to_two_element_tensor() {
312 let out = bandwidth_type(
313 &[Type::Tensor {
314 shape: Some(vec![Some(3), Some(3)]),
315 }],
316 &ResolveContext::new(Vec::new()),
317 );
318 assert_eq!(
319 out,
320 Type::Tensor {
321 shape: Some(vec![Some(1), Some(2)])
322 }
323 );
324 }
325
326 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
327 #[test]
328 fn bandwidth_lower_selector() {
329 let tensor = Tensor::new(
330 vec![1.0, 2.0, 3.0, 0.0, 1.0, 4.0, 0.0, 0.0, 1.0],
331 vec![3, 3],
332 )
333 .unwrap();
334 let args = vec![Value::from("lower")];
335 let result = bandwidth_builtin(Value::Tensor(tensor), args).expect("bandwidth");
336 match result {
337 Value::Num(n) => assert_eq!(n, 2.0),
338 other => panic!("expected scalar result, got {other:?}"),
339 }
340 }
341
342 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
343 #[test]
344 fn bandwidth_upper_selector() {
345 let tensor = Tensor::new(
346 vec![1.0, 0.0, 0.0, 2.0, 4.0, 0.0, 3.0, 5.0, 6.0],
347 vec![3, 3],
348 )
349 .unwrap();
350 let args = vec![Value::from("upper")];
351 let result = bandwidth_builtin(Value::Tensor(tensor), args).expect("bandwidth");
352 match result {
353 Value::Num(n) => assert_eq!(n, 2.0),
354 other => panic!("expected scalar result, got {other:?}"),
355 }
356 }
357
358 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
359 #[test]
360 fn bandwidth_complex_matrix() {
361 let data = vec![(0.0, 0.0), (1.0, 0.0), (0.0, 2.0), (0.0, 0.0)];
362 let tensor = ComplexTensor::new(data, vec![2, 2]).unwrap();
363 let result =
364 bandwidth_builtin(Value::ComplexTensor(tensor), Vec::new()).expect("bandwidth");
365 match result {
366 Value::Tensor(t) => {
367 assert_eq!(t.data, vec![1.0, 1.0]);
368 }
369 other => panic!("expected tensor result, got {other:?}"),
370 }
371 }
372
373 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
374 #[test]
375 fn bandwidth_rectangular_matrix() {
376 let tensor = Tensor::new(
377 vec![0.0, 8.0, 0.0, 0.0, 0.0, 0.0, 9.0, 0.0, 7.0, 0.0, 0.0, 10.0],
378 vec![4, 3],
379 )
380 .unwrap();
381 let result = bandwidth_builtin(Value::Tensor(tensor), Vec::new()).expect("bandwidth");
382 match result {
383 Value::Tensor(t) => assert_eq!(t.data, vec![1.0, 2.0]),
384 other => panic!("expected tensor result, got {other:?}"),
385 }
386 }
387
388 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
389 #[test]
390 fn bandwidth_empty_matrix_returns_zero() {
391 let tensor = Tensor::new(Vec::new(), vec![0, 0]).unwrap();
392 let result = bandwidth_builtin(Value::Tensor(tensor), Vec::new()).expect("bandwidth");
393 match result {
394 Value::Tensor(t) => assert_eq!(t.data, vec![0.0, 0.0]),
395 other => panic!("expected tensor result, got {other:?}"),
396 }
397 }
398
399 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
400 #[test]
401 fn bandwidth_nan_counts_as_nonzero() {
402 let tensor =
403 Tensor::new(vec![0.0, f64::NAN, 0.0, 0.0], vec![2, 2]).expect("tensor construction");
404 let result = bandwidth_builtin(Value::Tensor(tensor), Vec::new()).expect("bandwidth");
405 match result {
406 Value::Tensor(t) => assert_eq!(t.data, vec![1.0, 0.0]),
407 other => panic!("expected tensor result, got {other:?}"),
408 }
409 }
410
411 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
412 #[test]
413 fn bandwidth_logical_input_supported() {
414 let logical = LogicalArray::new(vec![1, 1, 1, 0], vec![2, 2]).expect("logical array");
415 let result =
416 bandwidth_builtin(Value::LogicalArray(logical), Vec::new()).expect("bandwidth");
417 match result {
418 Value::Tensor(t) => assert_eq!(t.data, vec![1.0, 1.0]),
419 other => panic!("expected tensor result, got {other:?}"),
420 }
421 }
422
423 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
424 #[test]
425 fn bandwidth_selector_validation() {
426 let tensor = Tensor::new(vec![1.0], vec![1, 1]).unwrap();
427 let err =
428 bandwidth_builtin(Value::Tensor(tensor), vec![Value::from("middle")]).unwrap_err();
429 let message = err.to_string();
430 assert!(
431 message.contains("lower") && message.contains("upper"),
432 "unexpected error: {message}"
433 );
434 }
435
436 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
437 #[test]
438 fn bandwidth_rejects_higher_dimensions() {
439 let tensor = Tensor::new(vec![1.0, 2.0], vec![1, 1, 2]).unwrap();
440 let err = bandwidth_builtin(Value::Tensor(tensor), Vec::new()).unwrap_err();
441 let message = err.to_string();
442 assert!(
443 message.contains("2-D"),
444 "unexpected error message: {message}"
445 );
446 }
447
448 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
449 #[test]
450 fn bandwidth_gpu_roundtrip() {
451 test_support::with_test_provider(|provider| {
452 let tensor = Tensor::new(vec![0.0, 2.0, 0.0, 0.0], vec![2, 2]).unwrap();
453 let view = runmat_accelerate_api::HostTensorView {
454 data: &tensor.data,
455 shape: &tensor.shape,
456 };
457 let handle = provider.upload(&view).expect("upload");
458 let result =
459 bandwidth_builtin(Value::GpuTensor(handle), Vec::new()).expect("bandwidth");
460 let gathered = test_support::gather(result).expect("gather");
461 assert_eq!(gathered.shape, vec![1, 2]);
462 assert_eq!(gathered.data, vec![1.0, 0.0]);
463 });
464 }
465
466 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
467 #[test]
468 #[cfg(feature = "wgpu")]
469 fn bandwidth_wgpu_matches_cpu() {
470 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
471 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
472 );
473 let Some(provider) = runmat_accelerate_api::provider() else {
474 return;
475 };
476 let tensor = Tensor::new(
477 vec![0.0, 2.0, 0.0, 0.0, 0.0, 4.0, 5.0, 0.0, 6.0],
478 vec![3, 3],
479 )
480 .unwrap();
481 let cpu = super::bandwidth_host_real_tensor(&tensor).expect("cpu bandwidth");
482 let view = runmat_accelerate_api::HostTensorView {
483 data: &tensor.data,
484 shape: &tensor.shape,
485 };
486 let handle = provider.upload(&view).expect("upload");
487 let gpu_meta = provider.bandwidth(&handle).expect("provider bandwidth");
488 assert_eq!(gpu_meta.lower as usize, cpu.0);
489 assert_eq!(gpu_meta.upper as usize, cpu.1);
490
491 let result =
492 bandwidth_builtin(Value::GpuTensor(handle.clone()), Vec::new()).expect("bandwidth");
493 let gathered = test_support::gather(result).expect("gather");
494 assert_eq!(gathered.shape, vec![1, 2]);
495 assert_eq!(gathered.data, vec![cpu.0 as f64, cpu.1 as f64]);
496 let _ = provider.free(&handle);
497 }
498
499 fn bandwidth_builtin(matrix: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
500 block_on(super::bandwidth_builtin(matrix, rest))
501 }
502}