1use nalgebra::DMatrix;
9use num_complex::Complex64;
10use runmat_builtins::{ComplexTensor, Tensor, Value};
11use runmat_macros::runtime_builtin;
12
13use crate::builtins::common::spec::{
14 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
15 ReductionNaN, ResidencyPolicy, ShapeRequirements,
16};
17use crate::builtins::common::{gpu_helpers, tensor};
18use crate::builtins::math::poly::type_resolvers::roots_type;
19use crate::{build_runtime_error, BuiltinResult, RuntimeError};
20
21const LEADING_ZERO_TOL: f64 = 1.0e-12;
22const RESULT_ZERO_TOL: f64 = 1.0e-10;
23const BUILTIN_NAME: &str = "roots";
24
25#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::math::poly::roots")]
26pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
27 name: "roots",
28 op_kind: GpuOpKind::Custom("polynomial-roots"),
29 supported_precisions: &[],
30 broadcast: BroadcastSemantics::None,
31 provider_hooks: &[],
32 constant_strategy: ConstantStrategy::InlineLiteral,
33 residency: ResidencyPolicy::GatherImmediately,
34 nan_mode: ReductionNaN::Include,
35 two_pass_threshold: None,
36 workgroup_size: None,
37 accepts_nan_mode: false,
38 notes: "Companion matrix eigenvalue solve executes on the host; providers currently fall back to the CPU implementation.",
39};
40
41fn roots_error(message: impl Into<String>) -> RuntimeError {
42 build_runtime_error(message)
43 .with_builtin(BUILTIN_NAME)
44 .build()
45}
46
47#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::math::poly::roots")]
48pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
49 name: "roots",
50 shape: ShapeRequirements::Any,
51 constant_strategy: ConstantStrategy::InlineLiteral,
52 elementwise: None,
53 reduction: None,
54 emits_nan: true,
55 notes: "Non-elementwise builtin that terminates fusion and gathers inputs to the host.",
56};
57
58#[runtime_builtin(
59 name = "roots",
60 category = "math/poly",
61 summary = "Compute the roots of a polynomial specified by its coefficients.",
62 keywords = "roots,polynomial,eigenvalues,companion",
63 accel = "sink",
64 type_resolver(roots_type),
65 builtin_path = "crate::builtins::math::poly::roots"
66)]
67async fn roots_builtin(coefficients: Value) -> crate::BuiltinResult<Value> {
68 let coeffs = coefficients_to_complex(coefficients).await?;
69 let trimmed = trim_leading_zeros(coeffs);
70 if trimmed.is_empty() || trimmed.len() == 1 {
71 return empty_column();
72 }
73 let roots = solve_roots(&trimmed)?;
74 roots_to_value(&roots)
75}
76
77async fn coefficients_to_complex(value: Value) -> BuiltinResult<Vec<Complex64>> {
78 match value {
79 Value::GpuTensor(handle) => {
80 let tensor = gpu_helpers::gather_tensor_async(&handle).await?;
81 tensor_to_complex(tensor)
82 }
83 Value::Tensor(tensor) => tensor_to_complex(tensor),
84 Value::ComplexTensor(tensor) => complex_tensor_to_vec(tensor),
85 Value::LogicalArray(logical) => {
86 let tensor = tensor::logical_to_tensor(&logical).map_err(roots_error)?;
87 tensor_to_complex(tensor)
88 }
89 Value::Num(n) => {
90 let tensor =
91 Tensor::new(vec![n], vec![1, 1]).map_err(|e| roots_error(format!("roots: {e}")))?;
92 tensor_to_complex(tensor)
93 }
94 Value::Int(i) => {
95 let tensor = Tensor::new(vec![i.to_f64()], vec![1, 1])
96 .map_err(|e| roots_error(format!("roots: {e}")))?;
97 tensor_to_complex(tensor)
98 }
99 Value::Bool(b) => {
100 let tensor = Tensor::new(vec![if b { 1.0 } else { 0.0 }], vec![1, 1])
101 .map_err(|e| roots_error(format!("roots: {e}")))?;
102 tensor_to_complex(tensor)
103 }
104 other => Err(roots_error(format!(
105 "roots: expected a numeric vector of polynomial coefficients, got {other:?}"
106 ))),
107 }
108}
109
110fn tensor_to_complex(tensor: Tensor) -> BuiltinResult<Vec<Complex64>> {
111 ensure_vector_shape("roots", &tensor.shape)?;
112 Ok(tensor
113 .data
114 .into_iter()
115 .map(|value| Complex64::new(value, 0.0))
116 .collect())
117}
118
119fn complex_tensor_to_vec(tensor: ComplexTensor) -> BuiltinResult<Vec<Complex64>> {
120 ensure_vector_shape("roots", &tensor.shape)?;
121 Ok(tensor
122 .data
123 .into_iter()
124 .map(|(re, im)| Complex64::new(re, im))
125 .collect())
126}
127
128fn ensure_vector_shape(name: &str, shape: &[usize]) -> BuiltinResult<()> {
129 let is_vector = match shape.len() {
130 0 => true,
131 1 => true,
132 2 => shape[0] == 1 || shape[1] == 1 || shape.iter().product::<usize>() == 0,
133 _ => shape.iter().filter(|&&dim| dim > 1).count() <= 1,
134 };
135 if !is_vector {
136 return Err(roots_error(format!(
137 "{name}: coefficients must be a vector (row or column), got shape {:?}",
138 shape
139 )));
140 }
141 Ok(())
142}
143
144fn trim_leading_zeros(mut coeffs: Vec<Complex64>) -> Vec<Complex64> {
145 if coeffs.is_empty() {
146 return coeffs;
147 }
148 let scale = coeffs.iter().map(|c| c.norm()).fold(0.0_f64, f64::max);
149 let tol = if scale == 0.0 {
150 LEADING_ZERO_TOL
151 } else {
152 LEADING_ZERO_TOL * scale
153 };
154 let first_nonzero = coeffs
155 .iter()
156 .position(|c| c.norm() > tol)
157 .unwrap_or(coeffs.len());
158 coeffs.split_off(first_nonzero)
159}
160
161fn solve_roots(coeffs: &[Complex64]) -> BuiltinResult<Vec<Complex64>> {
162 if coeffs.len() <= 1 {
163 return Ok(Vec::new());
164 }
165 if coeffs.len() == 2 {
166 let a = coeffs[0];
167 let b = coeffs[1];
168 if a.norm() <= LEADING_ZERO_TOL {
169 return Err(roots_error(
170 "roots: leading coefficient must be non-zero after trimming",
171 ));
172 }
173 return Ok(vec![-b / a]);
174 }
175
176 let degree = coeffs.len() - 1;
177 if degree == 3 {
178 return Ok(cubic_roots(coeffs[0], coeffs[1], coeffs[2], coeffs[3]));
179 }
180 let leading = coeffs[0];
181 if leading.norm() <= LEADING_ZERO_TOL {
182 return Err(roots_error(
183 "roots: leading coefficient must be non-zero after trimming",
184 ));
185 }
186
187 let mut companion = DMatrix::<Complex64>::zeros(degree, degree);
188 for row in 1..degree {
189 companion[(row, row - 1)] = Complex64::new(1.0, 0.0);
190 }
191
192 for (idx, coeff) in coeffs.iter().enumerate().skip(1) {
193 let value = -(*coeff) / leading;
194 let column = idx - 1;
195 if column < degree {
196 companion[(0, column)] = value;
197 }
198 }
199
200 let eigenvalues = companion.clone().eigenvalues().ok_or_else(|| {
201 roots_error("roots: failed to compute eigenvalues of the companion matrix")
202 })?;
203 Ok(eigenvalues.iter().map(|&z| canonicalize_root(z)).collect())
204}
205
206fn cubic_roots(a: Complex64, b: Complex64, c: Complex64, d: Complex64) -> Vec<Complex64> {
207 let three = 3.0;
209 let nine = 9.0;
210 let twenty_seven = 27.0;
211 let a2 = a * a;
212 let a3 = a2 * a;
213 let p = (three * a * c - b * b) / (three * a2);
214 let q = (twenty_seven * a2 * d - nine * a * b * c + Complex64::new(2.0, 0.0) * b * b * b)
215 / (twenty_seven * a3);
216 let half = Complex64::new(0.5, 0.0);
217 let disc = (q * q) * half * half + (p * p * p) / Complex64::new(27.0, 0.0);
218 let sqrt_disc = disc.sqrt();
219 let u = (-q * half + sqrt_disc).powf(1.0 / 3.0);
220 let v = (-q * half - sqrt_disc).powf(1.0 / 3.0);
221 let omega = Complex64::new(-0.5, (3.0f64).sqrt() * 0.5);
222 let omega2 = omega * omega;
223 let shift = b / (three * a);
224 let y0 = u + v;
225 let y1 = u * omega + v * omega.conj();
226 let y2 = u * omega2 + v * omega;
227 vec![y0 - shift, y1 - shift, y2 - shift]
228}
229
230fn canonicalize_root(z: Complex64) -> Complex64 {
231 if !z.re.is_finite() || !z.im.is_finite() {
232 return z;
233 }
234 let mut real = z.re;
235 let mut imag = z.im;
236 let scale = 1.0 + real.abs();
237 if imag.abs() <= RESULT_ZERO_TOL * scale {
238 imag = 0.0;
239 }
240 if real.abs() <= RESULT_ZERO_TOL {
241 real = 0.0;
242 }
243 Complex64::new(real, imag)
244}
245
246fn roots_to_value(roots: &[Complex64]) -> BuiltinResult<Value> {
247 if roots.is_empty() {
248 return empty_column();
249 }
250 let all_real = roots
251 .iter()
252 .all(|z| z.im.abs() <= RESULT_ZERO_TOL * (1.0 + z.re.abs()));
253 if all_real {
254 let mut data: Vec<f64> = Vec::with_capacity(roots.len());
255 for &root in roots {
256 data.push(root.re);
257 }
258 let tensor = Tensor::new(data, vec![roots.len(), 1])
259 .map_err(|e| roots_error(format!("roots: {e}")))?;
260 Ok(Value::Tensor(tensor))
261 } else {
262 let data: Vec<(f64, f64)> = roots.iter().map(|z| (z.re, z.im)).collect();
263 let tensor = ComplexTensor::new(data, vec![roots.len(), 1])
264 .map_err(|e| roots_error(format!("roots: {e}")))?;
265 Ok(Value::ComplexTensor(tensor))
266 }
267}
268
269fn empty_column() -> BuiltinResult<Value> {
270 let tensor =
271 Tensor::new(Vec::new(), vec![0, 1]).map_err(|e| roots_error(format!("roots: {e}")))?;
272 Ok(Value::Tensor(tensor))
273}
274
275#[cfg(test)]
276pub(crate) mod tests {
277 use super::*;
278 use crate::builtins::common::test_support;
279 use futures::executor::block_on;
280 use runmat_accelerate_api::HostTensorView;
281 use runmat_builtins::{ComplexTensor, LogicalArray, Tensor};
282
283 fn assert_error_contains(err: crate::RuntimeError, needle: &str) {
284 assert!(
285 err.message().contains(needle),
286 "expected error containing '{needle}', got '{}'",
287 err.message()
288 );
289 }
290
291 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
292 #[test]
293 fn roots_quadratic_real() {
294 let coeffs = Tensor::new(vec![1.0, -3.0, 2.0], vec![3, 1]).unwrap();
295 let result = roots_builtin(Value::Tensor(coeffs)).expect("roots");
296 match result {
297 Value::Tensor(t) => {
298 assert_eq!(t.shape, vec![2, 1]);
299 let mut roots = t.data;
300 roots.sort_by(|a, b| a.partial_cmp(b).unwrap());
301 assert!((roots[0] - 1.0).abs() < 1e-10);
302 assert!((roots[1] - 2.0).abs() < 1e-10);
303 }
304 other => panic!("expected real tensor, got {other:?}"),
305 }
306 }
307
308 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
309 #[test]
310 fn roots_leading_zeros_trimmed() {
311 let coeffs = Tensor::new(vec![0.0, 0.0, 1.0, -4.0], vec![4, 1]).unwrap();
312 let result = roots_builtin(Value::Tensor(coeffs)).expect("roots");
313 match result {
314 Value::Tensor(t) => {
315 assert_eq!(t.shape, vec![1, 1]);
316 assert!((t.data[0] - 4.0).abs() < 1e-10);
317 }
318 other => panic!("expected tensor, got {other:?}"),
319 }
320 }
321
322 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
323 #[test]
324 fn roots_complex_pair() {
325 let coeffs = Tensor::new(vec![1.0, 0.0, 1.0], vec![3, 1]).unwrap();
326 let result = roots_builtin(Value::Tensor(coeffs)).expect("roots");
327 match result {
328 Value::ComplexTensor(t) => {
329 assert_eq!(t.shape, vec![2, 1]);
330 let mut roots = t.data;
331 roots.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
332 assert!((roots[0].0).abs() < 1e-10);
333 assert!((roots[0].1 + 1.0).abs() < 1e-10);
334 assert!((roots[1].0).abs() < 1e-10);
335 assert!((roots[1].1 - 1.0).abs() < 1e-10);
336 }
337 other => panic!("expected complex tensor, got {other:?}"),
338 }
339 }
340
341 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
342 #[test]
343 fn roots_quartic_all_zero_roots() {
344 let coeffs = Tensor::new(vec![1.0, 0.0, 0.0, 0.0, 0.0], vec![5, 1]).unwrap();
346 let result = roots_builtin(Value::Tensor(coeffs)).expect("roots quartic");
347 match result {
348 Value::Tensor(t) => {
349 assert_eq!(t.shape, vec![4, 1]);
350 for &r in &t.data {
351 assert!(r.abs() < 1e-8);
352 }
353 }
354 Value::ComplexTensor(t) => {
355 assert_eq!(t.shape, vec![4, 1]);
356 for &(re, im) in &t.data {
357 assert!(re.abs() < 1e-7 && im.abs() < 1e-7);
358 }
359 }
360 other => panic!("unexpected output {other:?}"),
361 }
362 }
363
364 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
365 #[test]
366 fn roots_accepts_complex_coefficients_input() {
367 let coeffs =
369 ComplexTensor::new(vec![(1.0, 0.0), (0.0, 0.0), (1.0, 0.0)], vec![3, 1]).unwrap();
370 let result = roots_builtin(Value::ComplexTensor(coeffs)).expect("roots complex input");
371 match result {
372 Value::ComplexTensor(t) => {
373 assert_eq!(t.shape, vec![2, 1]);
374 let mut roots = t.data;
376 roots.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
377 assert!(roots[0].0.abs() < 1e-10 && (roots[0].1 + 1.0).abs() < 1e-6);
378 assert!(roots[1].0.abs() < 1e-10 && (roots[1].1 - 1.0).abs() < 1e-6);
379 }
380 other => panic!("expected complex tensor, got {other:?}"),
381 }
382 }
383
384 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
385 #[test]
386 fn roots_accepts_logical_coefficients() {
387 let la = LogicalArray::new(vec![1, 0], vec![1, 2]).unwrap();
389 let result = roots_builtin(Value::LogicalArray(la)).expect("roots logical");
390 match result {
391 Value::Tensor(t) => {
392 assert_eq!(t.shape, vec![1, 1]);
393 assert!(t.data[0].abs() < 1e-12);
394 }
395 other => panic!("expected real tensor, got {other:?}"),
396 }
397 }
398
399 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
400 #[test]
401 fn roots_scalar_num_returns_empty() {
402 let result = roots_builtin(Value::Num(5.0)).expect("roots scalar num");
403 match result {
404 Value::Tensor(t) => {
405 assert_eq!(t.shape, vec![0, 1]);
406 assert!(t.data.is_empty());
407 }
408 other => panic!("expected empty tensor, got {other:?}"),
409 }
410 }
411
412 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
413 #[test]
414 fn roots_rejects_non_vector_input() {
415 let coeffs = Tensor::new(vec![1.0, 0.0, 0.0, 1.0], vec![2, 2]).unwrap();
416 let err = roots_builtin(Value::Tensor(coeffs)).expect_err("expected vector-shape error");
417 assert_error_contains(err, "vector");
418 }
419
420 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
421 #[test]
422 fn roots_all_zero_coefficients_returns_empty() {
423 let coeffs = Tensor::new(vec![0.0, 0.0, 0.0], vec![3, 1]).unwrap();
424 let result = roots_builtin(Value::Tensor(coeffs)).expect("roots");
425 match result {
426 Value::Tensor(t) => {
427 assert_eq!(t.shape, vec![0, 1]);
428 assert!(t.data.is_empty());
429 }
430 other => panic!("expected empty tensor, got {other:?}"),
431 }
432 }
433
434 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
435 #[test]
436 fn roots_gpu_input_gathers_to_host() {
437 test_support::with_test_provider(|provider| {
438 let coeffs = Tensor::new(vec![1.0, 0.0, -9.0, 0.0], vec![4, 1]).unwrap();
439 let view = HostTensorView {
440 data: &coeffs.data,
441 shape: &coeffs.shape,
442 };
443 let handle = provider.upload(&view).expect("upload");
444 let result = roots_builtin(Value::GpuTensor(handle)).expect("roots");
445 let gathered = test_support::gather(result).expect("gather");
446 assert_eq!(gathered.shape, vec![3, 1]);
447 let mut roots = gathered.data;
448 roots.sort_by(|a, b| a.partial_cmp(b).unwrap());
449 assert!((roots[0] + 3.0).abs() < 1e-9);
450 assert!((roots[1]).abs() < 1e-9);
451 assert!((roots[2] - 3.0).abs() < 1e-9);
452 });
453 }
454
455 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
456 #[test]
457 fn roots_constant_polynomial_returns_empty() {
458 let coeffs = Tensor::new(vec![5.0], vec![1, 1]).unwrap();
459 let result = roots_builtin(Value::Tensor(coeffs)).expect("roots");
460 match result {
461 Value::Tensor(t) => {
462 assert_eq!(t.shape, vec![0, 1]);
463 }
464 other => panic!("expected empty tensor, got {other:?}"),
465 }
466 }
467
468 fn roots_builtin(coefficients: Value) -> BuiltinResult<Value> {
469 block_on(super::roots_builtin(coefficients))
470 }
471}