1use runmat_builtins::{ComplexTensor, Tensor, Value};
4use runmat_macros::runtime_builtin;
5
6use crate::builtins::common::broadcast::BroadcastPlan;
7use crate::builtins::common::spec::{
8 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
9 ReductionNaN, ResidencyPolicy, ShapeRequirements,
10};
11use crate::builtins::common::tensor;
12use crate::builtins::control::type_resolvers::db_type;
13use crate::{build_runtime_error, BuiltinResult, RuntimeError};
14
15const BUILTIN_NAME: &str = "db";
16
17#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::control::db")]
18pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
19 name: "db",
20 op_kind: GpuOpKind::Custom("decibel-conversion"),
21 supported_precisions: &[],
22 broadcast: BroadcastSemantics::Matlab,
23 provider_hooks: &[],
24 constant_strategy: ConstantStrategy::InlineLiteral,
25 residency: ResidencyPolicy::GatherImmediately,
26 nan_mode: ReductionNaN::Include,
27 two_pass_threshold: None,
28 workgroup_size: None,
29 accepts_nan_mode: false,
30 notes: "Host-side decibel conversion; gpuArray inputs are gathered before applying mode parsing, complex magnitudes, and optional resistance broadcasting.",
31};
32
33#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::control::db")]
34pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
35 name: "db",
36 shape: ShapeRequirements::BroadcastCompatible,
37 constant_strategy: ConstantStrategy::InlineLiteral,
38 elementwise: None,
39 reduction: None,
40 emits_nan: false,
41 notes: "db is a compound element-wise conversion with string mode parsing and optional resistance input; it terminates fusion and executes on the host.",
42};
43
44fn builtin_error(message: impl Into<String>) -> RuntimeError {
45 build_runtime_error(message)
46 .with_builtin(BUILTIN_NAME)
47 .build()
48}
49
50#[derive(Clone, Debug)]
51enum DbMode {
52 Voltage,
53 Power,
54 Resistance(Value),
55}
56
57#[runtime_builtin(
58 name = "db",
59 category = "control",
60 summary = "Convert numeric values to decibels using MATLAB-compatible voltage, power, or resistance forms.",
61 keywords = "db,decibel,voltage,power,resistance,complex",
62 accel = "metadata",
63 type_resolver(db_type),
64 builtin_path = "crate::builtins::control::db"
65)]
66async fn db_builtin(y: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
67 if rest.len() > 1 {
68 return Err(builtin_error(
69 "db: expected db(y), db(y, 'voltage'), db(y, 'power'), or db(y, R)",
70 ));
71 }
72
73 let y = crate::gather_if_needed_async(&y).await?;
74 let mode = match rest.into_iter().next() {
75 Some(arg) => parse_mode(crate::gather_if_needed_async(&arg).await?)?,
76 None => DbMode::Voltage,
77 };
78
79 let magnitudes = magnitude_tensor(y)?;
80 match mode {
81 DbMode::Voltage => map_magnitudes(magnitudes, |m| 20.0 * m.log10()),
82 DbMode::Power => map_magnitudes(magnitudes, |m| 10.0 * m.log10()),
83 DbMode::Resistance(reference) => {
84 let reference = resistance_tensor(reference)?;
85 db_with_resistance(&magnitudes, &reference)
86 }
87 }
88}
89
90fn parse_mode(value: Value) -> BuiltinResult<DbMode> {
91 match value {
92 Value::String(text) => parse_mode_string(&text),
93 Value::StringArray(array) if array.data.len() == 1 => parse_mode_string(&array.data[0]),
94 Value::StringArray(_) => Err(builtin_error("db: mode must be a scalar string")),
95 Value::CharArray(array) if array.rows == 1 => {
96 let text = array.data.iter().collect::<String>();
97 parse_mode_string(&text)
98 }
99 Value::CharArray(_) => Err(builtin_error("db: mode must be a character row vector")),
100 other => Ok(DbMode::Resistance(other)),
101 }
102}
103
104fn parse_mode_string(text: &str) -> BuiltinResult<DbMode> {
105 match text.to_ascii_lowercase().as_str() {
106 "voltage" => Ok(DbMode::Voltage),
107 "power" => Ok(DbMode::Power),
108 _ => Err(builtin_error(format!(
109 "db: unknown mode '{text}', expected 'voltage' or 'power'"
110 ))),
111 }
112}
113
114fn magnitude_tensor(value: Value) -> BuiltinResult<Tensor> {
115 match value {
116 Value::Complex(re, im) => Tensor::new(vec![re.hypot(im)], vec![1, 1])
117 .map_err(|e| builtin_error(format!("db: {e}"))),
118 Value::ComplexTensor(tensor) => complex_magnitudes(tensor),
119 Value::String(_) | Value::StringArray(_) | Value::CharArray(_) => {
120 Err(builtin_error("db: expected numeric input"))
121 }
122 other => {
123 let mut tensor = tensor::value_into_tensor_for(BUILTIN_NAME, other)
124 .map_err(|e| builtin_error(format!("db: {e}")))?;
125 for value in &mut tensor.data {
126 *value = value.abs();
127 }
128 Ok(tensor)
129 }
130 }
131}
132
133fn complex_magnitudes(tensor: ComplexTensor) -> BuiltinResult<Tensor> {
134 let data = tensor
135 .data
136 .iter()
137 .map(|&(re, im)| re.hypot(im))
138 .collect::<Vec<_>>();
139 Tensor::new(data, tensor.shape).map_err(|e| builtin_error(format!("db: {e}")))
140}
141
142fn resistance_tensor(value: Value) -> BuiltinResult<Tensor> {
143 match value {
144 Value::Complex(_, _) | Value::ComplexTensor(_) => {
145 Err(builtin_error("db: resistance must be real"))
146 }
147 Value::String(_) | Value::StringArray(_) | Value::CharArray(_) => {
148 Err(builtin_error("db: resistance must be numeric"))
149 }
150 other => {
151 let tensor = tensor::value_into_tensor_for(BUILTIN_NAME, other)
152 .map_err(|e| builtin_error(format!("db: {e}")))?;
153 for &resistance in &tensor.data {
154 if !resistance.is_finite() || resistance <= 0.0 {
155 return Err(builtin_error(
156 "db: resistance values must be finite and positive",
157 ));
158 }
159 }
160 Ok(tensor)
161 }
162 }
163}
164
165fn map_magnitudes<F>(input: Tensor, op: F) -> BuiltinResult<Value>
166where
167 F: Fn(f64) -> f64,
168{
169 let data = input
170 .data
171 .iter()
172 .map(|&value| op(value))
173 .collect::<Vec<_>>();
174 let tensor = Tensor::new(data, input.shape).map_err(|e| builtin_error(format!("db: {e}")))?;
175 Ok(tensor::tensor_into_value(tensor))
176}
177
178fn db_with_resistance(magnitudes: &Tensor, reference: &Tensor) -> BuiltinResult<Value> {
179 let plan = BroadcastPlan::new(&magnitudes.shape, &reference.shape)
180 .map_err(|err| builtin_error(format!("db: {err}")))?;
181 if plan.is_empty() {
182 let tensor = Tensor::new(Vec::new(), plan.output_shape().to_vec())
183 .map_err(|e| builtin_error(format!("db: {e}")))?;
184 return Ok(tensor::tensor_into_value(tensor));
185 }
186
187 let mut data = vec![0.0; plan.len()];
188 for (out_idx, y_idx, r_idx) in plan.iter() {
189 let magnitude = magnitudes.data[y_idx];
190 let resistance = reference.data[r_idx];
191 data[out_idx] = 10.0 * ((magnitude * magnitude) / resistance).log10();
192 }
193 let tensor = Tensor::new(data, plan.output_shape().to_vec())
194 .map_err(|e| builtin_error(format!("db: {e}")))?;
195 Ok(tensor::tensor_into_value(tensor))
196}
197
198#[cfg(test)]
199pub(crate) mod tests {
200 use super::*;
201 use crate::builtins::common::test_support;
202 use futures::executor::block_on;
203 use runmat_builtins::{CharArray, IntValue, LogicalArray, ResolveContext, StringArray, Type};
204
205 fn db_builtin(y: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
206 block_on(super::db_builtin(y, rest))
207 }
208
209 fn assert_num_close(value: Value, expected: f64) {
210 match value {
211 Value::Num(actual) => assert!(
212 (actual - expected).abs() < 1e-12,
213 "expected {expected}, got {actual}"
214 ),
215 other => panic!("expected scalar result, got {other:?}"),
216 }
217 }
218
219 fn assert_tensor_close(value: Value, expected_shape: &[usize], expected: &[f64]) {
220 match value {
221 Value::Tensor(tensor) => {
222 assert_eq!(tensor.shape, expected_shape);
223 assert_eq!(tensor.data.len(), expected.len());
224 for (&actual, &expected) in tensor.data.iter().zip(expected) {
225 if expected.is_infinite() {
226 assert_eq!(actual, expected);
227 } else {
228 assert!(
229 (actual - expected).abs() < 1e-12,
230 "expected {expected}, got {actual}"
231 );
232 }
233 }
234 }
235 other => panic!("expected tensor result, got {other:?}"),
236 }
237 }
238
239 #[test]
240 fn db_type_unary_preserves_tensor_shape() {
241 let out = db_type(
242 &[Type::Tensor {
243 shape: Some(vec![Some(2), Some(3)]),
244 }],
245 &ResolveContext::new(Vec::new()),
246 );
247 assert_eq!(
248 out,
249 Type::Tensor {
250 shape: Some(vec![Some(2), Some(3)])
251 }
252 );
253 }
254
255 #[test]
256 fn db_type_scalar_returns_num() {
257 let out = db_type(&[Type::Num], &ResolveContext::new(Vec::new()));
258 assert_eq!(out, Type::Num);
259 }
260
261 #[test]
262 fn db_type_string_mode_uses_input_shape() {
263 let out = db_type(
264 &[
265 Type::Tensor {
266 shape: Some(vec![Some(4), Some(1)]),
267 },
268 Type::String,
269 ],
270 &ResolveContext::new(Vec::new()),
271 );
272 assert_eq!(
273 out,
274 Type::Tensor {
275 shape: Some(vec![Some(4), Some(1)])
276 }
277 );
278 }
279
280 #[test]
281 fn db_type_text_modes_use_unary_shape_rules() {
282 let string_array_type = Type::from_value(&Value::StringArray(
283 StringArray::new(vec!["power".into()], vec![1, 1]).unwrap(),
284 ));
285 let char_array_type = Type::from_value(&Value::CharArray(CharArray::new_row("power")));
286
287 for mode in [Type::String, string_array_type, char_array_type] {
288 let out = db_type(
289 &[
290 Type::Tensor {
291 shape: Some(vec![Some(1), Some(1)]),
292 },
293 mode,
294 ],
295 &ResolveContext::new(Vec::new()),
296 );
297 assert_eq!(out, Type::Num);
298 }
299 }
300
301 #[test]
302 fn db_type_resistance_broadcasts_shapes() {
303 let out = db_type(
304 &[
305 Type::Tensor {
306 shape: Some(vec![Some(2), Some(1)]),
307 },
308 Type::Tensor {
309 shape: Some(vec![Some(1), Some(3)]),
310 },
311 ],
312 &ResolveContext::new(Vec::new()),
313 );
314 assert_eq!(
315 out,
316 Type::Tensor {
317 shape: Some(vec![Some(2), Some(3)])
318 }
319 );
320 }
321
322 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
323 #[test]
324 fn db_default_voltage_scalar() {
325 assert_num_close(db_builtin(Value::Num(10.0), Vec::new()).expect("db"), 20.0);
326 }
327
328 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
329 #[test]
330 fn db_voltage_mode_matches_default() {
331 let result = db_builtin(
332 Value::Num(10.0),
333 vec![Value::CharArray(CharArray::new_row("voltage"))],
334 )
335 .expect("db");
336 assert_num_close(result, 20.0);
337 }
338
339 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
340 #[test]
341 fn db_power_mode_scalar() {
342 let result = db_builtin(
343 Value::Num(100.0),
344 vec![Value::CharArray(CharArray::new_row("power"))],
345 )
346 .expect("db");
347 assert_num_close(result, 20.0);
348 }
349
350 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
351 #[test]
352 fn db_negative_input_uses_magnitude() {
353 assert_num_close(db_builtin(Value::Num(-10.0), Vec::new()).expect("db"), 20.0);
354 }
355
356 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
357 #[test]
358 fn db_zero_input_returns_negative_infinity() {
359 match db_builtin(Value::Num(0.0), Vec::new()).expect("db") {
360 Value::Num(value) => assert_eq!(value, f64::NEG_INFINITY),
361 other => panic!("expected scalar result, got {other:?}"),
362 }
363 }
364
365 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
366 #[test]
367 fn db_complex_scalar_uses_magnitude() {
368 assert_num_close(
369 db_builtin(Value::Complex(3.0, 4.0), Vec::new()).expect("db"),
370 20.0 * 5.0f64.log10(),
371 );
372 }
373
374 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
375 #[test]
376 fn db_tensor_elements() {
377 let tensor = Tensor::new(vec![1.0, 10.0, 100.0], vec![1, 3]).unwrap();
378 let result = db_builtin(Value::Tensor(tensor), Vec::new()).expect("db");
379 assert_tensor_close(result, &[1, 3], &[0.0, 20.0, 40.0]);
380 }
381
382 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
383 #[test]
384 fn db_complex_tensor_returns_real_tensor() {
385 let tensor = ComplexTensor::new(vec![(3.0, 4.0), (0.0, -10.0)], vec![2, 1]).unwrap();
386 let result = db_builtin(Value::ComplexTensor(tensor), Vec::new()).expect("db");
387 assert_tensor_close(result, &[2, 1], &[20.0 * 5.0f64.log10(), 20.0]);
388 }
389
390 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
391 #[test]
392 fn db_resistance_scalar() {
393 let result = db_builtin(Value::Num(10.0), vec![Value::Num(50.0)]).expect("db");
394 assert_num_close(result, 10.0 * (2.0f64).log10());
395 }
396
397 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
398 #[test]
399 fn db_resistance_broadcasts() {
400 let y = Tensor::new(vec![10.0, 20.0], vec![2, 1]).unwrap();
401 let r = Tensor::new(vec![50.0, 100.0, 200.0], vec![1, 3]).unwrap();
402 let result = db_builtin(Value::Tensor(y), vec![Value::Tensor(r)]).expect("db");
403 assert_tensor_close(
404 result,
405 &[2, 3],
406 &[
407 10.0 * (100.0f64 / 50.0).log10(),
408 10.0 * (400.0f64 / 50.0).log10(),
409 10.0 * (100.0f64 / 100.0).log10(),
410 10.0 * (400.0f64 / 100.0).log10(),
411 10.0 * (100.0f64 / 200.0).log10(),
412 10.0 * (400.0f64 / 200.0).log10(),
413 ],
414 );
415 }
416
417 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
418 #[test]
419 fn db_logical_and_integer_inputs_promote_to_double() {
420 let logical = LogicalArray::new(vec![1, 0], vec![1, 2]).unwrap();
421 let result = db_builtin(Value::LogicalArray(logical), Vec::new()).expect("db");
422 assert_tensor_close(result, &[1, 2], &[0.0, f64::NEG_INFINITY]);
423
424 let result = db_builtin(Value::Int(IntValue::I32(10)), Vec::new()).expect("db");
425 assert_num_close(result, 20.0);
426 }
427
428 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
429 #[test]
430 fn db_rejects_invalid_mode() {
431 let err = db_builtin(
432 Value::Num(1.0),
433 vec![Value::CharArray(CharArray::new_row("energy"))],
434 )
435 .expect_err("invalid mode");
436 assert!(err.message().contains("unknown mode"));
437 }
438
439 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
440 #[test]
441 fn db_rejects_nonpositive_resistance() {
442 let err =
443 db_builtin(Value::Num(1.0), vec![Value::Num(0.0)]).expect_err("invalid resistance");
444 assert!(err.message().contains("finite and positive"));
445 }
446
447 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
448 #[test]
449 fn db_rejects_nonnumeric_input() {
450 let err = db_builtin(Value::from("hello"), Vec::new()).expect_err("invalid input");
451 assert!(err.message().contains("expected numeric"));
452 }
453
454 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
455 #[test]
456 fn db_gpu_input_gathers_to_host() {
457 test_support::with_test_provider(|provider| {
458 let tensor = Tensor::new(vec![1.0, 10.0, 100.0], vec![1, 3]).unwrap();
459 let view = runmat_accelerate_api::HostTensorView {
460 data: &tensor.data,
461 shape: &tensor.shape,
462 };
463 let handle = provider.upload(&view).expect("upload");
464 let result = db_builtin(Value::GpuTensor(handle), Vec::new()).expect("db");
465 assert_tensor_close(result, &[1, 3], &[0.0, 20.0, 40.0]);
466 });
467 }
468}