Skip to main content

runmat_runtime/builtins/math/poly/
roots.rs

1//! MATLAB-compatible `roots` builtin with GPU-aware semantics for RunMat.
2//!
3//! This implementation mirrors MATLAB behaviour, including handling for leading
4//! zeros, constant polynomials, and complex-valued coefficients. GPU inputs are
5//! gathered to the host because companion matrix eigenvalue computations are
6//! currently performed on the CPU.
7
8use nalgebra::DMatrix;
9use num_complex::Complex64;
10use runmat_builtins::{ComplexTensor, Tensor, Value};
11use runmat_macros::runtime_builtin;
12
13use crate::builtins::common::spec::{
14    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
15    ReductionNaN, ResidencyPolicy, ShapeRequirements,
16};
17use crate::builtins::common::{gpu_helpers, tensor};
18use crate::builtins::math::poly::type_resolvers::roots_type;
19use crate::{build_runtime_error, BuiltinResult, RuntimeError};
20
21const LEADING_ZERO_TOL: f64 = 1.0e-12;
22const RESULT_ZERO_TOL: f64 = 1.0e-10;
23const BUILTIN_NAME: &str = "roots";
24
25#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::math::poly::roots")]
26pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
27    name: "roots",
28    op_kind: GpuOpKind::Custom("polynomial-roots"),
29    supported_precisions: &[],
30    broadcast: BroadcastSemantics::None,
31    provider_hooks: &[],
32    constant_strategy: ConstantStrategy::InlineLiteral,
33    residency: ResidencyPolicy::GatherImmediately,
34    nan_mode: ReductionNaN::Include,
35    two_pass_threshold: None,
36    workgroup_size: None,
37    accepts_nan_mode: false,
38    notes: "Companion matrix eigenvalue solve executes on the host; providers currently fall back to the CPU implementation.",
39};
40
41fn roots_error(message: impl Into<String>) -> RuntimeError {
42    build_runtime_error(message)
43        .with_builtin(BUILTIN_NAME)
44        .build()
45}
46
47#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::math::poly::roots")]
48pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
49    name: "roots",
50    shape: ShapeRequirements::Any,
51    constant_strategy: ConstantStrategy::InlineLiteral,
52    elementwise: None,
53    reduction: None,
54    emits_nan: true,
55    notes: "Non-elementwise builtin that terminates fusion and gathers inputs to the host.",
56};
57
58#[runtime_builtin(
59    name = "roots",
60    category = "math/poly",
61    summary = "Compute the roots of a polynomial specified by its coefficients.",
62    keywords = "roots,polynomial,eigenvalues,companion",
63    accel = "sink",
64    type_resolver(roots_type),
65    builtin_path = "crate::builtins::math::poly::roots"
66)]
67async fn roots_builtin(coefficients: Value) -> crate::BuiltinResult<Value> {
68    let coeffs = coefficients_to_complex(coefficients).await?;
69    let trimmed = trim_leading_zeros(coeffs);
70    if trimmed.is_empty() || trimmed.len() == 1 {
71        return empty_column();
72    }
73    let roots = solve_roots(&trimmed)?;
74    roots_to_value(&roots)
75}
76
77async fn coefficients_to_complex(value: Value) -> BuiltinResult<Vec<Complex64>> {
78    match value {
79        Value::GpuTensor(handle) => {
80            let tensor = gpu_helpers::gather_tensor_async(&handle).await?;
81            tensor_to_complex(tensor)
82        }
83        Value::Tensor(tensor) => tensor_to_complex(tensor),
84        Value::ComplexTensor(tensor) => complex_tensor_to_vec(tensor),
85        Value::LogicalArray(logical) => {
86            let tensor = tensor::logical_to_tensor(&logical).map_err(roots_error)?;
87            tensor_to_complex(tensor)
88        }
89        Value::Num(n) => {
90            let tensor =
91                Tensor::new(vec![n], vec![1, 1]).map_err(|e| roots_error(format!("roots: {e}")))?;
92            tensor_to_complex(tensor)
93        }
94        Value::Int(i) => {
95            let tensor = Tensor::new(vec![i.to_f64()], vec![1, 1])
96                .map_err(|e| roots_error(format!("roots: {e}")))?;
97            tensor_to_complex(tensor)
98        }
99        Value::Bool(b) => {
100            let tensor = Tensor::new(vec![if b { 1.0 } else { 0.0 }], vec![1, 1])
101                .map_err(|e| roots_error(format!("roots: {e}")))?;
102            tensor_to_complex(tensor)
103        }
104        other => Err(roots_error(format!(
105            "roots: expected a numeric vector of polynomial coefficients, got {other:?}"
106        ))),
107    }
108}
109
110fn tensor_to_complex(tensor: Tensor) -> BuiltinResult<Vec<Complex64>> {
111    ensure_vector_shape("roots", &tensor.shape)?;
112    Ok(tensor
113        .data
114        .into_iter()
115        .map(|value| Complex64::new(value, 0.0))
116        .collect())
117}
118
119fn complex_tensor_to_vec(tensor: ComplexTensor) -> BuiltinResult<Vec<Complex64>> {
120    ensure_vector_shape("roots", &tensor.shape)?;
121    Ok(tensor
122        .data
123        .into_iter()
124        .map(|(re, im)| Complex64::new(re, im))
125        .collect())
126}
127
128fn ensure_vector_shape(name: &str, shape: &[usize]) -> BuiltinResult<()> {
129    let is_vector = match shape.len() {
130        0 => true,
131        1 => true,
132        2 => shape[0] == 1 || shape[1] == 1 || shape.iter().product::<usize>() == 0,
133        _ => shape.iter().filter(|&&dim| dim > 1).count() <= 1,
134    };
135    if !is_vector {
136        return Err(roots_error(format!(
137            "{name}: coefficients must be a vector (row or column), got shape {:?}",
138            shape
139        )));
140    }
141    Ok(())
142}
143
144fn trim_leading_zeros(mut coeffs: Vec<Complex64>) -> Vec<Complex64> {
145    if coeffs.is_empty() {
146        return coeffs;
147    }
148    let scale = coeffs.iter().map(|c| c.norm()).fold(0.0_f64, f64::max);
149    let tol = if scale == 0.0 {
150        LEADING_ZERO_TOL
151    } else {
152        LEADING_ZERO_TOL * scale
153    };
154    let first_nonzero = coeffs
155        .iter()
156        .position(|c| c.norm() > tol)
157        .unwrap_or(coeffs.len());
158    coeffs.split_off(first_nonzero)
159}
160
161fn solve_roots(coeffs: &[Complex64]) -> BuiltinResult<Vec<Complex64>> {
162    if coeffs.len() <= 1 {
163        return Ok(Vec::new());
164    }
165    if coeffs.len() == 2 {
166        let a = coeffs[0];
167        let b = coeffs[1];
168        if a.norm() <= LEADING_ZERO_TOL {
169            return Err(roots_error(
170                "roots: leading coefficient must be non-zero after trimming",
171            ));
172        }
173        return Ok(vec![-b / a]);
174    }
175
176    let degree = coeffs.len() - 1;
177    if degree == 3 {
178        return Ok(cubic_roots(coeffs[0], coeffs[1], coeffs[2], coeffs[3]));
179    }
180    let leading = coeffs[0];
181    if leading.norm() <= LEADING_ZERO_TOL {
182        return Err(roots_error(
183            "roots: leading coefficient must be non-zero after trimming",
184        ));
185    }
186
187    let mut companion = DMatrix::<Complex64>::zeros(degree, degree);
188    for row in 1..degree {
189        companion[(row, row - 1)] = Complex64::new(1.0, 0.0);
190    }
191
192    for (idx, coeff) in coeffs.iter().enumerate().skip(1) {
193        let value = -(*coeff) / leading;
194        let column = idx - 1;
195        if column < degree {
196            companion[(0, column)] = value;
197        }
198    }
199
200    let eigenvalues = companion.clone().eigenvalues().ok_or_else(|| {
201        roots_error("roots: failed to compute eigenvalues of the companion matrix")
202    })?;
203    Ok(eigenvalues.iter().map(|&z| canonicalize_root(z)).collect())
204}
205
206fn cubic_roots(a: Complex64, b: Complex64, c: Complex64, d: Complex64) -> Vec<Complex64> {
207    // Depressed cubic via Cardano: x = y - b/(3a), y^3 + p y + q = 0
208    let three = 3.0;
209    let nine = 9.0;
210    let twenty_seven = 27.0;
211    let a2 = a * a;
212    let a3 = a2 * a;
213    let p = (three * a * c - b * b) / (three * a2);
214    let q = (twenty_seven * a2 * d - nine * a * b * c + Complex64::new(2.0, 0.0) * b * b * b)
215        / (twenty_seven * a3);
216    let half = Complex64::new(0.5, 0.0);
217    let disc = (q * q) * half * half + (p * p * p) / Complex64::new(27.0, 0.0);
218    let sqrt_disc = disc.sqrt();
219    let u = (-q * half + sqrt_disc).powf(1.0 / 3.0);
220    let v = (-q * half - sqrt_disc).powf(1.0 / 3.0);
221    let omega = Complex64::new(-0.5, (3.0f64).sqrt() * 0.5);
222    let omega2 = omega * omega;
223    let shift = b / (three * a);
224    let y0 = u + v;
225    let y1 = u * omega + v * omega.conj();
226    let y2 = u * omega2 + v * omega;
227    vec![y0 - shift, y1 - shift, y2 - shift]
228}
229
230fn canonicalize_root(z: Complex64) -> Complex64 {
231    if !z.re.is_finite() || !z.im.is_finite() {
232        return z;
233    }
234    let mut real = z.re;
235    let mut imag = z.im;
236    let scale = 1.0 + real.abs();
237    if imag.abs() <= RESULT_ZERO_TOL * scale {
238        imag = 0.0;
239    }
240    if real.abs() <= RESULT_ZERO_TOL {
241        real = 0.0;
242    }
243    Complex64::new(real, imag)
244}
245
246fn roots_to_value(roots: &[Complex64]) -> BuiltinResult<Value> {
247    if roots.is_empty() {
248        return empty_column();
249    }
250    let all_real = roots
251        .iter()
252        .all(|z| z.im.abs() <= RESULT_ZERO_TOL * (1.0 + z.re.abs()));
253    if all_real {
254        let mut data: Vec<f64> = Vec::with_capacity(roots.len());
255        for &root in roots {
256            data.push(root.re);
257        }
258        let tensor = Tensor::new(data, vec![roots.len(), 1])
259            .map_err(|e| roots_error(format!("roots: {e}")))?;
260        Ok(Value::Tensor(tensor))
261    } else {
262        let data: Vec<(f64, f64)> = roots.iter().map(|z| (z.re, z.im)).collect();
263        let tensor = ComplexTensor::new(data, vec![roots.len(), 1])
264            .map_err(|e| roots_error(format!("roots: {e}")))?;
265        Ok(Value::ComplexTensor(tensor))
266    }
267}
268
269fn empty_column() -> BuiltinResult<Value> {
270    let tensor =
271        Tensor::new(Vec::new(), vec![0, 1]).map_err(|e| roots_error(format!("roots: {e}")))?;
272    Ok(Value::Tensor(tensor))
273}
274
275#[cfg(test)]
276pub(crate) mod tests {
277    use super::*;
278    use crate::builtins::common::test_support;
279    use futures::executor::block_on;
280    use runmat_accelerate_api::HostTensorView;
281    use runmat_builtins::{ComplexTensor, LogicalArray, Tensor};
282
283    fn assert_error_contains(err: crate::RuntimeError, needle: &str) {
284        assert!(
285            err.message().contains(needle),
286            "expected error containing '{needle}', got '{}'",
287            err.message()
288        );
289    }
290
291    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
292    #[test]
293    fn roots_quadratic_real() {
294        let coeffs = Tensor::new(vec![1.0, -3.0, 2.0], vec![3, 1]).unwrap();
295        let result = roots_builtin(Value::Tensor(coeffs)).expect("roots");
296        match result {
297            Value::Tensor(t) => {
298                assert_eq!(t.shape, vec![2, 1]);
299                let mut roots = t.data;
300                roots.sort_by(|a, b| a.partial_cmp(b).unwrap());
301                assert!((roots[0] - 1.0).abs() < 1e-10);
302                assert!((roots[1] - 2.0).abs() < 1e-10);
303            }
304            other => panic!("expected real tensor, got {other:?}"),
305        }
306    }
307
308    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
309    #[test]
310    fn roots_leading_zeros_trimmed() {
311        let coeffs = Tensor::new(vec![0.0, 0.0, 1.0, -4.0], vec![4, 1]).unwrap();
312        let result = roots_builtin(Value::Tensor(coeffs)).expect("roots");
313        match result {
314            Value::Tensor(t) => {
315                assert_eq!(t.shape, vec![1, 1]);
316                assert!((t.data[0] - 4.0).abs() < 1e-10);
317            }
318            other => panic!("expected tensor, got {other:?}"),
319        }
320    }
321
322    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
323    #[test]
324    fn roots_complex_pair() {
325        let coeffs = Tensor::new(vec![1.0, 0.0, 1.0], vec![3, 1]).unwrap();
326        let result = roots_builtin(Value::Tensor(coeffs)).expect("roots");
327        match result {
328            Value::ComplexTensor(t) => {
329                assert_eq!(t.shape, vec![2, 1]);
330                let mut roots = t.data;
331                roots.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
332                assert!((roots[0].0).abs() < 1e-10);
333                assert!((roots[0].1 + 1.0).abs() < 1e-10);
334                assert!((roots[1].0).abs() < 1e-10);
335                assert!((roots[1].1 - 1.0).abs() < 1e-10);
336            }
337            other => panic!("expected complex tensor, got {other:?}"),
338        }
339    }
340
341    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
342    #[test]
343    fn roots_quartic_all_zero_roots() {
344        // p(x) = x^4 => 4 roots at 0
345        let coeffs = Tensor::new(vec![1.0, 0.0, 0.0, 0.0, 0.0], vec![5, 1]).unwrap();
346        let result = roots_builtin(Value::Tensor(coeffs)).expect("roots quartic");
347        match result {
348            Value::Tensor(t) => {
349                assert_eq!(t.shape, vec![4, 1]);
350                for &r in &t.data {
351                    assert!(r.abs() < 1e-8);
352                }
353            }
354            Value::ComplexTensor(t) => {
355                assert_eq!(t.shape, vec![4, 1]);
356                for &(re, im) in &t.data {
357                    assert!(re.abs() < 1e-7 && im.abs() < 1e-7);
358                }
359            }
360            other => panic!("unexpected output {other:?}"),
361        }
362    }
363
364    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
365    #[test]
366    fn roots_accepts_complex_coefficients_input() {
367        // p(x) = x^2 + 1 with complex coefficients path
368        let coeffs =
369            ComplexTensor::new(vec![(1.0, 0.0), (0.0, 0.0), (1.0, 0.0)], vec![3, 1]).unwrap();
370        let result = roots_builtin(Value::ComplexTensor(coeffs)).expect("roots complex input");
371        match result {
372            Value::ComplexTensor(t) => {
373                assert_eq!(t.shape, vec![2, 1]);
374                // roots at i and -i
375                let mut roots = t.data;
376                roots.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
377                assert!(roots[0].0.abs() < 1e-10 && (roots[0].1 + 1.0).abs() < 1e-6);
378                assert!(roots[1].0.abs() < 1e-10 && (roots[1].1 - 1.0).abs() < 1e-6);
379            }
380            other => panic!("expected complex tensor, got {other:?}"),
381        }
382    }
383
384    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
385    #[test]
386    fn roots_accepts_logical_coefficients() {
387        // p(x) = x with logical coefficients [1 0]
388        let la = LogicalArray::new(vec![1, 0], vec![1, 2]).unwrap();
389        let result = roots_builtin(Value::LogicalArray(la)).expect("roots logical");
390        match result {
391            Value::Tensor(t) => {
392                assert_eq!(t.shape, vec![1, 1]);
393                assert!(t.data[0].abs() < 1e-12);
394            }
395            other => panic!("expected real tensor, got {other:?}"),
396        }
397    }
398
399    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
400    #[test]
401    fn roots_scalar_num_returns_empty() {
402        let result = roots_builtin(Value::Num(5.0)).expect("roots scalar num");
403        match result {
404            Value::Tensor(t) => {
405                assert_eq!(t.shape, vec![0, 1]);
406                assert!(t.data.is_empty());
407            }
408            other => panic!("expected empty tensor, got {other:?}"),
409        }
410    }
411
412    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
413    #[test]
414    fn roots_rejects_non_vector_input() {
415        let coeffs = Tensor::new(vec![1.0, 0.0, 0.0, 1.0], vec![2, 2]).unwrap();
416        let err = roots_builtin(Value::Tensor(coeffs)).expect_err("expected vector-shape error");
417        assert_error_contains(err, "vector");
418    }
419
420    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
421    #[test]
422    fn roots_all_zero_coefficients_returns_empty() {
423        let coeffs = Tensor::new(vec![0.0, 0.0, 0.0], vec![3, 1]).unwrap();
424        let result = roots_builtin(Value::Tensor(coeffs)).expect("roots");
425        match result {
426            Value::Tensor(t) => {
427                assert_eq!(t.shape, vec![0, 1]);
428                assert!(t.data.is_empty());
429            }
430            other => panic!("expected empty tensor, got {other:?}"),
431        }
432    }
433
434    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
435    #[test]
436    fn roots_gpu_input_gathers_to_host() {
437        test_support::with_test_provider(|provider| {
438            let coeffs = Tensor::new(vec![1.0, 0.0, -9.0, 0.0], vec![4, 1]).unwrap();
439            let view = HostTensorView {
440                data: &coeffs.data,
441                shape: &coeffs.shape,
442            };
443            let handle = provider.upload(&view).expect("upload");
444            let result = roots_builtin(Value::GpuTensor(handle)).expect("roots");
445            let gathered = test_support::gather(result).expect("gather");
446            assert_eq!(gathered.shape, vec![3, 1]);
447            let mut roots = gathered.data;
448            roots.sort_by(|a, b| a.partial_cmp(b).unwrap());
449            assert!((roots[0] + 3.0).abs() < 1e-9);
450            assert!((roots[1]).abs() < 1e-9);
451            assert!((roots[2] - 3.0).abs() < 1e-9);
452        });
453    }
454
455    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
456    #[test]
457    fn roots_constant_polynomial_returns_empty() {
458        let coeffs = Tensor::new(vec![5.0], vec![1, 1]).unwrap();
459        let result = roots_builtin(Value::Tensor(coeffs)).expect("roots");
460        match result {
461            Value::Tensor(t) => {
462                assert_eq!(t.shape, vec![0, 1]);
463            }
464            other => panic!("expected empty tensor, got {other:?}"),
465        }
466    }
467
468    fn roots_builtin(coefficients: Value) -> BuiltinResult<Value> {
469        block_on(super::roots_builtin(coefficients))
470    }
471}