use super::{BasisError, BasisOptions, Dense, KnotSource, create_basis};
use gam_linalg::faer_ndarray::fast_ata;
use ndarray::{Array1, Array2, ArrayView1};
pub fn create_cyclic_difference_penalty_matrix(
num_basis_functions: usize,
order: usize,
) -> Result<Array2<f64>, BasisError> {
if order == 0 || order >= num_basis_functions {
return Err(BasisError::InvalidPenaltyOrder {
order,
num_basis: num_basis_functions,
});
}
let mut d = Array2::<f64>::eye(num_basis_functions);
for _ in 0..order {
let previous = d;
d = Array2::<f64>::zeros((num_basis_functions, num_basis_functions));
for i in 0..num_basis_functions {
let next = (i + 1) % num_basis_functions;
for j in 0..num_basis_functions {
d[[i, j]] = previous[[next, j]] - previous[[i, j]];
}
}
}
Ok(fast_ata(&d))
}
pub fn create_open_difference_penalty_matrix(
num_basis_functions: usize,
order: usize,
) -> Result<Array2<f64>, BasisError> {
if order == 0 || order >= num_basis_functions {
return Err(BasisError::InvalidPenaltyOrder {
order,
num_basis: num_basis_functions,
});
}
let rows = num_basis_functions - order;
let mut d = Array2::<f64>::zeros((rows, num_basis_functions));
for i in 0..rows {
for j in 0..=order {
let sign = if (order - j) % 2 == 0 { 1.0 } else { -1.0 };
d[[i, i + j]] = sign * binomial(order, j);
}
}
Ok(fast_ata(&d))
}
pub fn create_closure_difference_penalty_jet(
num_basis_functions: usize,
order: usize,
gamma: f64,
) -> Result<(Array2<f64>, Array2<f64>, Array2<f64>), BasisError> {
let s_open = create_open_difference_penalty_matrix(num_basis_functions, order)?;
let s_circle = create_cyclic_difference_penalty_matrix(num_basis_functions, order)?;
let s_wrap = &s_circle - &s_open;
Ok(gam_geometry::conductance_penalty_jet(
&s_open, &s_wrap, gamma,
))
}
pub(crate) fn binomial(n: usize, k: usize) -> f64 {
let mut acc = 1.0_f64;
for i in 0..k {
acc = acc * (n - i) as f64 / (i + 1) as f64;
}
acc
}
#[inline]
pub(crate) fn wrap_to_period(x: f64, start: f64, period: f64) -> f64 {
let offset = (x - start).rem_euclid(period);
if offset >= period {
start
} else {
start + offset
}
}
#[inline]
pub(crate) fn cyclic_distance_1d(x: f64, c: f64, period: f64) -> f64 {
let delta = (x - c).abs().rem_euclid(period);
delta.min(period - delta)
}
#[inline]
pub(crate) fn cyclic_knot_anchor(start: f64, period: f64, num_basis: usize) -> (f64, f64) {
let h = period / num_basis as f64;
let anchor = start - start.rem_euclid(h);
(anchor, h)
}
pub(crate) fn cyclic_uniform_knot_vector(
start: f64,
end: f64,
degree: usize,
num_basis: usize,
) -> Array1<f64> {
let period = end - start;
let (anchor, h) = cyclic_knot_anchor(start, period, num_basis);
let total_knots = num_basis + 2 * degree + 1;
Array1::from_iter((0..total_knots).map(|i| anchor + (i as f64 - degree as f64) * h))
}
pub(crate) fn create_cyclic_bspline_basis_dense(
data: ArrayView1<'_, f64>,
start: f64,
end: f64,
degree: usize,
num_basis: usize,
) -> Result<(Array2<f64>, Array1<f64>), BasisError> {
if end <= start {
return Err(BasisError::InvalidRange(start, end));
}
if num_basis <= degree {
crate::bail_invalid_basis!(
"cyclic B-spline basis requires more basis functions ({num_basis}) than degree ({degree})"
);
}
let period = end - start;
let (anchor, _) = cyclic_knot_anchor(start, period, num_basis);
let wrapped = data.mapv(|x| wrap_to_period(x, anchor, period));
let knots = cyclic_uniform_knot_vector(start, end, degree, num_basis);
let (extended, _) = create_basis::<Dense>(
wrapped.view(),
KnotSource::Provided(knots.view()),
degree,
BasisOptions::value(),
)?;
let mut cyclic = Array2::<f64>::zeros((data.len(), num_basis));
for i in 0..extended.nrows() {
for j in 0..extended.ncols() {
let target = j % num_basis;
cyclic[[i, target]] += extended[[i, j]];
}
}
Ok((cyclic, knots))
}
#[cfg(test)]
mod closure_tests {
use super::*;
#[test]
fn cyclic_basis_rigid_rotation_under_whole_knot_shift() {
let degree = 3usize;
let num_basis = 8usize;
let start = 0.0_f64;
let period = std::f64::consts::TAU;
let h = period / num_basis as f64;
let thetas =
Array1::from_iter((0..40).map(|i| (i as f64 + 0.3) / 40.0 * period));
let (b0, _) =
create_cyclic_bspline_basis_dense(thetas.view(), start, start + period, degree, num_basis)
.unwrap();
let (b1, _) = create_cyclic_bspline_basis_dense(
thetas.view(),
start + h,
start + h + period,
degree,
num_basis,
)
.unwrap();
let mut best = f64::INFINITY;
let mut best_shift = 0usize;
for shift in 0..num_basis {
let mut maxerr = 0.0_f64;
for r in 0..b0.nrows() {
for j in 0..num_basis {
let permuted = b0[[r, (j + shift) % num_basis]];
maxerr = maxerr.max((b1[[r, j]] - permuted).abs());
}
}
if maxerr < best {
best = maxerr;
best_shift = shift;
}
}
eprintln!("[cyclic-rigid] best cyclic-shift match err={best:.3e} at shift={best_shift}");
assert!(
best < 1e-10,
"cyclic basis is NOT a rigid cyclic permutation under a whole-knot seam shift: \
best max|ΔB| over all {num_basis} shifts = {best:.3e} (shift {best_shift})"
);
}
#[test]
fn cyclic_basis_span_invariant_to_subknot_seam_shift() {
use gam_linalg::faer_ndarray::{fast_ab, fast_ata};
let degree = 3usize;
let num_basis = 12usize;
let period = std::f64::consts::TAU;
let h = period / num_basis as f64;
let thetas = Array1::from_iter((0..200).map(|i| (i as f64 + 0.123) / 200.0 * period));
let reference = {
let (b, _) =
create_cyclic_bspline_basis_dense(thetas.view(), 0.0, period, degree, num_basis)
.unwrap();
b
};
let gram = fast_ata(&reference);
for frac in [0.37_f64, 0.5, 0.81, 1.0, 1.5, 2.8] {
let seam = frac * h;
let (bs, _) = create_cyclic_bspline_basis_dense(
thetas.view(),
seam,
seam + period,
degree,
num_basis,
)
.unwrap();
let rtb = fast_ab(&reference.t().to_owned(), &bs);
let gram_inv = {
use faer::Side;
use gam_linalg::faer_ndarray::FaerCholesky;
let chol = gram.cholesky(Side::Lower).unwrap();
let mut id = Array2::<f64>::eye(gram.nrows());
chol.solve_mat_in_place(&mut id);
id
};
let coef = fast_ab(&gram_inv, &rtb);
let approx = fast_ab(&reference, &coef);
let resid = &bs - ≈
let rel = resid.iter().map(|v| v * v).sum::<f64>().sqrt()
/ bs.iter().map(|v| v * v).sum::<f64>().sqrt().max(1e-300);
eprintln!("[cyclic-span] seam={seam:.4} (frac {frac}) span residual={rel:.3e}");
assert!(
rel < 1e-9,
"cyclic basis span is NOT seam-invariant at sub-knot shift frac={frac}: \
relative residual {rel:.3e} (the knot grid must anchor to a canonical phase, \
#1593)"
);
}
}
#[test]
pub(crate) fn closure_penalty_interpolates_open_to_cyclic() {
let n = 8;
let order = 2;
let s_open = create_open_difference_penalty_matrix(n, order).unwrap();
let s_circle = create_cyclic_difference_penalty_matrix(n, order).unwrap();
let (s0, _, _) = create_closure_difference_penalty_jet(n, order, 0.0).unwrap();
let (s1, _, _) = create_closure_difference_penalty_jet(n, order, 1.0).unwrap();
assert!((&s0 - &s_open).iter().all(|v| v.abs() < 1e-12));
assert!((&s1 - &s_circle).iter().all(|v| v.abs() < 1e-12));
}
#[test]
pub(crate) fn closure_penalty_gamma_derivative_matches_fd() {
let n = 6;
let order = 2;
let g = 0.45;
let (_, ds, _) = create_closure_difference_penalty_jet(n, order, g).unwrap();
let h = 1e-6;
let (sp, _, _) = create_closure_difference_penalty_jet(n, order, g + h).unwrap();
let (sm, _, _) = create_closure_difference_penalty_jet(n, order, g - h).unwrap();
let fd = (&sp - &sm).mapv(|v| v / (2.0 * h));
assert!((&ds - &fd).iter().all(|v| v.abs() < 1e-6));
}
#[test]
pub(crate) fn open_penalty_null_space_is_larger_than_cyclic() {
let n = 7;
let s_open = create_open_difference_penalty_matrix(n, 2).unwrap();
let ones = ndarray::Array1::<f64>::ones(n);
let ramp = ndarray::Array1::from_iter((0..n).map(|i| i as f64));
let open_const = ones.dot(&s_open.dot(&ones));
let open_ramp = ramp.dot(&s_open.dot(&ramp));
assert!(open_const.abs() < 1e-10, "open S·1 ≠ 0: {open_const}");
assert!(open_ramp.abs() < 1e-8, "open S·ramp ≠ 0: {open_ramp}");
}
}