1use nalgebra::DMatrix;
9use num_complex::Complex64;
10use runmat_builtins::{
11 BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
12 BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
13 ComplexTensor, Tensor, Value,
14};
15use runmat_macros::runtime_builtin;
16
17use crate::builtins::common::spec::{
18 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
19 ReductionNaN, ResidencyPolicy, ShapeRequirements,
20};
21use crate::builtins::common::{gpu_helpers, tensor};
22use crate::builtins::math::poly::type_resolvers::roots_type;
23use crate::{build_runtime_error, BuiltinResult, RuntimeError};
24
25const LEADING_ZERO_TOL: f64 = 1.0e-12;
26const RESULT_ZERO_TOL: f64 = 1.0e-10;
27const BUILTIN_NAME: &str = "roots";
28
29const ROOTS_OUTPUT: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
30 name: "r",
31 ty: BuiltinParamType::Any,
32 arity: BuiltinParamArity::Required,
33 default: None,
34 description: "Roots of the polynomial as a column vector.",
35}];
36
37const ROOTS_INPUTS: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
38 name: "c",
39 ty: BuiltinParamType::Any,
40 arity: BuiltinParamArity::Required,
41 default: None,
42 description: "Polynomial coefficient vector in descending power order.",
43}];
44
45const ROOTS_SIGNATURES: [BuiltinSignatureDescriptor; 1] = [BuiltinSignatureDescriptor {
46 label: "r = roots(c)",
47 inputs: &ROOTS_INPUTS,
48 outputs: &ROOTS_OUTPUT,
49}];
50
51const ROOTS_ERROR_INVALID_INPUT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
52 code: "RM.ROOTS.INVALID_INPUT",
53 identifier: Some("RunMat:roots:InvalidInput"),
54 when: "Input cannot be interpreted as a numeric coefficient vector.",
55 message: "roots: invalid input",
56};
57
58const ROOTS_ERROR_INTERNAL: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
59 code: "RM.ROOTS.INTERNAL",
60 identifier: Some("RunMat:roots:Internal"),
61 when: "Runtime fails while building companion matrix outputs or solving eigenvalues.",
62 message: "roots: internal runtime failure",
63};
64
65const ROOTS_ERRORS: [BuiltinErrorDescriptor; 2] = [ROOTS_ERROR_INVALID_INPUT, ROOTS_ERROR_INTERNAL];
66
67pub const ROOTS_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
68 signatures: &ROOTS_SIGNATURES,
69 output_mode: BuiltinOutputMode::Fixed,
70 completion_policy: BuiltinCompletionPolicy::Public,
71 errors: &ROOTS_ERRORS,
72};
73
74#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::math::poly::roots")]
75pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
76 name: "roots",
77 op_kind: GpuOpKind::Custom("polynomial-roots"),
78 supported_precisions: &[],
79 broadcast: BroadcastSemantics::None,
80 provider_hooks: &[],
81 constant_strategy: ConstantStrategy::InlineLiteral,
82 residency: ResidencyPolicy::GatherImmediately,
83 nan_mode: ReductionNaN::Include,
84 two_pass_threshold: None,
85 workgroup_size: None,
86 accepts_nan_mode: false,
87 notes: "Companion matrix eigenvalue solve executes on the host; providers currently fall back to the CPU implementation.",
88};
89
90fn roots_error(message: impl Into<String>) -> RuntimeError {
91 roots_error_with(message, &ROOTS_ERROR_INVALID_INPUT)
92}
93
94fn roots_error_with(
95 message: impl Into<String>,
96 error: &'static BuiltinErrorDescriptor,
97) -> RuntimeError {
98 let mut builder = build_runtime_error(message).with_builtin(BUILTIN_NAME);
99 if let Some(identifier) = error.identifier {
100 builder = builder.with_identifier(identifier);
101 }
102 builder.build()
103}
104
105#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::math::poly::roots")]
106pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
107 name: "roots",
108 shape: ShapeRequirements::Any,
109 constant_strategy: ConstantStrategy::InlineLiteral,
110 elementwise: None,
111 reduction: None,
112 emits_nan: true,
113 notes: "Non-elementwise builtin that terminates fusion and gathers inputs to the host.",
114};
115
116#[runtime_builtin(
117 name = "roots",
118 category = "math/poly",
119 summary = "Compute polynomial roots from a coefficient vector.",
120 keywords = "roots,polynomial,eigenvalues,companion",
121 accel = "sink",
122 type_resolver(roots_type),
123 descriptor(crate::builtins::math::poly::roots::ROOTS_DESCRIPTOR),
124 builtin_path = "crate::builtins::math::poly::roots"
125)]
126async fn roots_builtin(coefficients: Value) -> crate::BuiltinResult<Value> {
127 let coeffs = coefficients_to_complex(coefficients).await?;
128 let trimmed = trim_leading_zeros(coeffs);
129 if trimmed.is_empty() || trimmed.len() == 1 {
130 return empty_column();
131 }
132 let roots = solve_roots(&trimmed)?;
133 roots_to_value(&roots)
134}
135
136async fn coefficients_to_complex(value: Value) -> BuiltinResult<Vec<Complex64>> {
137 match value {
138 Value::GpuTensor(handle) => {
139 let tensor = gpu_helpers::gather_tensor_async(&handle).await?;
140 tensor_to_complex(tensor)
141 }
142 Value::Tensor(tensor) => tensor_to_complex(tensor),
143 Value::ComplexTensor(tensor) => complex_tensor_to_vec(tensor),
144 Value::LogicalArray(logical) => {
145 let tensor = tensor::logical_to_tensor(&logical).map_err(roots_error)?;
146 tensor_to_complex(tensor)
147 }
148 Value::Num(n) => {
149 let tensor =
150 Tensor::new(vec![n], vec![1, 1]).map_err(|e| roots_error(format!("roots: {e}")))?;
151 tensor_to_complex(tensor)
152 }
153 Value::Int(i) => {
154 let tensor = Tensor::new(vec![i.to_f64()], vec![1, 1])
155 .map_err(|e| roots_error(format!("roots: {e}")))?;
156 tensor_to_complex(tensor)
157 }
158 Value::Bool(b) => {
159 let tensor = Tensor::new(vec![if b { 1.0 } else { 0.0 }], vec![1, 1])
160 .map_err(|e| roots_error(format!("roots: {e}")))?;
161 tensor_to_complex(tensor)
162 }
163 other => Err(roots_error(format!(
164 "roots: expected a numeric vector of polynomial coefficients, got {other:?}"
165 ))),
166 }
167}
168
169fn tensor_to_complex(tensor: Tensor) -> BuiltinResult<Vec<Complex64>> {
170 ensure_vector_shape("roots", &tensor.shape)?;
171 Ok(tensor
172 .data
173 .into_iter()
174 .map(|value| Complex64::new(value, 0.0))
175 .collect())
176}
177
178fn complex_tensor_to_vec(tensor: ComplexTensor) -> BuiltinResult<Vec<Complex64>> {
179 ensure_vector_shape("roots", &tensor.shape)?;
180 Ok(tensor
181 .data
182 .into_iter()
183 .map(|(re, im)| Complex64::new(re, im))
184 .collect())
185}
186
187fn ensure_vector_shape(name: &str, shape: &[usize]) -> BuiltinResult<()> {
188 let is_vector = match shape.len() {
189 0 => true,
190 1 => true,
191 2 => shape[0] == 1 || shape[1] == 1 || shape.iter().product::<usize>() == 0,
192 _ => shape.iter().filter(|&&dim| dim > 1).count() <= 1,
193 };
194 if !is_vector {
195 return Err(roots_error(format!(
196 "{name}: coefficients must be a vector (row or column), got shape {:?}",
197 shape
198 )));
199 }
200 Ok(())
201}
202
203fn trim_leading_zeros(mut coeffs: Vec<Complex64>) -> Vec<Complex64> {
204 if coeffs.is_empty() {
205 return coeffs;
206 }
207 let scale = coeffs.iter().map(|c| c.norm()).fold(0.0_f64, f64::max);
208 let tol = if scale == 0.0 {
209 LEADING_ZERO_TOL
210 } else {
211 LEADING_ZERO_TOL * scale
212 };
213 let first_nonzero = coeffs
214 .iter()
215 .position(|c| c.norm() > tol)
216 .unwrap_or(coeffs.len());
217 coeffs.split_off(first_nonzero)
218}
219
220fn solve_roots(coeffs: &[Complex64]) -> BuiltinResult<Vec<Complex64>> {
221 if coeffs.len() <= 1 {
222 return Ok(Vec::new());
223 }
224 if coeffs.len() == 2 {
225 let a = coeffs[0];
226 let b = coeffs[1];
227 if a.norm() <= LEADING_ZERO_TOL {
228 return Err(roots_error(
229 "roots: leading coefficient must be non-zero after trimming",
230 ));
231 }
232 return Ok(vec![-b / a]);
233 }
234
235 let degree = coeffs.len() - 1;
236 if degree == 3 {
237 return Ok(cubic_roots(coeffs[0], coeffs[1], coeffs[2], coeffs[3]));
238 }
239 let leading = coeffs[0];
240 if leading.norm() <= LEADING_ZERO_TOL {
241 return Err(roots_error(
242 "roots: leading coefficient must be non-zero after trimming",
243 ));
244 }
245
246 let mut companion = DMatrix::<Complex64>::zeros(degree, degree);
247 for row in 1..degree {
248 companion[(row, row - 1)] = Complex64::new(1.0, 0.0);
249 }
250
251 for (idx, coeff) in coeffs.iter().enumerate().skip(1) {
252 let value = -(*coeff) / leading;
253 let column = idx - 1;
254 if column < degree {
255 companion[(0, column)] = value;
256 }
257 }
258
259 let eigenvalues = companion.clone().eigenvalues().ok_or_else(|| {
260 roots_error_with(
261 "roots: failed to compute eigenvalues of the companion matrix",
262 &ROOTS_ERROR_INTERNAL,
263 )
264 })?;
265 Ok(eigenvalues.iter().map(|&z| canonicalize_root(z)).collect())
266}
267
268fn cubic_roots(a: Complex64, b: Complex64, c: Complex64, d: Complex64) -> Vec<Complex64> {
269 let three = 3.0;
271 let nine = 9.0;
272 let twenty_seven = 27.0;
273 let a2 = a * a;
274 let a3 = a2 * a;
275 let p = (three * a * c - b * b) / (three * a2);
276 let q = (twenty_seven * a2 * d - nine * a * b * c + Complex64::new(2.0, 0.0) * b * b * b)
277 / (twenty_seven * a3);
278 let half = Complex64::new(0.5, 0.0);
279 let disc = (q * q) * half * half + (p * p * p) / Complex64::new(27.0, 0.0);
280 let sqrt_disc = disc.sqrt();
281 let u = (-q * half + sqrt_disc).powf(1.0 / 3.0);
282 let v = (-q * half - sqrt_disc).powf(1.0 / 3.0);
283 let omega = Complex64::new(-0.5, (3.0f64).sqrt() * 0.5);
284 let omega2 = omega * omega;
285 let shift = b / (three * a);
286 let y0 = u + v;
287 let y1 = u * omega + v * omega.conj();
288 let y2 = u * omega2 + v * omega;
289 vec![y0 - shift, y1 - shift, y2 - shift]
290}
291
292fn canonicalize_root(z: Complex64) -> Complex64 {
293 if !z.re.is_finite() || !z.im.is_finite() {
294 return z;
295 }
296 let mut real = z.re;
297 let mut imag = z.im;
298 let scale = 1.0 + real.abs();
299 if imag.abs() <= RESULT_ZERO_TOL * scale {
300 imag = 0.0;
301 }
302 if real.abs() <= RESULT_ZERO_TOL {
303 real = 0.0;
304 }
305 Complex64::new(real, imag)
306}
307
308fn roots_to_value(roots: &[Complex64]) -> BuiltinResult<Value> {
309 if roots.is_empty() {
310 return empty_column();
311 }
312 let all_real = roots
313 .iter()
314 .all(|z| z.im.abs() <= RESULT_ZERO_TOL * (1.0 + z.re.abs()));
315 if all_real {
316 let mut data: Vec<f64> = Vec::with_capacity(roots.len());
317 for &root in roots {
318 data.push(root.re);
319 }
320 let tensor = Tensor::new(data, vec![roots.len(), 1])
321 .map_err(|e| roots_error_with(format!("roots: {e}"), &ROOTS_ERROR_INTERNAL))?;
322 Ok(Value::Tensor(tensor))
323 } else {
324 let data: Vec<(f64, f64)> = roots.iter().map(|z| (z.re, z.im)).collect();
325 let tensor = ComplexTensor::new(data, vec![roots.len(), 1])
326 .map_err(|e| roots_error_with(format!("roots: {e}"), &ROOTS_ERROR_INTERNAL))?;
327 Ok(Value::ComplexTensor(tensor))
328 }
329}
330
331fn empty_column() -> BuiltinResult<Value> {
332 let tensor = Tensor::new(Vec::new(), vec![0, 1])
333 .map_err(|e| roots_error_with(format!("roots: {e}"), &ROOTS_ERROR_INTERNAL))?;
334 Ok(Value::Tensor(tensor))
335}
336
337#[cfg(test)]
338pub(crate) mod tests {
339 use super::*;
340 use crate::builtins::common::test_support;
341 use futures::executor::block_on;
342 use runmat_accelerate_api::HostTensorView;
343 use runmat_builtins::{ComplexTensor, LogicalArray, Tensor};
344
345 fn assert_error_contains(err: crate::RuntimeError, needle: &str) {
346 assert!(
347 err.message().contains(needle),
348 "expected error containing '{needle}', got '{}'",
349 err.message()
350 );
351 }
352
353 #[test]
354 fn roots_descriptor_signatures_cover_core_forms() {
355 let labels: Vec<&str> = ROOTS_DESCRIPTOR
356 .signatures
357 .iter()
358 .map(|signature| signature.label)
359 .collect();
360 assert!(labels.contains(&"r = roots(c)"));
361 }
362
363 #[test]
364 fn roots_descriptor_errors_have_stable_codes() {
365 let codes: Vec<&str> = ROOTS_DESCRIPTOR
366 .errors
367 .iter()
368 .map(|error| error.code)
369 .collect();
370 assert!(codes.contains(&"RM.ROOTS.INVALID_INPUT"));
371 assert!(codes.contains(&"RM.ROOTS.INTERNAL"));
372 }
373
374 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
375 #[test]
376 fn roots_quadratic_real() {
377 let coeffs = Tensor::new(vec![1.0, -3.0, 2.0], vec![3, 1]).unwrap();
378 let result = roots_builtin(Value::Tensor(coeffs)).expect("roots");
379 match result {
380 Value::Tensor(t) => {
381 assert_eq!(t.shape, vec![2, 1]);
382 let mut roots = t.data;
383 roots.sort_by(|a, b| a.partial_cmp(b).unwrap());
384 assert!((roots[0] - 1.0).abs() < 1e-10);
385 assert!((roots[1] - 2.0).abs() < 1e-10);
386 }
387 other => panic!("expected real tensor, got {other:?}"),
388 }
389 }
390
391 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
392 #[test]
393 fn roots_leading_zeros_trimmed() {
394 let coeffs = Tensor::new(vec![0.0, 0.0, 1.0, -4.0], vec![4, 1]).unwrap();
395 let result = roots_builtin(Value::Tensor(coeffs)).expect("roots");
396 match result {
397 Value::Tensor(t) => {
398 assert_eq!(t.shape, vec![1, 1]);
399 assert!((t.data[0] - 4.0).abs() < 1e-10);
400 }
401 other => panic!("expected tensor, got {other:?}"),
402 }
403 }
404
405 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
406 #[test]
407 fn roots_complex_pair() {
408 let coeffs = Tensor::new(vec![1.0, 0.0, 1.0], vec![3, 1]).unwrap();
409 let result = roots_builtin(Value::Tensor(coeffs)).expect("roots");
410 match result {
411 Value::ComplexTensor(t) => {
412 assert_eq!(t.shape, vec![2, 1]);
413 let mut roots = t.data;
414 roots.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
415 assert!((roots[0].0).abs() < 1e-10);
416 assert!((roots[0].1 + 1.0).abs() < 1e-10);
417 assert!((roots[1].0).abs() < 1e-10);
418 assert!((roots[1].1 - 1.0).abs() < 1e-10);
419 }
420 other => panic!("expected complex tensor, got {other:?}"),
421 }
422 }
423
424 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
425 #[test]
426 fn roots_quartic_all_zero_roots() {
427 let coeffs = Tensor::new(vec![1.0, 0.0, 0.0, 0.0, 0.0], vec![5, 1]).unwrap();
429 let result = roots_builtin(Value::Tensor(coeffs)).expect("roots quartic");
430 match result {
431 Value::Tensor(t) => {
432 assert_eq!(t.shape, vec![4, 1]);
433 for &r in &t.data {
434 assert!(r.abs() < 1e-8);
435 }
436 }
437 Value::ComplexTensor(t) => {
438 assert_eq!(t.shape, vec![4, 1]);
439 for &(re, im) in &t.data {
440 assert!(re.abs() < 1e-7 && im.abs() < 1e-7);
441 }
442 }
443 other => panic!("unexpected output {other:?}"),
444 }
445 }
446
447 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
448 #[test]
449 fn roots_accepts_complex_coefficients_input() {
450 let coeffs =
452 ComplexTensor::new(vec![(1.0, 0.0), (0.0, 0.0), (1.0, 0.0)], vec![3, 1]).unwrap();
453 let result = roots_builtin(Value::ComplexTensor(coeffs)).expect("roots complex input");
454 match result {
455 Value::ComplexTensor(t) => {
456 assert_eq!(t.shape, vec![2, 1]);
457 let mut roots = t.data;
459 roots.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
460 assert!(roots[0].0.abs() < 1e-10 && (roots[0].1 + 1.0).abs() < 1e-6);
461 assert!(roots[1].0.abs() < 1e-10 && (roots[1].1 - 1.0).abs() < 1e-6);
462 }
463 other => panic!("expected complex tensor, got {other:?}"),
464 }
465 }
466
467 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
468 #[test]
469 fn roots_accepts_logical_coefficients() {
470 let la = LogicalArray::new(vec![1, 0], vec![1, 2]).unwrap();
472 let result = roots_builtin(Value::LogicalArray(la)).expect("roots logical");
473 match result {
474 Value::Tensor(t) => {
475 assert_eq!(t.shape, vec![1, 1]);
476 assert!(t.data[0].abs() < 1e-12);
477 }
478 other => panic!("expected real tensor, got {other:?}"),
479 }
480 }
481
482 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
483 #[test]
484 fn roots_scalar_num_returns_empty() {
485 let result = roots_builtin(Value::Num(5.0)).expect("roots scalar num");
486 match result {
487 Value::Tensor(t) => {
488 assert_eq!(t.shape, vec![0, 1]);
489 assert!(t.data.is_empty());
490 }
491 other => panic!("expected empty tensor, got {other:?}"),
492 }
493 }
494
495 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
496 #[test]
497 fn roots_rejects_non_vector_input() {
498 let coeffs = Tensor::new(vec![1.0, 0.0, 0.0, 1.0], vec![2, 2]).unwrap();
499 let err = roots_builtin(Value::Tensor(coeffs)).expect_err("expected vector-shape error");
500 assert_eq!(err.identifier(), ROOTS_ERROR_INVALID_INPUT.identifier);
501 assert_error_contains(err, "vector");
502 }
503
504 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
505 #[test]
506 fn roots_all_zero_coefficients_returns_empty() {
507 let coeffs = Tensor::new(vec![0.0, 0.0, 0.0], vec![3, 1]).unwrap();
508 let result = roots_builtin(Value::Tensor(coeffs)).expect("roots");
509 match result {
510 Value::Tensor(t) => {
511 assert_eq!(t.shape, vec![0, 1]);
512 assert!(t.data.is_empty());
513 }
514 other => panic!("expected empty tensor, got {other:?}"),
515 }
516 }
517
518 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
519 #[test]
520 fn roots_gpu_input_gathers_to_host() {
521 test_support::with_test_provider(|provider| {
522 let coeffs = Tensor::new(vec![1.0, 0.0, -9.0, 0.0], vec![4, 1]).unwrap();
523 let view = HostTensorView {
524 data: &coeffs.data,
525 shape: &coeffs.shape,
526 };
527 let handle = provider.upload(&view).expect("upload");
528 let result = roots_builtin(Value::GpuTensor(handle)).expect("roots");
529 let gathered = test_support::gather(result).expect("gather");
530 assert_eq!(gathered.shape, vec![3, 1]);
531 let mut roots = gathered.data;
532 roots.sort_by(|a, b| a.partial_cmp(b).unwrap());
533 assert!((roots[0] + 3.0).abs() < 1e-9);
534 assert!((roots[1]).abs() < 1e-9);
535 assert!((roots[2] - 3.0).abs() < 1e-9);
536 });
537 }
538
539 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
540 #[test]
541 fn roots_constant_polynomial_returns_empty() {
542 let coeffs = Tensor::new(vec![5.0], vec![1, 1]).unwrap();
543 let result = roots_builtin(Value::Tensor(coeffs)).expect("roots");
544 match result {
545 Value::Tensor(t) => {
546 assert_eq!(t.shape, vec![0, 1]);
547 }
548 other => panic!("expected empty tensor, got {other:?}"),
549 }
550 }
551
552 fn roots_builtin(coefficients: Value) -> BuiltinResult<Value> {
553 block_on(super::roots_builtin(coefficients))
554 }
555}