Skip to main content

runmat_runtime/builtins/control/
db.rs

1//! MATLAB-compatible `db` decibel conversion builtin for RunMat.
2
3use runmat_builtins::{ComplexTensor, Tensor, Value};
4use runmat_macros::runtime_builtin;
5
6use crate::builtins::common::broadcast::BroadcastPlan;
7use crate::builtins::common::spec::{
8    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
9    ReductionNaN, ResidencyPolicy, ShapeRequirements,
10};
11use crate::builtins::common::tensor;
12use crate::builtins::control::type_resolvers::db_type;
13use crate::{build_runtime_error, BuiltinResult, RuntimeError};
14
15const BUILTIN_NAME: &str = "db";
16
17#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::control::db")]
18pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
19    name: "db",
20    op_kind: GpuOpKind::Custom("decibel-conversion"),
21    supported_precisions: &[],
22    broadcast: BroadcastSemantics::Matlab,
23    provider_hooks: &[],
24    constant_strategy: ConstantStrategy::InlineLiteral,
25    residency: ResidencyPolicy::GatherImmediately,
26    nan_mode: ReductionNaN::Include,
27    two_pass_threshold: None,
28    workgroup_size: None,
29    accepts_nan_mode: false,
30    notes: "Host-side decibel conversion; gpuArray inputs are gathered before applying mode parsing, complex magnitudes, and optional resistance broadcasting.",
31};
32
33#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::control::db")]
34pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
35    name: "db",
36    shape: ShapeRequirements::BroadcastCompatible,
37    constant_strategy: ConstantStrategy::InlineLiteral,
38    elementwise: None,
39    reduction: None,
40    emits_nan: false,
41    notes: "db is a compound element-wise conversion with string mode parsing and optional resistance input; it terminates fusion and executes on the host.",
42};
43
44fn builtin_error(message: impl Into<String>) -> RuntimeError {
45    build_runtime_error(message)
46        .with_builtin(BUILTIN_NAME)
47        .build()
48}
49
50#[derive(Clone, Debug)]
51enum DbMode {
52    Voltage,
53    Power,
54    Resistance(Value),
55}
56
57#[runtime_builtin(
58    name = "db",
59    category = "control",
60    summary = "Convert numeric values to decibels using MATLAB-compatible voltage, power, or resistance forms.",
61    keywords = "db,decibel,voltage,power,resistance,complex",
62    accel = "metadata",
63    type_resolver(db_type),
64    builtin_path = "crate::builtins::control::db"
65)]
66async fn db_builtin(y: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
67    if rest.len() > 1 {
68        return Err(builtin_error(
69            "db: expected db(y), db(y, 'voltage'), db(y, 'power'), or db(y, R)",
70        ));
71    }
72
73    let y = crate::gather_if_needed_async(&y).await?;
74    let mode = match rest.into_iter().next() {
75        Some(arg) => parse_mode(crate::gather_if_needed_async(&arg).await?)?,
76        None => DbMode::Voltage,
77    };
78
79    let magnitudes = magnitude_tensor(y)?;
80    match mode {
81        DbMode::Voltage => map_magnitudes(magnitudes, |m| 20.0 * m.log10()),
82        DbMode::Power => map_magnitudes(magnitudes, |m| 10.0 * m.log10()),
83        DbMode::Resistance(reference) => {
84            let reference = resistance_tensor(reference)?;
85            db_with_resistance(&magnitudes, &reference)
86        }
87    }
88}
89
90fn parse_mode(value: Value) -> BuiltinResult<DbMode> {
91    match value {
92        Value::String(text) => parse_mode_string(&text),
93        Value::StringArray(array) if array.data.len() == 1 => parse_mode_string(&array.data[0]),
94        Value::StringArray(_) => Err(builtin_error("db: mode must be a scalar string")),
95        Value::CharArray(array) if array.rows == 1 => {
96            let text = array.data.iter().collect::<String>();
97            parse_mode_string(&text)
98        }
99        Value::CharArray(_) => Err(builtin_error("db: mode must be a character row vector")),
100        other => Ok(DbMode::Resistance(other)),
101    }
102}
103
104fn parse_mode_string(text: &str) -> BuiltinResult<DbMode> {
105    match text.to_ascii_lowercase().as_str() {
106        "voltage" => Ok(DbMode::Voltage),
107        "power" => Ok(DbMode::Power),
108        _ => Err(builtin_error(format!(
109            "db: unknown mode '{text}', expected 'voltage' or 'power'"
110        ))),
111    }
112}
113
114fn magnitude_tensor(value: Value) -> BuiltinResult<Tensor> {
115    match value {
116        Value::Complex(re, im) => Tensor::new(vec![re.hypot(im)], vec![1, 1])
117            .map_err(|e| builtin_error(format!("db: {e}"))),
118        Value::ComplexTensor(tensor) => complex_magnitudes(tensor),
119        Value::String(_) | Value::StringArray(_) | Value::CharArray(_) => {
120            Err(builtin_error("db: expected numeric input"))
121        }
122        other => {
123            let mut tensor = tensor::value_into_tensor_for(BUILTIN_NAME, other)
124                .map_err(|e| builtin_error(format!("db: {e}")))?;
125            for value in &mut tensor.data {
126                *value = value.abs();
127            }
128            Ok(tensor)
129        }
130    }
131}
132
133fn complex_magnitudes(tensor: ComplexTensor) -> BuiltinResult<Tensor> {
134    let data = tensor
135        .data
136        .iter()
137        .map(|&(re, im)| re.hypot(im))
138        .collect::<Vec<_>>();
139    Tensor::new(data, tensor.shape).map_err(|e| builtin_error(format!("db: {e}")))
140}
141
142fn resistance_tensor(value: Value) -> BuiltinResult<Tensor> {
143    match value {
144        Value::Complex(_, _) | Value::ComplexTensor(_) => {
145            Err(builtin_error("db: resistance must be real"))
146        }
147        Value::String(_) | Value::StringArray(_) | Value::CharArray(_) => {
148            Err(builtin_error("db: resistance must be numeric"))
149        }
150        other => {
151            let tensor = tensor::value_into_tensor_for(BUILTIN_NAME, other)
152                .map_err(|e| builtin_error(format!("db: {e}")))?;
153            for &resistance in &tensor.data {
154                if !resistance.is_finite() || resistance <= 0.0 {
155                    return Err(builtin_error(
156                        "db: resistance values must be finite and positive",
157                    ));
158                }
159            }
160            Ok(tensor)
161        }
162    }
163}
164
165fn map_magnitudes<F>(input: Tensor, op: F) -> BuiltinResult<Value>
166where
167    F: Fn(f64) -> f64,
168{
169    let data = input
170        .data
171        .iter()
172        .map(|&value| op(value))
173        .collect::<Vec<_>>();
174    let tensor = Tensor::new(data, input.shape).map_err(|e| builtin_error(format!("db: {e}")))?;
175    Ok(tensor::tensor_into_value(tensor))
176}
177
178fn db_with_resistance(magnitudes: &Tensor, reference: &Tensor) -> BuiltinResult<Value> {
179    let plan = BroadcastPlan::new(&magnitudes.shape, &reference.shape)
180        .map_err(|err| builtin_error(format!("db: {err}")))?;
181    if plan.is_empty() {
182        let tensor = Tensor::new(Vec::new(), plan.output_shape().to_vec())
183            .map_err(|e| builtin_error(format!("db: {e}")))?;
184        return Ok(tensor::tensor_into_value(tensor));
185    }
186
187    let mut data = vec![0.0; plan.len()];
188    for (out_idx, y_idx, r_idx) in plan.iter() {
189        let magnitude = magnitudes.data[y_idx];
190        let resistance = reference.data[r_idx];
191        data[out_idx] = 10.0 * ((magnitude * magnitude) / resistance).log10();
192    }
193    let tensor = Tensor::new(data, plan.output_shape().to_vec())
194        .map_err(|e| builtin_error(format!("db: {e}")))?;
195    Ok(tensor::tensor_into_value(tensor))
196}
197
198#[cfg(test)]
199pub(crate) mod tests {
200    use super::*;
201    use crate::builtins::common::test_support;
202    use futures::executor::block_on;
203    use runmat_builtins::{CharArray, IntValue, LogicalArray, ResolveContext, StringArray, Type};
204
205    fn db_builtin(y: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
206        block_on(super::db_builtin(y, rest))
207    }
208
209    fn assert_num_close(value: Value, expected: f64) {
210        match value {
211            Value::Num(actual) => assert!(
212                (actual - expected).abs() < 1e-12,
213                "expected {expected}, got {actual}"
214            ),
215            other => panic!("expected scalar result, got {other:?}"),
216        }
217    }
218
219    fn assert_tensor_close(value: Value, expected_shape: &[usize], expected: &[f64]) {
220        match value {
221            Value::Tensor(tensor) => {
222                assert_eq!(tensor.shape, expected_shape);
223                assert_eq!(tensor.data.len(), expected.len());
224                for (&actual, &expected) in tensor.data.iter().zip(expected) {
225                    if expected.is_infinite() {
226                        assert_eq!(actual, expected);
227                    } else {
228                        assert!(
229                            (actual - expected).abs() < 1e-12,
230                            "expected {expected}, got {actual}"
231                        );
232                    }
233                }
234            }
235            other => panic!("expected tensor result, got {other:?}"),
236        }
237    }
238
239    #[test]
240    fn db_type_unary_preserves_tensor_shape() {
241        let out = db_type(
242            &[Type::Tensor {
243                shape: Some(vec![Some(2), Some(3)]),
244            }],
245            &ResolveContext::new(Vec::new()),
246        );
247        assert_eq!(
248            out,
249            Type::Tensor {
250                shape: Some(vec![Some(2), Some(3)])
251            }
252        );
253    }
254
255    #[test]
256    fn db_type_scalar_returns_num() {
257        let out = db_type(&[Type::Num], &ResolveContext::new(Vec::new()));
258        assert_eq!(out, Type::Num);
259    }
260
261    #[test]
262    fn db_type_string_mode_uses_input_shape() {
263        let out = db_type(
264            &[
265                Type::Tensor {
266                    shape: Some(vec![Some(4), Some(1)]),
267                },
268                Type::String,
269            ],
270            &ResolveContext::new(Vec::new()),
271        );
272        assert_eq!(
273            out,
274            Type::Tensor {
275                shape: Some(vec![Some(4), Some(1)])
276            }
277        );
278    }
279
280    #[test]
281    fn db_type_text_modes_use_unary_shape_rules() {
282        let string_array_type = Type::from_value(&Value::StringArray(
283            StringArray::new(vec!["power".into()], vec![1, 1]).unwrap(),
284        ));
285        let char_array_type = Type::from_value(&Value::CharArray(CharArray::new_row("power")));
286
287        for mode in [Type::String, string_array_type, char_array_type] {
288            let out = db_type(
289                &[
290                    Type::Tensor {
291                        shape: Some(vec![Some(1), Some(1)]),
292                    },
293                    mode,
294                ],
295                &ResolveContext::new(Vec::new()),
296            );
297            assert_eq!(out, Type::Num);
298        }
299    }
300
301    #[test]
302    fn db_type_resistance_broadcasts_shapes() {
303        let out = db_type(
304            &[
305                Type::Tensor {
306                    shape: Some(vec![Some(2), Some(1)]),
307                },
308                Type::Tensor {
309                    shape: Some(vec![Some(1), Some(3)]),
310                },
311            ],
312            &ResolveContext::new(Vec::new()),
313        );
314        assert_eq!(
315            out,
316            Type::Tensor {
317                shape: Some(vec![Some(2), Some(3)])
318            }
319        );
320    }
321
322    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
323    #[test]
324    fn db_default_voltage_scalar() {
325        assert_num_close(db_builtin(Value::Num(10.0), Vec::new()).expect("db"), 20.0);
326    }
327
328    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
329    #[test]
330    fn db_voltage_mode_matches_default() {
331        let result = db_builtin(
332            Value::Num(10.0),
333            vec![Value::CharArray(CharArray::new_row("voltage"))],
334        )
335        .expect("db");
336        assert_num_close(result, 20.0);
337    }
338
339    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
340    #[test]
341    fn db_power_mode_scalar() {
342        let result = db_builtin(
343            Value::Num(100.0),
344            vec![Value::CharArray(CharArray::new_row("power"))],
345        )
346        .expect("db");
347        assert_num_close(result, 20.0);
348    }
349
350    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
351    #[test]
352    fn db_negative_input_uses_magnitude() {
353        assert_num_close(db_builtin(Value::Num(-10.0), Vec::new()).expect("db"), 20.0);
354    }
355
356    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
357    #[test]
358    fn db_zero_input_returns_negative_infinity() {
359        match db_builtin(Value::Num(0.0), Vec::new()).expect("db") {
360            Value::Num(value) => assert_eq!(value, f64::NEG_INFINITY),
361            other => panic!("expected scalar result, got {other:?}"),
362        }
363    }
364
365    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
366    #[test]
367    fn db_complex_scalar_uses_magnitude() {
368        assert_num_close(
369            db_builtin(Value::Complex(3.0, 4.0), Vec::new()).expect("db"),
370            20.0 * 5.0f64.log10(),
371        );
372    }
373
374    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
375    #[test]
376    fn db_tensor_elements() {
377        let tensor = Tensor::new(vec![1.0, 10.0, 100.0], vec![1, 3]).unwrap();
378        let result = db_builtin(Value::Tensor(tensor), Vec::new()).expect("db");
379        assert_tensor_close(result, &[1, 3], &[0.0, 20.0, 40.0]);
380    }
381
382    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
383    #[test]
384    fn db_complex_tensor_returns_real_tensor() {
385        let tensor = ComplexTensor::new(vec![(3.0, 4.0), (0.0, -10.0)], vec![2, 1]).unwrap();
386        let result = db_builtin(Value::ComplexTensor(tensor), Vec::new()).expect("db");
387        assert_tensor_close(result, &[2, 1], &[20.0 * 5.0f64.log10(), 20.0]);
388    }
389
390    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
391    #[test]
392    fn db_resistance_scalar() {
393        let result = db_builtin(Value::Num(10.0), vec![Value::Num(50.0)]).expect("db");
394        assert_num_close(result, 10.0 * (2.0f64).log10());
395    }
396
397    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
398    #[test]
399    fn db_resistance_broadcasts() {
400        let y = Tensor::new(vec![10.0, 20.0], vec![2, 1]).unwrap();
401        let r = Tensor::new(vec![50.0, 100.0, 200.0], vec![1, 3]).unwrap();
402        let result = db_builtin(Value::Tensor(y), vec![Value::Tensor(r)]).expect("db");
403        assert_tensor_close(
404            result,
405            &[2, 3],
406            &[
407                10.0 * (100.0f64 / 50.0).log10(),
408                10.0 * (400.0f64 / 50.0).log10(),
409                10.0 * (100.0f64 / 100.0).log10(),
410                10.0 * (400.0f64 / 100.0).log10(),
411                10.0 * (100.0f64 / 200.0).log10(),
412                10.0 * (400.0f64 / 200.0).log10(),
413            ],
414        );
415    }
416
417    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
418    #[test]
419    fn db_logical_and_integer_inputs_promote_to_double() {
420        let logical = LogicalArray::new(vec![1, 0], vec![1, 2]).unwrap();
421        let result = db_builtin(Value::LogicalArray(logical), Vec::new()).expect("db");
422        assert_tensor_close(result, &[1, 2], &[0.0, f64::NEG_INFINITY]);
423
424        let result = db_builtin(Value::Int(IntValue::I32(10)), Vec::new()).expect("db");
425        assert_num_close(result, 20.0);
426    }
427
428    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
429    #[test]
430    fn db_rejects_invalid_mode() {
431        let err = db_builtin(
432            Value::Num(1.0),
433            vec![Value::CharArray(CharArray::new_row("energy"))],
434        )
435        .expect_err("invalid mode");
436        assert!(err.message().contains("unknown mode"));
437    }
438
439    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
440    #[test]
441    fn db_rejects_nonpositive_resistance() {
442        let err =
443            db_builtin(Value::Num(1.0), vec![Value::Num(0.0)]).expect_err("invalid resistance");
444        assert!(err.message().contains("finite and positive"));
445    }
446
447    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
448    #[test]
449    fn db_rejects_nonnumeric_input() {
450        let err = db_builtin(Value::from("hello"), Vec::new()).expect_err("invalid input");
451        assert!(err.message().contains("expected numeric"));
452    }
453
454    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
455    #[test]
456    fn db_gpu_input_gathers_to_host() {
457        test_support::with_test_provider(|provider| {
458            let tensor = Tensor::new(vec![1.0, 10.0, 100.0], vec![1, 3]).unwrap();
459            let view = runmat_accelerate_api::HostTensorView {
460                data: &tensor.data,
461                shape: &tensor.shape,
462            };
463            let handle = provider.upload(&view).expect("upload");
464            let result = db_builtin(Value::GpuTensor(handle), Vec::new()).expect("db");
465            assert_tensor_close(result, &[1, 3], &[0.0, 20.0, 40.0]);
466        });
467    }
468}