Skip to main content

runmat_runtime/builtins/math/poly/
polyfit.rs

1//! MATLAB-compatible `polyfit` builtin with GPU-aware semantics for RunMat.
2
3use log::{trace, warn};
4use num_complex::Complex64;
5use runmat_accelerate_api::ProviderPolyfitResult;
6use runmat_builtins::{ComplexTensor, StructValue, Tensor, Value};
7use runmat_macros::runtime_builtin;
8
9use crate::builtins::common::tensor;
10use crate::dispatcher;
11use crate::{build_runtime_error, BuiltinResult, RuntimeError};
12
13use crate::builtins::common::spec::{
14    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
15    ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
16};
17use crate::builtins::math::poly::type_resolvers::polyfit_type;
18
19const EPS: f64 = 1.0e-12;
20const EPS_NAN: f64 = 1.0e-12;
21const BUILTIN_NAME: &str = "polyfit";
22
23#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::math::poly::polyfit")]
24pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
25    name: "polyfit",
26    op_kind: GpuOpKind::Custom("polyfit"),
27    supported_precisions: &[ScalarType::F32, ScalarType::F64],
28    broadcast: BroadcastSemantics::Matlab,
29    provider_hooks: &[ProviderHook::Custom("polyfit")],
30    constant_strategy: ConstantStrategy::UniformBuffer,
31    residency: ResidencyPolicy::GatherImmediately,
32    nan_mode: ReductionNaN::Include,
33    two_pass_threshold: None,
34    workgroup_size: None,
35    accepts_nan_mode: false,
36    notes:
37        "Providers may gather to the host and invoke the shared Householder QR solver; WGPU implements this path today.",
38};
39
40fn polyfit_error(message: impl Into<String>) -> RuntimeError {
41    build_runtime_error(message)
42        .with_builtin(BUILTIN_NAME)
43        .build()
44}
45
46#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::math::poly::polyfit")]
47pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
48    name: "polyfit",
49    shape: ShapeRequirements::Any,
50    constant_strategy: ConstantStrategy::UniformBuffer,
51    elementwise: None,
52    reduction: None,
53    emits_nan: false,
54    notes: "Acts as a sink node—polynomial fitting materialises results eagerly and terminates fusion graphs.",
55};
56
57#[runtime_builtin(
58    name = "polyfit",
59    category = "math/poly",
60    summary = "Fit an n-th degree polynomial to data points with MATLAB-compatible outputs.",
61    keywords = "polyfit,polynomial,least-squares,gpu",
62    accel = "sink",
63    sink = true,
64    type_resolver(polyfit_type),
65    builtin_path = "crate::builtins::math::poly::polyfit"
66)]
67async fn polyfit_builtin(
68    x: Value,
69    y: Value,
70    degree: Value,
71    rest: Vec<Value>,
72) -> crate::BuiltinResult<Value> {
73    let eval = evaluate(x, y, degree, &rest).await?;
74    if let Some(out_count) = crate::output_count::current_output_count() {
75        if out_count == 0 {
76            return Ok(Value::OutputList(Vec::new()));
77        }
78        let mut outputs = vec![eval.coefficients()];
79        if out_count >= 2 {
80            outputs.push(eval.stats());
81        }
82        if out_count >= 3 {
83            outputs.push(eval.mu());
84        }
85        return Ok(crate::output_count::output_list_with_padding(
86            out_count, outputs,
87        ));
88    }
89    Ok(eval.coefficients())
90}
91
92/// Evaluate `polyfit`, returning the multi-output envelope used by the VM.
93pub async fn evaluate(
94    x: Value,
95    y: Value,
96    degree: Value,
97    rest: &[Value],
98) -> BuiltinResult<PolyfitEval> {
99    let deg = parse_degree(&degree)?;
100
101    if let Some(eval) = try_gpu_polyfit(&x, &y, deg, rest).await? {
102        return Ok(eval);
103    }
104
105    let x_host = dispatcher::gather_if_needed_async(&x).await?;
106    let y_host = dispatcher::gather_if_needed_async(&y).await?;
107
108    let x_data = real_vector("polyfit", "X", x_host).await?;
109    let (y_data, is_complex_input) = complex_vector("polyfit", "Y", y_host).await?;
110
111    if x_data.len() != y_data.len() {
112        return Err(polyfit_error(
113            "polyfit: X and Y vectors must be the same length",
114        ));
115    }
116    if x_data.is_empty() {
117        return Err(polyfit_error(
118            "polyfit: X and Y must contain at least one sample",
119        ));
120    }
121    if deg + 1 > x_data.len() && x_data.len() > 1 {
122        warn!(
123            "polyfit: polynomial degree {} is ill-conditioned for {} data points; results may be inaccurate",
124            deg,
125            x_data.len()
126        );
127    }
128
129    let weights = parse_weights(rest, x_data.len()).await?;
130    let mut solution = solve_polyfit(&x_data, &y_data, deg, weights.as_deref())?;
131    if is_complex_input {
132        solution.is_complex = true;
133    }
134
135    PolyfitEval::from_solution(solution)
136}
137
138async fn try_gpu_polyfit(
139    x: &Value,
140    y: &Value,
141    degree: usize,
142    rest: &[Value],
143) -> BuiltinResult<Option<PolyfitEval>> {
144    let provider = match runmat_accelerate_api::provider() {
145        Some(p) => p,
146        None => return Ok(None),
147    };
148
149    let x_handle = match x {
150        Value::GpuTensor(handle) => handle,
151        _ => return Ok(None),
152    };
153    let y_handle = match y {
154        Value::GpuTensor(handle) => handle,
155        _ => return Ok(None),
156    };
157
158    if rest.len() > 1 {
159        return Ok(None);
160    }
161
162    let weight_handle = match rest.first() {
163        Some(Value::GpuTensor(handle)) => Some(handle),
164        Some(_) => return Ok(None),
165        None => None,
166    };
167
168    let result = match provider
169        .polyfit(x_handle, y_handle, degree, weight_handle)
170        .await
171    {
172        Ok(res) => res,
173        Err(err) => {
174            trace!("polyfit: provider path unavailable ({err}); falling back to host");
175            return Ok(None);
176        }
177    };
178
179    let solution = PolyfitSolution::from_provider(result)?;
180    PolyfitEval::from_solution(solution).map(Some)
181}
182
183#[derive(Clone, Debug)]
184struct PolyfitSolution {
185    coeffs: Vec<Complex64>,
186    r_matrix: Vec<f64>,
187    mu_mean: f64,
188    mu_scale: f64,
189    normr: f64,
190    df: f64,
191    cols: usize,
192    is_complex: bool,
193}
194
195impl PolyfitSolution {
196    fn from_provider(result: ProviderPolyfitResult) -> BuiltinResult<Self> {
197        let cols = result.coefficients.len();
198        if cols == 0 {
199            return Err(polyfit_error(
200                "polyfit: provider returned empty coefficient vector",
201            ));
202        }
203        if result.r_matrix.len() != cols * cols {
204            return Err(polyfit_error(
205                "polyfit: provider returned malformed R matrix",
206            ));
207        }
208        let [mu_mean, mu_scale] = result.mu;
209        Ok(Self {
210            coeffs: result
211                .coefficients
212                .into_iter()
213                .map(|re| Complex64::new(re, 0.0))
214                .collect(),
215            r_matrix: result.r_matrix,
216            mu_mean,
217            mu_scale,
218            normr: result.normr,
219            df: result.df,
220            cols,
221            is_complex: false,
222        })
223    }
224}
225
226/// Multi-output envelope for `polyfit`, mirroring MATLAB semantics.
227#[derive(Debug)]
228pub struct PolyfitEval {
229    coefficients: Value,
230    stats: Value,
231    mu: Value,
232    is_complex: bool,
233}
234
235impl PolyfitEval {
236    fn from_solution(solution: PolyfitSolution) -> BuiltinResult<Self> {
237        let coefficients = coefficients_to_value(&solution.coeffs)?;
238        let stats = build_stats(
239            &solution.r_matrix,
240            solution.cols,
241            solution.normr,
242            solution.df,
243        )?;
244        let mu = build_mu(solution.mu_mean, solution.mu_scale)?;
245        Ok(Self {
246            coefficients,
247            stats,
248            mu,
249            is_complex: solution.is_complex,
250        })
251    }
252
253    /// Polynomial coefficients ordered from highest power to constant term.
254    pub fn coefficients(&self) -> Value {
255        self.coefficients.clone()
256    }
257
258    /// Structure `S` containing fields `R`, `df`, and `normr`.
259    pub fn stats(&self) -> Value {
260        self.stats.clone()
261    }
262
263    /// Centering and scaling vector `[mean(x), std(x)]`.
264    pub fn mu(&self) -> Value {
265        self.mu.clone()
266    }
267
268    /// Returns `true` if the fitted polynomial contains a complex coefficient.
269    pub fn is_complex(&self) -> bool {
270        self.is_complex
271    }
272}
273
274fn parse_degree(value: &Value) -> BuiltinResult<usize> {
275    match value {
276        Value::Int(i) => {
277            let raw = i.to_i64();
278            if raw < 0 {
279                return Err(polyfit_error(
280                    "polyfit: degree must be a non-negative integer",
281                ));
282            }
283            Ok(raw as usize)
284        }
285        Value::Num(n) => {
286            if !n.is_finite() {
287                return Err(polyfit_error("polyfit: degree must be finite"));
288            }
289            let rounded = n.round();
290            if (rounded - n).abs() > EPS {
291                return Err(polyfit_error("polyfit: degree must be an integer"));
292            }
293            if rounded < 0.0 {
294                return Err(polyfit_error(
295                    "polyfit: degree must be a non-negative integer",
296                ));
297            }
298            Ok(rounded as usize)
299        }
300        Value::Tensor(t) if tensor::is_scalar_tensor(t) => parse_degree(&Value::Num(t.data[0])),
301        Value::LogicalArray(l) if l.len() == 1 => {
302            parse_degree(&Value::Num(if l.data[0] != 0 { 1.0 } else { 0.0 }))
303        }
304        other => Err(polyfit_error(format!(
305            "polyfit: degree must be a scalar numeric value, got {other:?}"
306        ))),
307    }
308}
309
310#[async_recursion::async_recursion(?Send)]
311async fn real_vector(context: &str, label: &str, value: Value) -> BuiltinResult<Vec<f64>> {
312    match value {
313        Value::Tensor(mut tensor) => {
314            ensure_vector_shape(context, label, &tensor.shape)?;
315            Ok(tensor.data.drain(..).collect())
316        }
317        Value::LogicalArray(logical) => {
318            let tensor = tensor::logical_to_tensor(&logical).map_err(polyfit_error)?;
319            ensure_vector_shape(context, label, &tensor.shape)?;
320            Ok(tensor.data)
321        }
322        Value::Num(n) => Ok(vec![n]),
323        Value::Int(i) => Ok(vec![i.to_f64()]),
324        Value::Bool(b) => Ok(vec![if b { 1.0 } else { 0.0 }]),
325        Value::GpuTensor(handle) => {
326            let gathered =
327                crate::builtins::common::gpu_helpers::gather_tensor_async(&handle).await?;
328            real_vector(context, label, Value::Tensor(gathered)).await
329        }
330        Value::Complex(_, _) | Value::ComplexTensor(_) => Err(polyfit_error(format!(
331            "{context}: {label} must be real-valued; complex inputs are not supported"
332        ))),
333        other => Err(polyfit_error(format!(
334            "{context}: expected {label} to be a numeric vector, got {other:?}"
335        ))),
336    }
337}
338
339#[async_recursion::async_recursion(?Send)]
340async fn complex_vector(
341    context: &str,
342    label: &str,
343    value: Value,
344) -> BuiltinResult<(Vec<Complex64>, bool)> {
345    match value {
346        Value::Tensor(mut tensor) => {
347            ensure_vector_shape(context, label, &tensor.shape)?;
348            let all_real = true;
349            let data = tensor
350                .data
351                .drain(..)
352                .map(|x| Complex64::new(x, 0.0))
353                .collect();
354            Ok((data, all_real))
355        }
356        Value::ComplexTensor(tensor) => {
357            ensure_vector_shape(context, label, &tensor.shape)?;
358            let is_complex = tensor.data.iter().any(|&(_, im)| im.abs() > EPS);
359            let data = tensor
360                .data
361                .into_iter()
362                .map(|(re, im)| Complex64::new(re, im))
363                .collect::<Vec<_>>();
364            Ok((data, is_complex))
365        }
366        Value::LogicalArray(logical) => {
367            let tensor = tensor::logical_to_tensor(&logical).map_err(polyfit_error)?;
368            ensure_vector_shape(context, label, &tensor.shape)?;
369            Ok((
370                tensor
371                    .data
372                    .iter()
373                    .map(|&x| Complex64::new(x, 0.0))
374                    .collect(),
375                false,
376            ))
377        }
378        Value::Num(n) => Ok((vec![Complex64::new(n, 0.0)], false)),
379        Value::Int(i) => Ok((vec![Complex64::new(i.to_f64(), 0.0)], false)),
380        Value::Bool(b) => Ok((vec![Complex64::new(if b { 1.0 } else { 0.0 }, 0.0)], false)),
381        Value::Complex(re, im) => Ok((vec![Complex64::new(re, im)], im.abs() > EPS)),
382        Value::GpuTensor(handle) => {
383            let gathered =
384                crate::builtins::common::gpu_helpers::gather_tensor_async(&handle).await?;
385            complex_vector(context, label, Value::Tensor(gathered)).await
386        }
387        other => Err(polyfit_error(format!(
388            "{context}: expected {label} to be a numeric vector, got {other:?}"
389        ))),
390    }
391}
392
393async fn parse_weights(rest: &[Value], len: usize) -> BuiltinResult<Option<Vec<f64>>> {
394    match rest.len() {
395        0 => Ok(None),
396        1 => {
397            let gathered = dispatcher::gather_if_needed_async(&rest[0]).await?;
398            let data = real_vector("polyfit", "weights", gathered).await?;
399            if data.len() != len {
400                return Err(polyfit_error(
401                    "polyfit: weight vector must match the size of X",
402                ));
403            }
404            validate_weights(&data)?;
405            Ok(Some(data))
406        }
407        _ => Err(polyfit_error("polyfit: too many input arguments")),
408    }
409}
410
411fn validate_weights(weights: &[f64]) -> BuiltinResult<()> {
412    for (idx, w) in weights.iter().enumerate() {
413        if !w.is_finite() {
414            return Err(polyfit_error(format!(
415                "polyfit: weight at position {} must be finite",
416                idx + 1
417            )));
418        }
419        if *w < 0.0 {
420            return Err(polyfit_error("polyfit: weights must be non-negative"));
421        }
422    }
423    Ok(())
424}
425
426fn solve_polyfit(
427    x_data: &[f64],
428    y_data: &[Complex64],
429    degree: usize,
430    weights: Option<&[f64]>,
431) -> BuiltinResult<PolyfitSolution> {
432    if x_data.len() != y_data.len() {
433        return Err(polyfit_error(
434            "polyfit: X and Y vectors must be the same length",
435        ));
436    }
437    if x_data.is_empty() {
438        return Err(polyfit_error(
439            "polyfit: X and Y must contain at least one sample",
440        ));
441    }
442    if let Some(w) = weights {
443        if w.len() != x_data.len() {
444            return Err(polyfit_error(
445                "polyfit: weight vector must match the size of X",
446            ));
447        }
448        validate_weights(w)?;
449    }
450
451    let mean = x_data.iter().sum::<f64>() / x_data.len() as f64;
452    if !mean.is_finite() {
453        return Err(polyfit_error("polyfit: mean of X must be finite"));
454    }
455    let scale = compute_scale(x_data, mean)?;
456    let scaled: Vec<f64> = x_data.iter().map(|&v| (v - mean) / scale).collect();
457
458    let mut rhs = y_data.to_vec();
459    for (idx, value) in rhs.iter().enumerate() {
460        if !value.re.is_finite() || !value.im.is_finite() {
461            return Err(polyfit_error(format!(
462                "polyfit: Y must contain finite values (encountered NaN/Inf at position {})",
463                idx + 1
464            )));
465        }
466    }
467    if let Some(w) = weights {
468        apply_weights_rhs(&mut rhs, w)?;
469    }
470
471    let rows = scaled.len();
472    let cols = degree + 1;
473    let mut vandermonde = build_vandermonde(&scaled, cols);
474    if let Some(w) = weights {
475        apply_weights_matrix(&mut vandermonde, rows, cols, w)?;
476    }
477
478    let mut transformed_rhs = rhs.clone();
479    householder_qr(&mut vandermonde, rows, cols, &mut transformed_rhs)?;
480    let coeff_scaled = solve_upper(&vandermonde, rows, cols, &transformed_rhs)?;
481    let coeff_original = transform_coefficients(&coeff_scaled, mean, scale);
482
483    let normr = residual_norm(&transformed_rhs, rows, cols);
484    let df = if rows > cols {
485        (rows - cols) as f64
486    } else {
487        0.0
488    };
489    let r_matrix = extract_upper(&vandermonde, rows, cols);
490    let is_complex = coeff_original.iter().any(|c| c.im.abs() > EPS_NAN);
491
492    Ok(PolyfitSolution {
493        coeffs: coeff_original,
494        r_matrix,
495        mu_mean: mean,
496        mu_scale: scale,
497        normr,
498        df,
499        cols,
500        is_complex,
501    })
502}
503
504fn compute_scale(data: &[f64], mean: f64) -> BuiltinResult<f64> {
505    if data.len() <= 1 {
506        return Ok(1.0);
507    }
508    let mut acc = 0.0;
509    for &value in data {
510        if !value.is_finite() {
511            return Err(polyfit_error("polyfit: X must contain finite values"));
512        }
513        let diff = value - mean;
514        acc += diff * diff;
515    }
516    let denom = (data.len() as f64 - 1.0).max(1.0);
517    let std = (acc / denom).sqrt();
518    let scale = if std.abs() <= EPS { 1.0 } else { std };
519    if !scale.is_finite() {
520        return Err(polyfit_error(
521            "polyfit: failed to compute a stable scaling factor",
522        ));
523    }
524    Ok(scale)
525}
526
527fn build_vandermonde(u: &[f64], cols: usize) -> Vec<f64> {
528    let rows = u.len();
529    let mut matrix = vec![0.0; rows * cols];
530    if cols == 0 {
531        return matrix;
532    }
533    for (row_idx, &value) in u.iter().enumerate() {
534        let mut powers = vec![0.0; cols];
535        powers[cols - 1] = 1.0;
536        for idx in (0..cols - 1).rev() {
537            powers[idx] = powers[idx + 1] * value;
538        }
539        for col_idx in 0..cols {
540            matrix[row_idx + col_idx * rows] = powers[col_idx];
541        }
542    }
543    matrix
544}
545
546fn apply_weights_matrix(
547    matrix: &mut [f64],
548    rows: usize,
549    cols: usize,
550    weights: &[f64],
551) -> BuiltinResult<()> {
552    for (row, weight) in weights.iter().enumerate().take(rows) {
553        let sqrt_w = weight.sqrt();
554        if !sqrt_w.is_finite() {
555            return Err(polyfit_error(format!(
556                "polyfit: weight at position {} must be finite",
557                row + 1
558            )));
559        }
560        for col in 0..cols {
561            let idx = row + col * rows;
562            matrix[idx] *= sqrt_w;
563        }
564    }
565    Ok(())
566}
567
568fn apply_weights_rhs(rhs: &mut [Complex64], weights: &[f64]) -> BuiltinResult<()> {
569    for (idx, (value, weight)) in rhs.iter_mut().zip(weights.iter()).enumerate() {
570        let sqrt_w = weight.sqrt();
571        if !sqrt_w.is_finite() {
572            return Err(polyfit_error(format!(
573                "polyfit: weight at position {} must be finite",
574                idx + 1
575            )));
576        }
577        *value *= sqrt_w;
578    }
579    Ok(())
580}
581
582fn ensure_vector_shape(context: &str, label: &str, shape: &[usize]) -> BuiltinResult<()> {
583    if !is_vector_shape(shape) {
584        return Err(polyfit_error(format!(
585            "{context}: {label} must be a vector"
586        )));
587    }
588    Ok(())
589}
590
591fn is_vector_shape(shape: &[usize]) -> bool {
592    shape.iter().copied().filter(|&dim| dim > 1).count() <= 1
593}
594
595fn householder_qr(
596    matrix: &mut [f64],
597    rows: usize,
598    cols: usize,
599    rhs: &mut [Complex64],
600) -> BuiltinResult<()> {
601    let min_dim = rows.min(cols);
602    for k in 0..min_dim {
603        let mut norm_sq = 0.0;
604        for row in k..rows {
605            let val = matrix[row + k * rows];
606            norm_sq += val * val;
607        }
608        if norm_sq <= EPS {
609            continue;
610        }
611        let norm = norm_sq.sqrt();
612        let x0 = matrix[k + k * rows];
613        let alpha = if x0 >= 0.0 { -norm } else { norm };
614        let mut v = vec![0.0; rows - k];
615        v[0] = x0 - alpha;
616        for row in (k + 1)..rows {
617            v[row - k] = matrix[row + k * rows];
618        }
619        let v_norm_sq: f64 = v.iter().map(|&x| x * x).sum();
620        if v_norm_sq <= EPS {
621            continue;
622        }
623        let beta = 2.0 / v_norm_sq;
624        matrix[k + k * rows] = alpha;
625        for row in (k + 1)..rows {
626            matrix[row + k * rows] = 0.0;
627        }
628        for col in (k + 1)..cols {
629            let mut dot = 0.0;
630            for (idx, &vi) in v.iter().enumerate() {
631                let row_idx = k + idx;
632                dot += vi * matrix[row_idx + col * rows];
633            }
634            let factor = beta * dot;
635            for (idx, &vi) in v.iter().enumerate() {
636                let row_idx = k + idx;
637                matrix[row_idx + col * rows] -= factor * vi;
638            }
639        }
640        let mut dot = Complex64::new(0.0, 0.0);
641        for (idx, &vi) in v.iter().enumerate() {
642            let row_idx = k + idx;
643            dot += rhs[row_idx] * vi;
644        }
645        let factor = Complex64::new(beta, 0.0) * dot;
646        for (idx, &vi) in v.iter().enumerate() {
647            let row_idx = k + idx;
648            rhs[row_idx] -= factor * vi;
649        }
650    }
651    Ok(())
652}
653
654fn solve_upper(
655    matrix: &[f64],
656    rows: usize,
657    cols: usize,
658    rhs: &[Complex64],
659) -> BuiltinResult<Vec<Complex64>> {
660    if rhs.len() < rows {
661        return Err(polyfit_error(
662            "polyfit internal error: RHS dimension mismatch",
663        ));
664    }
665    let mut coeffs = vec![Complex64::new(0.0, 0.0); cols];
666    for col in (0..cols).rev() {
667        let diag = if col < rows {
668            matrix[col + col * rows]
669        } else {
670            0.0
671        };
672        if diag.abs() <= EPS {
673            coeffs[col] = Complex64::new(0.0, 0.0);
674            continue;
675        }
676        let mut acc = if col < rows {
677            rhs[col]
678        } else {
679            Complex64::new(0.0, 0.0)
680        };
681        for next in (col + 1)..cols {
682            let idx = if col < rows {
683                matrix[col + next * rows]
684            } else {
685                0.0
686            };
687            acc -= Complex64::new(idx, 0.0) * coeffs[next];
688        }
689        coeffs[col] = acc / Complex64::new(diag, 0.0);
690    }
691    Ok(coeffs)
692}
693
694fn residual_norm(rhs: &[Complex64], rows: usize, cols: usize) -> f64 {
695    let tail_start = rows.min(cols);
696    let mut acc = 0.0;
697    for value in rhs.iter().skip(tail_start) {
698        acc += value.norm_sqr();
699    }
700    acc.sqrt()
701}
702
703fn extract_upper(matrix: &[f64], rows: usize, cols: usize) -> Vec<f64> {
704    let mut output = vec![0.0; cols * cols];
705    for col in 0..cols {
706        for row in 0..=col {
707            if row < rows {
708                output[row + col * cols] = matrix[row + col * rows];
709            }
710        }
711    }
712    output
713}
714
715fn transform_coefficients(coeffs: &[Complex64], mean: f64, scale: f64) -> Vec<Complex64> {
716    let mut poly: Vec<Complex64> = Vec::new();
717    for &coeff in coeffs {
718        let mut next = vec![Complex64::new(0.0, 0.0); poly.len() + 1];
719        for (idx, &value) in poly.iter().enumerate() {
720            next[idx + 1] += value / scale;
721            next[idx] -= value * (mean / scale);
722        }
723        next[0] += coeff;
724        poly = next;
725    }
726    poly.reverse();
727    poly
728}
729
730fn coefficients_to_value(coeffs: &[Complex64]) -> BuiltinResult<Value> {
731    let all_real = coeffs
732        .iter()
733        .all(|c| c.im.abs() <= EPS_NAN && c.re.is_finite());
734    if all_real {
735        let data: Vec<f64> = coeffs.iter().map(|c| c.re).collect();
736        let tensor = Tensor::new(data, vec![1, coeffs.len()])
737            .map_err(|e| polyfit_error(format!("polyfit: {e}")))?;
738        Ok(Value::Tensor(tensor))
739    } else {
740        let data: Vec<(f64, f64)> = coeffs.iter().map(|c| (c.re, c.im)).collect();
741        let tensor = ComplexTensor::new(data, vec![1, coeffs.len()])
742            .map_err(|e| polyfit_error(format!("polyfit: {e}")))?;
743        Ok(Value::ComplexTensor(tensor))
744    }
745}
746
747fn build_stats(r: &[f64], n: usize, normr: f64, df: f64) -> BuiltinResult<Value> {
748    let tensor =
749        Tensor::new(r.to_vec(), vec![n, n]).map_err(|e| polyfit_error(format!("polyfit: {e}")))?;
750    let mut st = StructValue::new();
751    st.fields.insert("R".to_string(), Value::Tensor(tensor));
752    st.fields.insert("df".to_string(), Value::Num(df));
753    st.fields.insert("normr".to_string(), Value::Num(normr));
754    Ok(Value::Struct(st))
755}
756
757fn build_mu(mean: f64, scale: f64) -> BuiltinResult<Value> {
758    if !scale.is_finite() || scale.abs() <= EPS {
759        return Err(polyfit_error("polyfit: mu(2) must be non-zero and finite"));
760    }
761    let tensor = Tensor::new(vec![mean, scale], vec![1, 2])
762        .map_err(|e| polyfit_error(format!("polyfit: {e}")))?;
763    Ok(Value::Tensor(tensor))
764}
765
766#[derive(Debug, Clone)]
767pub struct PolyfitHostRealResult {
768    pub coefficients: Vec<f64>,
769    pub r_matrix: Vec<f64>,
770    pub mu: [f64; 2],
771    pub normr: f64,
772    pub df: f64,
773}
774
775pub fn polyfit_host_real_for_provider(
776    x: &[f64],
777    y: &[f64],
778    degree: usize,
779    weights: Option<&[f64]>,
780) -> BuiltinResult<PolyfitHostRealResult> {
781    if x.len() != y.len() {
782        return Err(polyfit_error(
783            "polyfit: X and Y vectors must be the same length",
784        ));
785    }
786    if let Some(w) = weights {
787        if w.len() != x.len() {
788            return Err(polyfit_error(
789                "polyfit: weight vector must match the size of X",
790            ));
791        }
792        validate_weights(w)?;
793    }
794    let complex_y: Vec<Complex64> = y.iter().copied().map(|v| Complex64::new(v, 0.0)).collect();
795    let solution = solve_polyfit(x, &complex_y, degree, weights)?;
796    let PolyfitSolution {
797        coeffs,
798        r_matrix,
799        mu_mean,
800        mu_scale,
801        normr,
802        df,
803        cols: _,
804        is_complex,
805    } = solution;
806    if is_complex {
807        return Err(polyfit_error(
808            "polyfit: provider fallback produced complex coefficients for real data",
809        ));
810    }
811    let coeffs: Vec<f64> = coeffs.into_iter().map(|c| c.re).collect();
812    let mu = [mu_mean, mu_scale];
813    Ok(PolyfitHostRealResult {
814        coefficients: coeffs,
815        r_matrix,
816        mu,
817        normr,
818        df,
819    })
820}
821
822#[cfg(test)]
823pub(crate) mod tests {
824    use super::*;
825    use crate::builtins::common::test_support;
826    use futures::executor::block_on;
827
828    fn assert_error_contains(err: crate::RuntimeError, needle: &str) {
829        assert!(
830            err.message().contains(needle),
831            "expected error containing '{needle}', got '{}'",
832            err.message()
833        );
834    }
835
836    fn evaluate(
837        x: Value,
838        y: Value,
839        degree: Value,
840        rest: &[Value],
841    ) -> Result<PolyfitEval, RuntimeError> {
842        block_on(super::evaluate(x, y, degree, rest))
843    }
844
845    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
846    #[test]
847    fn fits_linear_data() {
848        let x = Tensor::new(vec![0.0, 1.0, 2.0, 3.0], vec![4, 1]).unwrap();
849        let mut y_vals = Vec::new();
850        for i in 0..4 {
851            y_vals.push(1.5 * i as f64 + 2.0);
852        }
853        let y = Tensor::new(y_vals, vec![4, 1]).unwrap();
854        let eval = evaluate(
855            Value::Tensor(x),
856            Value::Tensor(y),
857            Value::Int(runmat_builtins::IntValue::I32(1)),
858            &[],
859        )
860        .expect("polyfit");
861        match eval.coefficients() {
862            Value::Tensor(t) => {
863                assert_eq!(t.shape, vec![1, 2]);
864                assert!((t.data[0] - 1.5).abs() < 1e-10);
865                assert!((t.data[1] - 2.0).abs() < 1e-10);
866            }
867            other => panic!("expected tensor coefficients, got {other:?}"),
868        }
869    }
870
871    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
872    #[test]
873    fn returns_struct_and_mu() {
874        let x = Tensor::new(vec![-1.0, 0.0, 1.0], vec![3, 1]).unwrap();
875        let y = Tensor::new(vec![1.0, 0.0, 1.0], vec![3, 1]).unwrap();
876        let eval = evaluate(
877            Value::Tensor(x),
878            Value::Tensor(y),
879            Value::Int(runmat_builtins::IntValue::I32(2)),
880            &[],
881        )
882        .expect("polyfit");
883        match eval.stats() {
884            Value::Struct(s) => {
885                assert!(s.fields.contains_key("R"));
886                assert!(s.fields.contains_key("df"));
887                assert!(s.fields.contains_key("normr"));
888            }
889            other => panic!("expected struct, got {other:?}"),
890        }
891        match eval.mu() {
892            Value::Tensor(t) => {
893                assert_eq!(t.shape, vec![1, 2]);
894                assert!((t.data[0]).abs() < 1e-10);
895                assert!(t.data[1].abs() > 0.0);
896            }
897            other => panic!("expected tensor mu, got {other:?}"),
898        }
899    }
900
901    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
902    #[test]
903    fn weighted_fit_matches_unweighted_when_weights_equal() {
904        let x = Tensor::new(vec![0.0, 1.0, 2.0], vec![3, 1]).unwrap();
905        let y = Tensor::new(vec![1.0, 3.0, 7.0], vec![3, 1]).unwrap();
906        let weights = Tensor::new(vec![1.0, 1.0, 1.0], vec![3, 1]).unwrap();
907        let eval_unweighted = evaluate(
908            Value::Tensor(x.clone()),
909            Value::Tensor(y.clone()),
910            Value::Int(runmat_builtins::IntValue::I32(2)),
911            &[],
912        )
913        .expect("polyfit");
914        let eval_weighted = evaluate(
915            Value::Tensor(x),
916            Value::Tensor(y),
917            Value::Int(runmat_builtins::IntValue::I32(2)),
918            &[Value::Tensor(weights)],
919        )
920        .expect("polyfit");
921        assert_eq!(eval_unweighted.coefficients(), eval_weighted.coefficients());
922    }
923
924    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
925    #[test]
926    fn accepts_logical_degree_scalar() {
927        let x = Tensor::new(vec![0.0, 1.0], vec![2, 1]).unwrap();
928        let y = Tensor::new(vec![1.0, 3.0], vec![2, 1]).unwrap();
929        let logical = runmat_builtins::LogicalArray::new(vec![1], vec![1, 1]).unwrap();
930        let eval = evaluate(
931            Value::Tensor(x),
932            Value::Tensor(y),
933            Value::LogicalArray(logical),
934            &[],
935        )
936        .expect("polyfit");
937        assert!(matches!(eval.coefficients(), Value::Tensor(_)));
938    }
939
940    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
941    #[test]
942    fn rejects_non_integer_degree() {
943        let x = Tensor::new(vec![0.0, 1.0, 2.0], vec![3, 1]).unwrap();
944        let y = Tensor::new(vec![1.0, 3.0, 7.0], vec![3, 1]).unwrap();
945        let err = evaluate(Value::Tensor(x), Value::Tensor(y), Value::Num(1.5), &[])
946            .expect_err("polyfit should reject non-integer degree");
947        assert_error_contains(err, "integer");
948    }
949
950    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
951    #[test]
952    fn rejects_infinite_weights() {
953        let x = Tensor::new(vec![0.0, 1.0, 2.0], vec![3, 1]).unwrap();
954        let y = Tensor::new(vec![1.0, 3.0, 7.0], vec![3, 1]).unwrap();
955        let weights = Tensor::new(vec![1.0, f64::INFINITY, 1.0], vec![3, 1]).unwrap();
956        let err = evaluate(
957            Value::Tensor(x),
958            Value::Tensor(y),
959            Value::Int(runmat_builtins::IntValue::I32(2)),
960            &[Value::Tensor(weights)],
961        )
962        .expect_err("polyfit should reject infinite weights");
963        assert_error_contains(err, "weight at position 2");
964    }
965
966    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
967    #[test]
968    fn gpu_inputs_are_gathered() {
969        test_support::with_test_provider(|provider| {
970            let x = Tensor::new(vec![0.0, 1.0, 2.0], vec![3, 1]).unwrap();
971            let y = Tensor::new(vec![1.0, 3.0, 7.0], vec![3, 1]).unwrap();
972            let view = runmat_accelerate_api::HostTensorView {
973                data: &x.data,
974                shape: &x.shape,
975            };
976            let x_handle = provider.upload(&view).expect("upload");
977            let view_y = runmat_accelerate_api::HostTensorView {
978                data: &y.data,
979                shape: &y.shape,
980            };
981            let y_handle = provider.upload(&view_y).expect("upload");
982            let eval = evaluate(
983                Value::GpuTensor(x_handle),
984                Value::GpuTensor(y_handle),
985                Value::Int(runmat_builtins::IntValue::I32(2)),
986                &[],
987            )
988            .expect("polyfit");
989            assert!(matches!(eval.coefficients(), Value::Tensor(_)));
990        });
991    }
992
993    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
994    #[test]
995    fn gpu_weights_are_gathered() {
996        test_support::with_test_provider(|provider| {
997            let x = Tensor::new(vec![0.0, 1.0, 2.0], vec![3, 1]).unwrap();
998            let y = Tensor::new(vec![1.0, 3.0, 7.0], vec![3, 1]).unwrap();
999            let weights = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
1000
1001            let x_view = runmat_accelerate_api::HostTensorView {
1002                data: &x.data,
1003                shape: &x.shape,
1004            };
1005            let y_view = runmat_accelerate_api::HostTensorView {
1006                data: &y.data,
1007                shape: &y.shape,
1008            };
1009            let w_view = runmat_accelerate_api::HostTensorView {
1010                data: &weights.data,
1011                shape: &weights.shape,
1012            };
1013
1014            let x_handle = provider.upload(&x_view).expect("upload x");
1015            let y_handle = provider.upload(&y_view).expect("upload y");
1016            let w_handle = provider.upload(&w_view).expect("upload weights");
1017
1018            let cpu_eval = evaluate(
1019                Value::Tensor(x.clone()),
1020                Value::Tensor(y.clone()),
1021                Value::Int(runmat_builtins::IntValue::I32(2)),
1022                &[Value::Tensor(weights.clone())],
1023            )
1024            .expect("cpu polyfit");
1025
1026            let gpu_eval = evaluate(
1027                Value::GpuTensor(x_handle.clone()),
1028                Value::GpuTensor(y_handle.clone()),
1029                Value::Int(runmat_builtins::IntValue::I32(2)),
1030                &[Value::GpuTensor(w_handle.clone())],
1031            )
1032            .expect("gpu polyfit with weights");
1033
1034            assert_eq!(cpu_eval.coefficients(), gpu_eval.coefficients());
1035            assert_eq!(cpu_eval.mu(), gpu_eval.mu());
1036
1037            let _ = provider.free(&x_handle);
1038            let _ = provider.free(&y_handle);
1039            let _ = provider.free(&w_handle);
1040        });
1041    }
1042
1043    #[cfg(feature = "wgpu")]
1044    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1045    #[test]
1046    fn polyfit_wgpu_matches_cpu() {
1047        let options = runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default();
1048        let _provider =
1049            match runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(options) {
1050                Ok(p) => p,
1051                Err(err) => {
1052                    warn!("polyfit_wgpu_matches_cpu: skipping test ({err})");
1053                    return;
1054                }
1055            };
1056        let x = Tensor::new(vec![0.0, 1.0, 2.0, 3.0], vec![4, 1]).unwrap();
1057        let y = Tensor::new(vec![1.0, 3.0, 7.0, 13.0], vec![4, 1]).unwrap();
1058
1059        let cpu_eval = evaluate(
1060            Value::Tensor(x.clone()),
1061            Value::Tensor(y.clone()),
1062            Value::Int(runmat_builtins::IntValue::I32(2)),
1063            &[],
1064        )
1065        .expect("cpu polyfit");
1066
1067        let trait_provider = runmat_accelerate_api::provider().expect("wgpu provider registered");
1068        let x_view = runmat_accelerate_api::HostTensorView {
1069            data: &x.data,
1070            shape: &x.shape,
1071        };
1072        let y_view = runmat_accelerate_api::HostTensorView {
1073            data: &y.data,
1074            shape: &y.shape,
1075        };
1076        let x_handle = trait_provider.upload(&x_view).expect("upload x");
1077        let y_handle = trait_provider.upload(&y_view).expect("upload y");
1078
1079        let gpu_eval = evaluate(
1080            Value::GpuTensor(x_handle.clone()),
1081            Value::GpuTensor(y_handle.clone()),
1082            Value::Int(runmat_builtins::IntValue::I32(2)),
1083            &[],
1084        )
1085        .expect("gpu polyfit");
1086
1087        let _ = trait_provider.free(&x_handle);
1088        let _ = trait_provider.free(&y_handle);
1089
1090        let cpu_coeff = match cpu_eval.coefficients() {
1091            Value::Tensor(t) => t,
1092            other => panic!("expected tensor coefficients, got {other:?}"),
1093        };
1094        let gpu_coeff = match gpu_eval.coefficients() {
1095            Value::Tensor(t) => t,
1096            other => panic!("expected tensor coefficients, got {other:?}"),
1097        };
1098        assert_eq!(cpu_coeff.shape, gpu_coeff.shape);
1099        for (a, b) in cpu_coeff.data.iter().zip(gpu_coeff.data.iter()) {
1100            assert!((a - b).abs() < 1e-9, "coeff mismatch {a} vs {b}");
1101        }
1102
1103        let cpu_mu = match cpu_eval.mu() {
1104            Value::Tensor(t) => t,
1105            other => panic!("expected tensor mu, got {other:?}"),
1106        };
1107        let gpu_mu = match gpu_eval.mu() {
1108            Value::Tensor(t) => t,
1109            other => panic!("expected tensor mu, got {other:?}"),
1110        };
1111        assert_eq!(cpu_mu.shape, gpu_mu.shape);
1112        for (a, b) in cpu_mu.data.iter().zip(gpu_mu.data.iter()) {
1113            assert!((a - b).abs() < 1e-9, "mu mismatch {a} vs {b}");
1114        }
1115
1116        let cpu_stats = match cpu_eval.stats() {
1117            Value::Struct(s) => s,
1118            other => panic!("expected struct stats, got {other:?}"),
1119        };
1120        let gpu_stats = match gpu_eval.stats() {
1121            Value::Struct(s) => s,
1122            other => panic!("expected struct stats, got {other:?}"),
1123        };
1124        let cpu_r = match cpu_stats.fields.get("R").expect("R present") {
1125            Value::Tensor(t) => t.clone(),
1126            other => panic!("expected tensor R, got {other:?}"),
1127        };
1128        let gpu_r = match gpu_stats.fields.get("R").expect("R present") {
1129            Value::Tensor(t) => t.clone(),
1130            other => panic!("expected tensor R, got {other:?}"),
1131        };
1132        assert_eq!(cpu_r.shape, gpu_r.shape);
1133        for (a, b) in cpu_r.data.iter().zip(gpu_r.data.iter()) {
1134            assert!((a - b).abs() < 1e-9, "R mismatch {a} vs {b}");
1135        }
1136        let cpu_df = match cpu_stats.fields.get("df").expect("df present") {
1137            Value::Num(n) => *n,
1138            other => panic!("expected numeric df, got {other:?}"),
1139        };
1140        let gpu_df = match gpu_stats.fields.get("df").expect("df present") {
1141            Value::Num(n) => *n,
1142            other => panic!("expected numeric df, got {other:?}"),
1143        };
1144        assert!((cpu_df - gpu_df).abs() < 1e-9);
1145        let cpu_normr = match cpu_stats.fields.get("normr").expect("normr present") {
1146            Value::Num(n) => *n,
1147            other => panic!("expected numeric normr, got {other:?}"),
1148        };
1149        let gpu_normr = match gpu_stats.fields.get("normr").expect("normr present") {
1150            Value::Num(n) => *n,
1151            other => panic!("expected numeric normr, got {other:?}"),
1152        };
1153        assert!((cpu_normr - gpu_normr).abs() < 1e-9);
1154    }
1155
1156    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1157    #[test]
1158    fn rejects_mismatched_lengths() {
1159        let x = Tensor::new(vec![0.0, 1.0, 2.0], vec![3, 1]).unwrap();
1160        let y = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
1161        let err = evaluate(
1162            Value::Tensor(x),
1163            Value::Tensor(y),
1164            Value::Int(runmat_builtins::IntValue::I32(1)),
1165            &[],
1166        )
1167        .expect_err("polyfit should reject mismatched vector lengths");
1168        assert_error_contains(err, "same length");
1169    }
1170
1171    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1172    #[test]
1173    fn rejects_non_vector_inputs() {
1174        let x = Tensor::new(vec![0.0, 1.0, 2.0, 3.0], vec![2, 2]).unwrap();
1175        let y = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![4, 1]).unwrap();
1176        let err = evaluate(
1177            Value::Tensor(x),
1178            Value::Tensor(y),
1179            Value::Int(runmat_builtins::IntValue::I32(1)),
1180            &[],
1181        )
1182        .expect_err("polyfit should reject non-vector X");
1183        assert_error_contains(err, "vector");
1184    }
1185
1186    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1187    #[test]
1188    fn rejects_weight_length_mismatch() {
1189        let x = Tensor::new(vec![0.0, 1.0, 2.0], vec![3, 1]).unwrap();
1190        let y = Tensor::new(vec![1.0, 3.0, 7.0], vec![3, 1]).unwrap();
1191        let weights = Tensor::new(vec![1.0, 1.0], vec![2, 1]).unwrap();
1192        let err = evaluate(
1193            Value::Tensor(x),
1194            Value::Tensor(y),
1195            Value::Int(runmat_builtins::IntValue::I32(2)),
1196            &[Value::Tensor(weights)],
1197        )
1198        .expect_err("polyfit should reject mismatched weights");
1199        assert_error_contains(err, "weight vector must match");
1200    }
1201
1202    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1203    #[test]
1204    fn rejects_negative_weights() {
1205        let x = Tensor::new(vec![0.0, 1.0, 2.0], vec![3, 1]).unwrap();
1206        let y = Tensor::new(vec![1.0, 3.0, 7.0], vec![3, 1]).unwrap();
1207        let weights = Tensor::new(vec![1.0, -1.0, 1.0], vec![3, 1]).unwrap();
1208        let err = evaluate(
1209            Value::Tensor(x),
1210            Value::Tensor(y),
1211            Value::Int(runmat_builtins::IntValue::I32(2)),
1212            &[Value::Tensor(weights)],
1213        )
1214        .expect_err("polyfit should reject negative weights");
1215        assert_error_contains(err, "non-negative");
1216    }
1217
1218    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1219    #[test]
1220    fn fits_complex_data() {
1221        let x = Tensor::new(vec![0.0, 1.0, 2.0], vec![3, 1]).unwrap();
1222        let complex_values =
1223            ComplexTensor::new(vec![(0.0, 1.0), (1.0, 0.5), (4.0, -0.25)], vec![3, 1]).unwrap();
1224        let eval = evaluate(
1225            Value::Tensor(x),
1226            Value::ComplexTensor(complex_values),
1227            Value::Int(runmat_builtins::IntValue::I32(2)),
1228            &[],
1229        )
1230        .expect("polyfit complex");
1231        match eval.coefficients() {
1232            Value::ComplexTensor(t) => {
1233                assert_eq!(t.shape, vec![1, 3]);
1234            }
1235            other => panic!("expected complex tensor coefficients, got {other:?}"),
1236        }
1237    }
1238}