Skip to main content

runmat_runtime/builtins/math/poly/
roots.rs

1//! MATLAB-compatible `roots` builtin with GPU-aware semantics for RunMat.
2//!
3//! This implementation mirrors MATLAB behaviour, including handling for leading
4//! zeros, constant polynomials, and complex-valued coefficients. GPU inputs are
5//! gathered to the host because companion matrix eigenvalue computations are
6//! currently performed on the CPU.
7
8use nalgebra::DMatrix;
9use num_complex::Complex64;
10use runmat_builtins::{
11    BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
12    BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
13    ComplexTensor, Tensor, Value,
14};
15use runmat_macros::runtime_builtin;
16
17use crate::builtins::common::spec::{
18    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
19    ReductionNaN, ResidencyPolicy, ShapeRequirements,
20};
21use crate::builtins::common::{gpu_helpers, tensor};
22use crate::builtins::math::poly::type_resolvers::roots_type;
23use crate::{build_runtime_error, BuiltinResult, RuntimeError};
24
25const LEADING_ZERO_TOL: f64 = 1.0e-12;
26const RESULT_ZERO_TOL: f64 = 1.0e-10;
27const BUILTIN_NAME: &str = "roots";
28
29const ROOTS_OUTPUT: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
30    name: "r",
31    ty: BuiltinParamType::Any,
32    arity: BuiltinParamArity::Required,
33    default: None,
34    description: "Roots of the polynomial as a column vector.",
35}];
36
37const ROOTS_INPUTS: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
38    name: "c",
39    ty: BuiltinParamType::Any,
40    arity: BuiltinParamArity::Required,
41    default: None,
42    description: "Polynomial coefficient vector in descending power order.",
43}];
44
45const ROOTS_SIGNATURES: [BuiltinSignatureDescriptor; 1] = [BuiltinSignatureDescriptor {
46    label: "r = roots(c)",
47    inputs: &ROOTS_INPUTS,
48    outputs: &ROOTS_OUTPUT,
49}];
50
51const ROOTS_ERROR_INVALID_INPUT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
52    code: "RM.ROOTS.INVALID_INPUT",
53    identifier: Some("RunMat:roots:InvalidInput"),
54    when: "Input cannot be interpreted as a numeric coefficient vector.",
55    message: "roots: invalid input",
56};
57
58const ROOTS_ERROR_INTERNAL: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
59    code: "RM.ROOTS.INTERNAL",
60    identifier: Some("RunMat:roots:Internal"),
61    when: "Runtime fails while building companion matrix outputs or solving eigenvalues.",
62    message: "roots: internal runtime failure",
63};
64
65const ROOTS_ERRORS: [BuiltinErrorDescriptor; 2] = [ROOTS_ERROR_INVALID_INPUT, ROOTS_ERROR_INTERNAL];
66
67pub const ROOTS_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
68    signatures: &ROOTS_SIGNATURES,
69    output_mode: BuiltinOutputMode::Fixed,
70    completion_policy: BuiltinCompletionPolicy::Public,
71    errors: &ROOTS_ERRORS,
72};
73
74#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::math::poly::roots")]
75pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
76    name: "roots",
77    op_kind: GpuOpKind::Custom("polynomial-roots"),
78    supported_precisions: &[],
79    broadcast: BroadcastSemantics::None,
80    provider_hooks: &[],
81    constant_strategy: ConstantStrategy::InlineLiteral,
82    residency: ResidencyPolicy::GatherImmediately,
83    nan_mode: ReductionNaN::Include,
84    two_pass_threshold: None,
85    workgroup_size: None,
86    accepts_nan_mode: false,
87    notes: "Companion matrix eigenvalue solve executes on the host; providers currently fall back to the CPU implementation.",
88};
89
90fn roots_error(message: impl Into<String>) -> RuntimeError {
91    roots_error_with(message, &ROOTS_ERROR_INVALID_INPUT)
92}
93
94fn roots_error_with(
95    message: impl Into<String>,
96    error: &'static BuiltinErrorDescriptor,
97) -> RuntimeError {
98    let mut builder = build_runtime_error(message).with_builtin(BUILTIN_NAME);
99    if let Some(identifier) = error.identifier {
100        builder = builder.with_identifier(identifier);
101    }
102    builder.build()
103}
104
105#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::math::poly::roots")]
106pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
107    name: "roots",
108    shape: ShapeRequirements::Any,
109    constant_strategy: ConstantStrategy::InlineLiteral,
110    elementwise: None,
111    reduction: None,
112    emits_nan: true,
113    notes: "Non-elementwise builtin that terminates fusion and gathers inputs to the host.",
114};
115
116#[runtime_builtin(
117    name = "roots",
118    category = "math/poly",
119    summary = "Compute polynomial roots from a coefficient vector.",
120    keywords = "roots,polynomial,eigenvalues,companion",
121    accel = "sink",
122    type_resolver(roots_type),
123    descriptor(crate::builtins::math::poly::roots::ROOTS_DESCRIPTOR),
124    builtin_path = "crate::builtins::math::poly::roots"
125)]
126async fn roots_builtin(coefficients: Value) -> crate::BuiltinResult<Value> {
127    let coeffs = coefficients_to_complex(coefficients).await?;
128    let trimmed = trim_leading_zeros(coeffs);
129    if trimmed.is_empty() || trimmed.len() == 1 {
130        return empty_column();
131    }
132    let roots = solve_roots(&trimmed)?;
133    roots_to_value(&roots)
134}
135
136async fn coefficients_to_complex(value: Value) -> BuiltinResult<Vec<Complex64>> {
137    match value {
138        Value::GpuTensor(handle) => {
139            let tensor = gpu_helpers::gather_tensor_async(&handle).await?;
140            tensor_to_complex(tensor)
141        }
142        Value::Tensor(tensor) => tensor_to_complex(tensor),
143        Value::ComplexTensor(tensor) => complex_tensor_to_vec(tensor),
144        Value::LogicalArray(logical) => {
145            let tensor = tensor::logical_to_tensor(&logical).map_err(roots_error)?;
146            tensor_to_complex(tensor)
147        }
148        Value::Num(n) => {
149            let tensor =
150                Tensor::new(vec![n], vec![1, 1]).map_err(|e| roots_error(format!("roots: {e}")))?;
151            tensor_to_complex(tensor)
152        }
153        Value::Int(i) => {
154            let tensor = Tensor::new(vec![i.to_f64()], vec![1, 1])
155                .map_err(|e| roots_error(format!("roots: {e}")))?;
156            tensor_to_complex(tensor)
157        }
158        Value::Bool(b) => {
159            let tensor = Tensor::new(vec![if b { 1.0 } else { 0.0 }], vec![1, 1])
160                .map_err(|e| roots_error(format!("roots: {e}")))?;
161            tensor_to_complex(tensor)
162        }
163        other => Err(roots_error(format!(
164            "roots: expected a numeric vector of polynomial coefficients, got {other:?}"
165        ))),
166    }
167}
168
169fn tensor_to_complex(tensor: Tensor) -> BuiltinResult<Vec<Complex64>> {
170    ensure_vector_shape("roots", &tensor.shape)?;
171    Ok(tensor
172        .data
173        .into_iter()
174        .map(|value| Complex64::new(value, 0.0))
175        .collect())
176}
177
178fn complex_tensor_to_vec(tensor: ComplexTensor) -> BuiltinResult<Vec<Complex64>> {
179    ensure_vector_shape("roots", &tensor.shape)?;
180    Ok(tensor
181        .data
182        .into_iter()
183        .map(|(re, im)| Complex64::new(re, im))
184        .collect())
185}
186
187fn ensure_vector_shape(name: &str, shape: &[usize]) -> BuiltinResult<()> {
188    let is_vector = match shape.len() {
189        0 => true,
190        1 => true,
191        2 => shape[0] == 1 || shape[1] == 1 || shape.iter().product::<usize>() == 0,
192        _ => shape.iter().filter(|&&dim| dim > 1).count() <= 1,
193    };
194    if !is_vector {
195        return Err(roots_error(format!(
196            "{name}: coefficients must be a vector (row or column), got shape {:?}",
197            shape
198        )));
199    }
200    Ok(())
201}
202
203fn trim_leading_zeros(mut coeffs: Vec<Complex64>) -> Vec<Complex64> {
204    if coeffs.is_empty() {
205        return coeffs;
206    }
207    let scale = coeffs.iter().map(|c| c.norm()).fold(0.0_f64, f64::max);
208    let tol = if scale == 0.0 {
209        LEADING_ZERO_TOL
210    } else {
211        LEADING_ZERO_TOL * scale
212    };
213    let first_nonzero = coeffs
214        .iter()
215        .position(|c| c.norm() > tol)
216        .unwrap_or(coeffs.len());
217    coeffs.split_off(first_nonzero)
218}
219
220fn solve_roots(coeffs: &[Complex64]) -> BuiltinResult<Vec<Complex64>> {
221    if coeffs.len() <= 1 {
222        return Ok(Vec::new());
223    }
224    if coeffs.len() == 2 {
225        let a = coeffs[0];
226        let b = coeffs[1];
227        if a.norm() <= LEADING_ZERO_TOL {
228            return Err(roots_error(
229                "roots: leading coefficient must be non-zero after trimming",
230            ));
231        }
232        return Ok(vec![-b / a]);
233    }
234
235    let degree = coeffs.len() - 1;
236    if degree == 3 {
237        return Ok(cubic_roots(coeffs[0], coeffs[1], coeffs[2], coeffs[3]));
238    }
239    let leading = coeffs[0];
240    if leading.norm() <= LEADING_ZERO_TOL {
241        return Err(roots_error(
242            "roots: leading coefficient must be non-zero after trimming",
243        ));
244    }
245
246    let mut companion = DMatrix::<Complex64>::zeros(degree, degree);
247    for row in 1..degree {
248        companion[(row, row - 1)] = Complex64::new(1.0, 0.0);
249    }
250
251    for (idx, coeff) in coeffs.iter().enumerate().skip(1) {
252        let value = -(*coeff) / leading;
253        let column = idx - 1;
254        if column < degree {
255            companion[(0, column)] = value;
256        }
257    }
258
259    let eigenvalues = companion.clone().eigenvalues().ok_or_else(|| {
260        roots_error_with(
261            "roots: failed to compute eigenvalues of the companion matrix",
262            &ROOTS_ERROR_INTERNAL,
263        )
264    })?;
265    Ok(eigenvalues.iter().map(|&z| canonicalize_root(z)).collect())
266}
267
268fn cubic_roots(a: Complex64, b: Complex64, c: Complex64, d: Complex64) -> Vec<Complex64> {
269    // Depressed cubic via Cardano: x = y - b/(3a), y^3 + p y + q = 0
270    let three = 3.0;
271    let nine = 9.0;
272    let twenty_seven = 27.0;
273    let a2 = a * a;
274    let a3 = a2 * a;
275    let p = (three * a * c - b * b) / (three * a2);
276    let q = (twenty_seven * a2 * d - nine * a * b * c + Complex64::new(2.0, 0.0) * b * b * b)
277        / (twenty_seven * a3);
278    let half = Complex64::new(0.5, 0.0);
279    let disc = (q * q) * half * half + (p * p * p) / Complex64::new(27.0, 0.0);
280    let sqrt_disc = disc.sqrt();
281    let u = (-q * half + sqrt_disc).powf(1.0 / 3.0);
282    let v = (-q * half - sqrt_disc).powf(1.0 / 3.0);
283    let omega = Complex64::new(-0.5, (3.0f64).sqrt() * 0.5);
284    let omega2 = omega * omega;
285    let shift = b / (three * a);
286    let y0 = u + v;
287    let y1 = u * omega + v * omega.conj();
288    let y2 = u * omega2 + v * omega;
289    vec![y0 - shift, y1 - shift, y2 - shift]
290}
291
292fn canonicalize_root(z: Complex64) -> Complex64 {
293    if !z.re.is_finite() || !z.im.is_finite() {
294        return z;
295    }
296    let mut real = z.re;
297    let mut imag = z.im;
298    let scale = 1.0 + real.abs();
299    if imag.abs() <= RESULT_ZERO_TOL * scale {
300        imag = 0.0;
301    }
302    if real.abs() <= RESULT_ZERO_TOL {
303        real = 0.0;
304    }
305    Complex64::new(real, imag)
306}
307
308fn roots_to_value(roots: &[Complex64]) -> BuiltinResult<Value> {
309    if roots.is_empty() {
310        return empty_column();
311    }
312    let all_real = roots
313        .iter()
314        .all(|z| z.im.abs() <= RESULT_ZERO_TOL * (1.0 + z.re.abs()));
315    if all_real {
316        let mut data: Vec<f64> = Vec::with_capacity(roots.len());
317        for &root in roots {
318            data.push(root.re);
319        }
320        let tensor = Tensor::new(data, vec![roots.len(), 1])
321            .map_err(|e| roots_error_with(format!("roots: {e}"), &ROOTS_ERROR_INTERNAL))?;
322        Ok(Value::Tensor(tensor))
323    } else {
324        let data: Vec<(f64, f64)> = roots.iter().map(|z| (z.re, z.im)).collect();
325        let tensor = ComplexTensor::new(data, vec![roots.len(), 1])
326            .map_err(|e| roots_error_with(format!("roots: {e}"), &ROOTS_ERROR_INTERNAL))?;
327        Ok(Value::ComplexTensor(tensor))
328    }
329}
330
331fn empty_column() -> BuiltinResult<Value> {
332    let tensor = Tensor::new(Vec::new(), vec![0, 1])
333        .map_err(|e| roots_error_with(format!("roots: {e}"), &ROOTS_ERROR_INTERNAL))?;
334    Ok(Value::Tensor(tensor))
335}
336
337#[cfg(test)]
338pub(crate) mod tests {
339    use super::*;
340    use crate::builtins::common::test_support;
341    use futures::executor::block_on;
342    use runmat_accelerate_api::HostTensorView;
343    use runmat_builtins::{ComplexTensor, LogicalArray, Tensor};
344
345    fn assert_error_contains(err: crate::RuntimeError, needle: &str) {
346        assert!(
347            err.message().contains(needle),
348            "expected error containing '{needle}', got '{}'",
349            err.message()
350        );
351    }
352
353    #[test]
354    fn roots_descriptor_signatures_cover_core_forms() {
355        let labels: Vec<&str> = ROOTS_DESCRIPTOR
356            .signatures
357            .iter()
358            .map(|signature| signature.label)
359            .collect();
360        assert!(labels.contains(&"r = roots(c)"));
361    }
362
363    #[test]
364    fn roots_descriptor_errors_have_stable_codes() {
365        let codes: Vec<&str> = ROOTS_DESCRIPTOR
366            .errors
367            .iter()
368            .map(|error| error.code)
369            .collect();
370        assert!(codes.contains(&"RM.ROOTS.INVALID_INPUT"));
371        assert!(codes.contains(&"RM.ROOTS.INTERNAL"));
372    }
373
374    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
375    #[test]
376    fn roots_quadratic_real() {
377        let coeffs = Tensor::new(vec![1.0, -3.0, 2.0], vec![3, 1]).unwrap();
378        let result = roots_builtin(Value::Tensor(coeffs)).expect("roots");
379        match result {
380            Value::Tensor(t) => {
381                assert_eq!(t.shape, vec![2, 1]);
382                let mut roots = t.data;
383                roots.sort_by(|a, b| a.partial_cmp(b).unwrap());
384                assert!((roots[0] - 1.0).abs() < 1e-10);
385                assert!((roots[1] - 2.0).abs() < 1e-10);
386            }
387            other => panic!("expected real tensor, got {other:?}"),
388        }
389    }
390
391    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
392    #[test]
393    fn roots_leading_zeros_trimmed() {
394        let coeffs = Tensor::new(vec![0.0, 0.0, 1.0, -4.0], vec![4, 1]).unwrap();
395        let result = roots_builtin(Value::Tensor(coeffs)).expect("roots");
396        match result {
397            Value::Tensor(t) => {
398                assert_eq!(t.shape, vec![1, 1]);
399                assert!((t.data[0] - 4.0).abs() < 1e-10);
400            }
401            other => panic!("expected tensor, got {other:?}"),
402        }
403    }
404
405    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
406    #[test]
407    fn roots_complex_pair() {
408        let coeffs = Tensor::new(vec![1.0, 0.0, 1.0], vec![3, 1]).unwrap();
409        let result = roots_builtin(Value::Tensor(coeffs)).expect("roots");
410        match result {
411            Value::ComplexTensor(t) => {
412                assert_eq!(t.shape, vec![2, 1]);
413                let mut roots = t.data;
414                roots.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
415                assert!((roots[0].0).abs() < 1e-10);
416                assert!((roots[0].1 + 1.0).abs() < 1e-10);
417                assert!((roots[1].0).abs() < 1e-10);
418                assert!((roots[1].1 - 1.0).abs() < 1e-10);
419            }
420            other => panic!("expected complex tensor, got {other:?}"),
421        }
422    }
423
424    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
425    #[test]
426    fn roots_quartic_all_zero_roots() {
427        // p(x) = x^4 => 4 roots at 0
428        let coeffs = Tensor::new(vec![1.0, 0.0, 0.0, 0.0, 0.0], vec![5, 1]).unwrap();
429        let result = roots_builtin(Value::Tensor(coeffs)).expect("roots quartic");
430        match result {
431            Value::Tensor(t) => {
432                assert_eq!(t.shape, vec![4, 1]);
433                for &r in &t.data {
434                    assert!(r.abs() < 1e-8);
435                }
436            }
437            Value::ComplexTensor(t) => {
438                assert_eq!(t.shape, vec![4, 1]);
439                for &(re, im) in &t.data {
440                    assert!(re.abs() < 1e-7 && im.abs() < 1e-7);
441                }
442            }
443            other => panic!("unexpected output {other:?}"),
444        }
445    }
446
447    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
448    #[test]
449    fn roots_accepts_complex_coefficients_input() {
450        // p(x) = x^2 + 1 with complex coefficients path
451        let coeffs =
452            ComplexTensor::new(vec![(1.0, 0.0), (0.0, 0.0), (1.0, 0.0)], vec![3, 1]).unwrap();
453        let result = roots_builtin(Value::ComplexTensor(coeffs)).expect("roots complex input");
454        match result {
455            Value::ComplexTensor(t) => {
456                assert_eq!(t.shape, vec![2, 1]);
457                // roots at i and -i
458                let mut roots = t.data;
459                roots.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
460                assert!(roots[0].0.abs() < 1e-10 && (roots[0].1 + 1.0).abs() < 1e-6);
461                assert!(roots[1].0.abs() < 1e-10 && (roots[1].1 - 1.0).abs() < 1e-6);
462            }
463            other => panic!("expected complex tensor, got {other:?}"),
464        }
465    }
466
467    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
468    #[test]
469    fn roots_accepts_logical_coefficients() {
470        // p(x) = x with logical coefficients [1 0]
471        let la = LogicalArray::new(vec![1, 0], vec![1, 2]).unwrap();
472        let result = roots_builtin(Value::LogicalArray(la)).expect("roots logical");
473        match result {
474            Value::Tensor(t) => {
475                assert_eq!(t.shape, vec![1, 1]);
476                assert!(t.data[0].abs() < 1e-12);
477            }
478            other => panic!("expected real tensor, got {other:?}"),
479        }
480    }
481
482    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
483    #[test]
484    fn roots_scalar_num_returns_empty() {
485        let result = roots_builtin(Value::Num(5.0)).expect("roots scalar num");
486        match result {
487            Value::Tensor(t) => {
488                assert_eq!(t.shape, vec![0, 1]);
489                assert!(t.data.is_empty());
490            }
491            other => panic!("expected empty tensor, got {other:?}"),
492        }
493    }
494
495    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
496    #[test]
497    fn roots_rejects_non_vector_input() {
498        let coeffs = Tensor::new(vec![1.0, 0.0, 0.0, 1.0], vec![2, 2]).unwrap();
499        let err = roots_builtin(Value::Tensor(coeffs)).expect_err("expected vector-shape error");
500        assert_eq!(err.identifier(), ROOTS_ERROR_INVALID_INPUT.identifier);
501        assert_error_contains(err, "vector");
502    }
503
504    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
505    #[test]
506    fn roots_all_zero_coefficients_returns_empty() {
507        let coeffs = Tensor::new(vec![0.0, 0.0, 0.0], vec![3, 1]).unwrap();
508        let result = roots_builtin(Value::Tensor(coeffs)).expect("roots");
509        match result {
510            Value::Tensor(t) => {
511                assert_eq!(t.shape, vec![0, 1]);
512                assert!(t.data.is_empty());
513            }
514            other => panic!("expected empty tensor, got {other:?}"),
515        }
516    }
517
518    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
519    #[test]
520    fn roots_gpu_input_gathers_to_host() {
521        test_support::with_test_provider(|provider| {
522            let coeffs = Tensor::new(vec![1.0, 0.0, -9.0, 0.0], vec![4, 1]).unwrap();
523            let view = HostTensorView {
524                data: &coeffs.data,
525                shape: &coeffs.shape,
526            };
527            let handle = provider.upload(&view).expect("upload");
528            let result = roots_builtin(Value::GpuTensor(handle)).expect("roots");
529            let gathered = test_support::gather(result).expect("gather");
530            assert_eq!(gathered.shape, vec![3, 1]);
531            let mut roots = gathered.data;
532            roots.sort_by(|a, b| a.partial_cmp(b).unwrap());
533            assert!((roots[0] + 3.0).abs() < 1e-9);
534            assert!((roots[1]).abs() < 1e-9);
535            assert!((roots[2] - 3.0).abs() < 1e-9);
536        });
537    }
538
539    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
540    #[test]
541    fn roots_constant_polynomial_returns_empty() {
542        let coeffs = Tensor::new(vec![5.0], vec![1, 1]).unwrap();
543        let result = roots_builtin(Value::Tensor(coeffs)).expect("roots");
544        match result {
545            Value::Tensor(t) => {
546                assert_eq!(t.shape, vec![0, 1]);
547            }
548            other => panic!("expected empty tensor, got {other:?}"),
549        }
550    }
551
552    fn roots_builtin(coefficients: Value) -> BuiltinResult<Value> {
553        block_on(super::roots_builtin(coefficients))
554    }
555}