runmat_runtime/builtins/math/poly/
roots.rs

1//! MATLAB-compatible `roots` builtin with GPU-aware semantics for RunMat.
2//!
3//! This implementation mirrors MATLAB behaviour, including handling for leading
4//! zeros, constant polynomials, and complex-valued coefficients. GPU inputs are
5//! gathered to the host because companion matrix eigenvalue computations are
6//! currently performed on the CPU.
7
8use nalgebra::DMatrix;
9use num_complex::Complex64;
10use runmat_builtins::{ComplexTensor, Tensor, Value};
11use runmat_macros::runtime_builtin;
12
13use crate::builtins::common::spec::{
14    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
15    ReductionNaN, ResidencyPolicy, ShapeRequirements,
16};
17use crate::builtins::common::{gpu_helpers, tensor};
18#[cfg(feature = "doc_export")]
19use crate::register_builtin_doc_text;
20use crate::{register_builtin_fusion_spec, register_builtin_gpu_spec};
21
22const LEADING_ZERO_TOL: f64 = 1.0e-12;
23const RESULT_ZERO_TOL: f64 = 1.0e-10;
24
25#[cfg(feature = "doc_export")]
26pub const DOC_MD: &str = r#"---
27title: "roots"
28category: "math/poly"
29keywords: ["roots", "polynomial roots", "companion matrix", "eigenvalues", "gpu"]
30summary: "Compute the roots of a polynomial specified by its coefficients, matching MATLAB semantics including complex output."
31references:
32  - title: "MATLAB roots documentation"
33    url: "https://www.mathworks.com/help/matlab/ref/roots.html"
34  - title: "Golub & Van Loan – Matrix Computations, Chapter 7"
35    url: "https://doi.org/10.56021/9781421407944"
36gpu_support:
37  elementwise: false
38  reduction: false
39  precisions: []
40  broadcasting: "none"
41  notes: "Roots builds a companion matrix and computes its eigenvalues on the CPU; GPU inputs are gathered automatically."
42fusion:
43  elementwise: false
44  reduction: false
45  max_inputs: 1
46  constants: "inline"
47requires_feature: null
48tested:
49  unit: "builtins::math::poly::roots::tests"
50  integration: "builtins::math::poly::roots::tests::roots_gpu_input_gathers_to_host"
51---
52
53# What does the `roots` function do in MATLAB / RunMat?
54`roots(p)` returns the zeros of the polynomial whose coefficients are stored in `p`, with coefficients ordered from the highest power of `x` to the constant term. The result is always a column vector whose entries may be complex.
55
56## How does the `roots` function behave in MATLAB / RunMat?
57- Leading zeros in the coefficient vector are discarded before solving. If all coefficients are zero, the result is an empty column vector.
58- Constant polynomials (degree 0) produce an empty output because they have no finite roots.
59- Linear polynomials return the single solution `-b/a`. Higher-degree polynomials are solved via the eigenvalues of the companion matrix.
60- Real coefficients can generate complex conjugate root pairs. Small imaginary round-off terms are rounded to zero to match MATLAB formatting.
61- Input vectors can be row or column vectors. Higher-dimensional arrays are rejected.
62- Inputs may be real or complex. Logical and integer types are converted to double precision automatically.
63
64## `roots` Function GPU Execution Behaviour
65RunMat gathers GPU-resident coefficient vectors to the host because the companion matrix eigenvalue computation presently runs only on the CPU. The output is produced on the host as well. When future providers supply a dedicated polynomial root solver, the builtin can be updated to keep residency on-device transparently.
66
67## Examples of using the `roots` function in MATLAB / RunMat
68
69### Finding roots of a quadratic polynomial
70
71```matlab
72p = [1 -3 2];
73r = roots(p);
74```
75
76Expected output:
77
78```matlab
79r =
80     2
81     1
82```
83
84### Computing roots that include repeated factors
85
86```matlab
87p = [1 -2 1 0];   % (x - 1)^2 * x
88r = roots(p);
89```
90
91Expected output:
92
93```matlab
94r =
95     1
96     1
97     0
98```
99
100### Handling leading zeros in the coefficient vector
101
102```matlab
103p = [0 0 1 -4];
104r = roots(p);
105```
106
107Expected output:
108
109```matlab
110r =
111     4
112```
113
114### Calculating complex roots from real coefficients
115
116```matlab
117p = [1 0 1];
118r = roots(p);
119```
120
121Expected output:
122
123```matlab
124r =
125   0.0000 + 1.0000i
126   0.0000 - 1.0000i
127```
128
129### Solving roots of a polynomial stored on the GPU
130
131```matlab
132p = gpuArray([1 0 -9 0]);
133r = roots(p);
134```
135
136Expected output:
137
138```matlab
139r =
140    3.0000
141   -3.0000
142         0
143```
144
145The coefficients are gathered automatically, so no manual `gather` call is required.
146
147## FAQ
148
149### What shape must the coefficient vector have?
150`roots` accepts row vectors, column vectors, or 1-D arrays. Higher-dimensional tensors are rejected with an error.
151
152### How are leading zeros handled?
153Leading zeros are removed before solving. If all coefficients are zero, `roots` returns an empty column vector.
154
155### Does `roots` preserve the data type of the coefficients?
156Coefficients are promoted to double precision internally. The output is a double vector when all roots are real and a complex double vector otherwise.
157
158### Are the roots sorted?
159Roots are returned in the order supplied by the eigenvalue computation (typically descending magnitude). MATLAB also does not sort the roots.
160
161### Can I run `roots` entirely on the GPU?
162Not yet. RunMat gathers coefficients from the GPU, solves the companion matrix on the CPU, and returns a host-resident vector. When GPU providers add a polynomial root solver, this builtin will automatically route to it.
163
164### How does RunMat handle numerical round-off?
165Small imaginary components (|imag| ≤ 1e-10·(1 + |real|)) are rounded to zero so that near-real roots are displayed as real numbers, matching MATLAB formatting.
166
167## See Also
168[polyval](./polyval), [polyfit](./polyfit), [residue](../signal/residue), [roots documentation (MathWorks)](https://www.mathworks.com/help/matlab/ref/roots.html)
169
170## Source & Feedback
171- The full source code for the implementation of the `roots` function is available at: [`crates/runmat-runtime/src/builtins/math/poly/roots.rs`](https://github.com/runmat-org/runmat/blob/main/crates/runmat-runtime/src/builtins/math/poly/roots.rs)
172- Found a bug or behavioral difference? Please [open an issue](https://github.com/runmat-org/runmat/issues/new/choose) with details and a minimal repro.
173"#;
174
175pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
176    name: "roots",
177    op_kind: GpuOpKind::Custom("polynomial-roots"),
178    supported_precisions: &[],
179    broadcast: BroadcastSemantics::None,
180    provider_hooks: &[],
181    constant_strategy: ConstantStrategy::InlineLiteral,
182    residency: ResidencyPolicy::GatherImmediately,
183    nan_mode: ReductionNaN::Include,
184    two_pass_threshold: None,
185    workgroup_size: None,
186    accepts_nan_mode: false,
187    notes: "Companion matrix eigenvalue solve executes on the host; providers currently fall back to the CPU implementation.",
188};
189
190register_builtin_gpu_spec!(GPU_SPEC);
191
192pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
193    name: "roots",
194    shape: ShapeRequirements::Any,
195    constant_strategy: ConstantStrategy::InlineLiteral,
196    elementwise: None,
197    reduction: None,
198    emits_nan: true,
199    notes: "Non-elementwise builtin that terminates fusion and gathers inputs to the host.",
200};
201
202register_builtin_fusion_spec!(FUSION_SPEC);
203
204#[cfg(feature = "doc_export")]
205register_builtin_doc_text!("roots", DOC_MD);
206
207#[runtime_builtin(
208    name = "roots",
209    category = "math/poly",
210    summary = "Compute the roots of a polynomial specified by its coefficients.",
211    keywords = "roots,polynomial,eigenvalues,companion",
212    accel = "sink"
213)]
214fn roots_builtin(coefficients: Value) -> Result<Value, String> {
215    let coeffs = coefficients_to_complex(coefficients)?;
216    let trimmed = trim_leading_zeros(coeffs);
217    if trimmed.is_empty() || trimmed.len() == 1 {
218        return empty_column();
219    }
220    let roots = solve_roots(&trimmed)?;
221    roots_to_value(&roots)
222}
223
224fn coefficients_to_complex(value: Value) -> Result<Vec<Complex64>, String> {
225    match value {
226        Value::GpuTensor(handle) => {
227            let tensor = gpu_helpers::gather_tensor(&handle)?;
228            tensor_to_complex(tensor)
229        }
230        Value::Tensor(tensor) => tensor_to_complex(tensor),
231        Value::ComplexTensor(tensor) => complex_tensor_to_vec(tensor),
232        Value::LogicalArray(logical) => {
233            let tensor = tensor::logical_to_tensor(&logical)?;
234            tensor_to_complex(tensor)
235        }
236        Value::Num(n) => {
237            let tensor = Tensor::new(vec![n], vec![1, 1]).map_err(|e| format!("roots: {e}"))?;
238            tensor_to_complex(tensor)
239        }
240        Value::Int(i) => {
241            let tensor =
242                Tensor::new(vec![i.to_f64()], vec![1, 1]).map_err(|e| format!("roots: {e}"))?;
243            tensor_to_complex(tensor)
244        }
245        Value::Bool(b) => {
246            let tensor = Tensor::new(vec![if b { 1.0 } else { 0.0 }], vec![1, 1])
247                .map_err(|e| format!("roots: {e}"))?;
248            tensor_to_complex(tensor)
249        }
250        other => Err(format!(
251            "roots: expected a numeric vector of polynomial coefficients, got {other:?}"
252        )),
253    }
254}
255
256fn tensor_to_complex(tensor: Tensor) -> Result<Vec<Complex64>, String> {
257    ensure_vector_shape("roots", &tensor.shape)?;
258    Ok(tensor
259        .data
260        .into_iter()
261        .map(|value| Complex64::new(value, 0.0))
262        .collect())
263}
264
265fn complex_tensor_to_vec(tensor: ComplexTensor) -> Result<Vec<Complex64>, String> {
266    ensure_vector_shape("roots", &tensor.shape)?;
267    Ok(tensor
268        .data
269        .into_iter()
270        .map(|(re, im)| Complex64::new(re, im))
271        .collect())
272}
273
274fn ensure_vector_shape(name: &str, shape: &[usize]) -> Result<(), String> {
275    let is_vector = match shape.len() {
276        0 => true,
277        1 => true,
278        2 => shape[0] == 1 || shape[1] == 1 || shape.iter().product::<usize>() == 0,
279        _ => shape.iter().filter(|&&dim| dim > 1).count() <= 1,
280    };
281    if !is_vector {
282        return Err(format!(
283            "{name}: coefficients must be a vector (row or column), got shape {:?}",
284            shape
285        ));
286    }
287    Ok(())
288}
289
290fn trim_leading_zeros(mut coeffs: Vec<Complex64>) -> Vec<Complex64> {
291    if coeffs.is_empty() {
292        return coeffs;
293    }
294    let scale = coeffs.iter().map(|c| c.norm()).fold(0.0_f64, f64::max);
295    let tol = if scale == 0.0 {
296        LEADING_ZERO_TOL
297    } else {
298        LEADING_ZERO_TOL * scale
299    };
300    let first_nonzero = coeffs
301        .iter()
302        .position(|c| c.norm() > tol)
303        .unwrap_or(coeffs.len());
304    coeffs.split_off(first_nonzero)
305}
306
307fn solve_roots(coeffs: &[Complex64]) -> Result<Vec<Complex64>, String> {
308    if coeffs.len() <= 1 {
309        return Ok(Vec::new());
310    }
311    if coeffs.len() == 2 {
312        let a = coeffs[0];
313        let b = coeffs[1];
314        if a.norm() <= LEADING_ZERO_TOL {
315            return Err("roots: leading coefficient must be non-zero after trimming".to_string());
316        }
317        return Ok(vec![-b / a]);
318    }
319
320    let degree = coeffs.len() - 1;
321    if degree == 3 {
322        return Ok(cubic_roots(coeffs[0], coeffs[1], coeffs[2], coeffs[3]));
323    }
324    let leading = coeffs[0];
325    if leading.norm() <= LEADING_ZERO_TOL {
326        return Err("roots: leading coefficient must be non-zero after trimming".to_string());
327    }
328
329    let mut companion = DMatrix::<Complex64>::zeros(degree, degree);
330    for row in 1..degree {
331        companion[(row, row - 1)] = Complex64::new(1.0, 0.0);
332    }
333
334    for (idx, coeff) in coeffs.iter().enumerate().skip(1) {
335        let value = -(*coeff) / leading;
336        let column = idx - 1;
337        if column < degree {
338            companion[(0, column)] = value;
339        }
340    }
341
342    let eigenvalues = companion.clone().eigenvalues().ok_or_else(|| {
343        "roots: failed to compute eigenvalues of the companion matrix".to_string()
344    })?;
345    Ok(eigenvalues.iter().map(|&z| canonicalize_root(z)).collect())
346}
347
348fn cubic_roots(a: Complex64, b: Complex64, c: Complex64, d: Complex64) -> Vec<Complex64> {
349    // Depressed cubic via Cardano: x = y - b/(3a), y^3 + p y + q = 0
350    let three = 3.0;
351    let nine = 9.0;
352    let twenty_seven = 27.0;
353    let a2 = a * a;
354    let a3 = a2 * a;
355    let p = (three * a * c - b * b) / (three * a2);
356    let q = (twenty_seven * a2 * d - nine * a * b * c + Complex64::new(2.0, 0.0) * b * b * b)
357        / (twenty_seven * a3);
358    let half = Complex64::new(0.5, 0.0);
359    let disc = (q * q) * half * half + (p * p * p) / Complex64::new(27.0, 0.0);
360    let sqrt_disc = disc.sqrt();
361    let u = (-q * half + sqrt_disc).powf(1.0 / 3.0);
362    let v = (-q * half - sqrt_disc).powf(1.0 / 3.0);
363    let omega = Complex64::new(-0.5, (3.0f64).sqrt() * 0.5);
364    let omega2 = omega * omega;
365    let shift = b / (three * a);
366    let y0 = u + v;
367    let y1 = u * omega + v * omega.conj();
368    let y2 = u * omega2 + v * omega;
369    vec![y0 - shift, y1 - shift, y2 - shift]
370}
371
372fn canonicalize_root(z: Complex64) -> Complex64 {
373    if !z.re.is_finite() || !z.im.is_finite() {
374        return z;
375    }
376    let mut real = z.re;
377    let mut imag = z.im;
378    let scale = 1.0 + real.abs();
379    if imag.abs() <= RESULT_ZERO_TOL * scale {
380        imag = 0.0;
381    }
382    if real.abs() <= RESULT_ZERO_TOL {
383        real = 0.0;
384    }
385    Complex64::new(real, imag)
386}
387
388fn roots_to_value(roots: &[Complex64]) -> Result<Value, String> {
389    if roots.is_empty() {
390        return empty_column();
391    }
392    let all_real = roots
393        .iter()
394        .all(|z| z.im.abs() <= RESULT_ZERO_TOL * (1.0 + z.re.abs()));
395    if all_real {
396        let mut data: Vec<f64> = Vec::with_capacity(roots.len());
397        for &root in roots {
398            data.push(root.re);
399        }
400        let tensor = Tensor::new(data, vec![roots.len(), 1]).map_err(|e| format!("roots: {e}"))?;
401        Ok(Value::Tensor(tensor))
402    } else {
403        let data: Vec<(f64, f64)> = roots.iter().map(|z| (z.re, z.im)).collect();
404        let tensor =
405            ComplexTensor::new(data, vec![roots.len(), 1]).map_err(|e| format!("roots: {e}"))?;
406        Ok(Value::ComplexTensor(tensor))
407    }
408}
409
410fn empty_column() -> Result<Value, String> {
411    let tensor = Tensor::new(Vec::new(), vec![0, 1]).map_err(|e| format!("roots: {e}"))?;
412    Ok(Value::Tensor(tensor))
413}
414
415#[cfg(test)]
416mod tests {
417    use super::*;
418    use crate::builtins::common::test_support;
419    use runmat_accelerate_api::HostTensorView;
420    use runmat_builtins::{ComplexTensor, LogicalArray, Tensor};
421
422    #[test]
423    fn roots_quadratic_real() {
424        let coeffs = Tensor::new(vec![1.0, -3.0, 2.0], vec![3, 1]).unwrap();
425        let result = roots_builtin(Value::Tensor(coeffs)).expect("roots");
426        match result {
427            Value::Tensor(t) => {
428                assert_eq!(t.shape, vec![2, 1]);
429                let mut roots = t.data;
430                roots.sort_by(|a, b| a.partial_cmp(b).unwrap());
431                assert!((roots[0] - 1.0).abs() < 1e-10);
432                assert!((roots[1] - 2.0).abs() < 1e-10);
433            }
434            other => panic!("expected real tensor, got {other:?}"),
435        }
436    }
437
438    #[test]
439    fn roots_leading_zeros_trimmed() {
440        let coeffs = Tensor::new(vec![0.0, 0.0, 1.0, -4.0], vec![4, 1]).unwrap();
441        let result = roots_builtin(Value::Tensor(coeffs)).expect("roots");
442        match result {
443            Value::Tensor(t) => {
444                assert_eq!(t.shape, vec![1, 1]);
445                assert!((t.data[0] - 4.0).abs() < 1e-10);
446            }
447            other => panic!("expected tensor, got {other:?}"),
448        }
449    }
450
451    #[test]
452    fn roots_complex_pair() {
453        let coeffs = Tensor::new(vec![1.0, 0.0, 1.0], vec![3, 1]).unwrap();
454        let result = roots_builtin(Value::Tensor(coeffs)).expect("roots");
455        match result {
456            Value::ComplexTensor(t) => {
457                assert_eq!(t.shape, vec![2, 1]);
458                let mut roots = t.data;
459                roots.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
460                assert!((roots[0].0).abs() < 1e-10);
461                assert!((roots[0].1 + 1.0).abs() < 1e-10);
462                assert!((roots[1].0).abs() < 1e-10);
463                assert!((roots[1].1 - 1.0).abs() < 1e-10);
464            }
465            other => panic!("expected complex tensor, got {other:?}"),
466        }
467    }
468
469    #[test]
470    fn roots_quartic_all_zero_roots() {
471        // p(x) = x^4 => 4 roots at 0
472        let coeffs = Tensor::new(vec![1.0, 0.0, 0.0, 0.0, 0.0], vec![5, 1]).unwrap();
473        let result = roots_builtin(Value::Tensor(coeffs)).expect("roots quartic");
474        match result {
475            Value::Tensor(t) => {
476                assert_eq!(t.shape, vec![4, 1]);
477                for &r in &t.data {
478                    assert!(r.abs() < 1e-8);
479                }
480            }
481            Value::ComplexTensor(t) => {
482                assert_eq!(t.shape, vec![4, 1]);
483                for &(re, im) in &t.data {
484                    assert!(re.abs() < 1e-7 && im.abs() < 1e-7);
485                }
486            }
487            other => panic!("unexpected output {other:?}"),
488        }
489    }
490
491    #[test]
492    fn roots_accepts_complex_coefficients_input() {
493        // p(x) = x^2 + 1 with complex coefficients path
494        let coeffs =
495            ComplexTensor::new(vec![(1.0, 0.0), (0.0, 0.0), (1.0, 0.0)], vec![3, 1]).unwrap();
496        let result = roots_builtin(Value::ComplexTensor(coeffs)).expect("roots complex input");
497        match result {
498            Value::ComplexTensor(t) => {
499                assert_eq!(t.shape, vec![2, 1]);
500                // roots at i and -i
501                let mut roots = t.data;
502                roots.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
503                assert!(roots[0].0.abs() < 1e-10 && (roots[0].1 + 1.0).abs() < 1e-6);
504                assert!(roots[1].0.abs() < 1e-10 && (roots[1].1 - 1.0).abs() < 1e-6);
505            }
506            other => panic!("expected complex tensor, got {other:?}"),
507        }
508    }
509
510    #[test]
511    fn roots_accepts_logical_coefficients() {
512        // p(x) = x with logical coefficients [1 0]
513        let la = LogicalArray::new(vec![1, 0], vec![1, 2]).unwrap();
514        let result = roots_builtin(Value::LogicalArray(la)).expect("roots logical");
515        match result {
516            Value::Tensor(t) => {
517                assert_eq!(t.shape, vec![1, 1]);
518                assert!(t.data[0].abs() < 1e-12);
519            }
520            other => panic!("expected real tensor, got {other:?}"),
521        }
522    }
523
524    #[test]
525    fn roots_scalar_num_returns_empty() {
526        let result = roots_builtin(Value::Num(5.0)).expect("roots scalar num");
527        match result {
528            Value::Tensor(t) => {
529                assert_eq!(t.shape, vec![0, 1]);
530                assert!(t.data.is_empty());
531            }
532            other => panic!("expected empty tensor, got {other:?}"),
533        }
534    }
535
536    #[test]
537    fn roots_rejects_non_vector_input() {
538        let coeffs = Tensor::new(vec![1.0, 0.0, 0.0, 1.0], vec![2, 2]).unwrap();
539        let err = roots_builtin(Value::Tensor(coeffs)).expect_err("expected vector-shape error");
540        assert!(err.to_lowercase().contains("vector"));
541    }
542
543    #[test]
544    fn roots_all_zero_coefficients_returns_empty() {
545        let coeffs = Tensor::new(vec![0.0, 0.0, 0.0], vec![3, 1]).unwrap();
546        let result = roots_builtin(Value::Tensor(coeffs)).expect("roots");
547        match result {
548            Value::Tensor(t) => {
549                assert_eq!(t.shape, vec![0, 1]);
550                assert!(t.data.is_empty());
551            }
552            other => panic!("expected empty tensor, got {other:?}"),
553        }
554    }
555
556    #[test]
557    fn roots_gpu_input_gathers_to_host() {
558        test_support::with_test_provider(|provider| {
559            let coeffs = Tensor::new(vec![1.0, 0.0, -9.0, 0.0], vec![4, 1]).unwrap();
560            let view = HostTensorView {
561                data: &coeffs.data,
562                shape: &coeffs.shape,
563            };
564            let handle = provider.upload(&view).expect("upload");
565            let result = roots_builtin(Value::GpuTensor(handle)).expect("roots");
566            let gathered = test_support::gather(result).expect("gather");
567            assert_eq!(gathered.shape, vec![3, 1]);
568            let mut roots = gathered.data;
569            roots.sort_by(|a, b| a.partial_cmp(b).unwrap());
570            assert!((roots[0] + 3.0).abs() < 1e-9);
571            assert!((roots[1]).abs() < 1e-9);
572            assert!((roots[2] - 3.0).abs() < 1e-9);
573        });
574    }
575
576    #[test]
577    fn roots_constant_polynomial_returns_empty() {
578        let coeffs = Tensor::new(vec![5.0], vec![1, 1]).unwrap();
579        let result = roots_builtin(Value::Tensor(coeffs)).expect("roots");
580        match result {
581            Value::Tensor(t) => {
582                assert_eq!(t.shape, vec![0, 1]);
583            }
584            other => panic!("expected empty tensor, got {other:?}"),
585        }
586    }
587
588    #[test]
589    #[cfg(feature = "doc_export")]
590    fn doc_examples_present() {
591        let blocks = test_support::doc_examples(DOC_MD);
592        assert!(!blocks.is_empty());
593    }
594}