1use super::*;
2
3#[derive(Debug, Clone)]
4pub struct CollocationOperatorMatrices {
5 pub d0: Array2<f64>,
6 pub d1: Array2<f64>,
7 pub d2: Array2<f64>,
8 pub collocation_points: Array2<f64>,
9 pub kernel_nullspace_transform: Option<Array2<f64>>,
13 pub polynomial_block_cols: usize,
16}
17
18#[derive(Debug, Clone)]
19pub struct DuchonOperatorPenaltyMatrices {
20 pub mass: Array2<f64>,
21 pub tension: Array2<f64>,
22 pub stiffness: Array2<f64>,
23}
24
25#[derive(Debug, Clone)]
26pub struct ThinPlatePenaltyMatrix {
27 pub penalty: Array2<f64>,
28}
29
30pub(crate) fn validate_center_count(num_centers: usize) -> Result<(), BasisError> {
31 if num_centers == 0 {
32 crate::bail_invalid_basis!("center count must be positive");
33 }
34 Ok(())
35}
36
37pub(crate) fn select_equal_mass_centers(
38 data: ArrayView2<'_, f64>,
39 num_centers: usize,
40) -> Result<Array2<f64>, BasisError> {
41 validate_center_count(num_centers)?;
42 let n = data.nrows();
43 let d = data.ncols();
44 if num_centers > n {
45 crate::bail_invalid_basis!(
46 "equal-mass center selection requested {num_centers} centers but data has {n} rows"
47 );
48 }
49 if d == 0 {
50 crate::bail_invalid_basis!("equal-mass center selection requires at least one column");
51 }
52 #[derive(Clone, Copy)]
53 struct Leaf {
54 pub(crate) start: usize,
55 pub(crate) end: usize,
56 }
57
58 let mut order: Vec<usize> = (0..n).collect();
75 let mut leaves = vec![Leaf { start: 0, end: n }];
76
77 let principal_axis = |slice: &[usize]| -> Option<Vec<f64>> {
86 let m = slice.len();
87 if m < 2 {
88 return None;
89 }
90 let mut centroid = vec![0.0_f64; d];
91 for &idx in slice {
92 for j in 0..d {
93 centroid[j] += data[[idx, j]];
94 }
95 }
96 let inv = 1.0 / m as f64;
97 for v in &mut centroid {
98 *v *= inv;
99 }
100 let mut cov = Array2::<f64>::zeros((d, d));
103 for &idx in slice {
104 for a in 0..d {
105 let da = data[[idx, a]] - centroid[a];
106 for b in a..d {
107 let db = data[[idx, b]] - centroid[b];
108 cov[[a, b]] += da * db;
109 }
110 }
111 }
112 for a in 0..d {
113 cov[[a, a]] *= inv;
114 for b in (a + 1)..d {
115 cov[[a, b]] *= inv;
116 cov[[b, a]] = cov[[a, b]];
117 }
118 }
119 if cov.iter().any(|v| !v.is_finite()) {
120 return None;
121 }
122 let mut axis: Vec<f64> = if d == 2 {
140 let sxx = cov[[0, 0]];
141 let syy = cov[[1, 1]];
142 let sxy = cov[[0, 1]];
143 if sxy == 0.0 && sxx == syy {
144 return None;
145 }
146 let angle = 0.5 * (2.0 * sxy).atan2(sxx - syy);
147 vec![angle.cos(), angle.sin()]
148 } else {
149 let (evals, evecs) = cov.eigh(Side::Lower).ok()?;
152 let last = evals.len().checked_sub(1)?;
153 if !(evals[last] > 0.0) {
154 return None;
155 }
156 (0..d).map(|r| evecs[[r, last]]).collect()
157 };
158 if axis.iter().any(|v| !v.is_finite()) {
159 return None;
160 }
161 let mut far_idx = slice[0];
172 let mut far_d2 = f64::NEG_INFINITY;
173 for &idx in slice {
174 let mut d2 = 0.0_f64;
175 for j in 0..d {
176 let delta = data[[idx, j]] - centroid[j];
177 d2 += delta * delta;
178 }
179 if d2 > far_d2 || (d2 == far_d2 && idx < far_idx) {
180 far_d2 = d2;
181 far_idx = idx;
182 }
183 }
184 let mut proj = 0.0_f64;
185 for j in 0..d {
186 proj += (data[[far_idx, j]] - centroid[j]) * axis[j];
187 }
188 if proj < 0.0 {
189 for v in &mut axis {
190 *v = -*v;
191 }
192 } else if proj == 0.0 {
193 let mut pivot = 0usize;
197 for r in 1..d {
198 if axis[r].abs() > axis[pivot].abs() {
199 pivot = r;
200 }
201 }
202 if axis[pivot] < 0.0 {
203 for v in &mut axis {
204 *v = -*v;
205 }
206 }
207 }
208 Some(axis)
209 };
210
211 while leaves.len() < num_centers {
212 let mut split_pos = None;
213 let mut split_size = 0usize;
214 for (i, leaf) in leaves.iter().enumerate() {
215 let leaf_size = leaf.end - leaf.start;
216 if leaf_size > split_size && leaf_size > 1 {
217 split_size = leaf_size;
218 split_pos = Some(i);
219 }
220 }
221 let Some(pos) = split_pos else {
222 break;
223 };
224
225 let leaf = leaves.swap_remove(pos);
226 let axis = principal_axis(&order[leaf.start..leaf.end]);
227 match axis {
228 Some(axis) => {
229 order[leaf.start..leaf.end].sort_by(|&a, &b| {
233 let mut pa = 0.0_f64;
234 let mut pb = 0.0_f64;
235 for j in 0..d {
236 pa += data[[a, j]] * axis[j];
237 pb += data[[b, j]] * axis[j];
238 }
239 let ord = pa.total_cmp(&pb);
240 if ord.is_eq() { a.cmp(&b) } else { ord }
241 });
242 }
243 None => {
244 order[leaf.start..leaf.end].sort_by(|&a, &b| {
248 for j in 0..d {
249 let ord = data[[a, j]].total_cmp(&data[[b, j]]);
250 if !ord.is_eq() {
251 return ord;
252 }
253 }
254 a.cmp(&b)
255 });
256 }
257 }
258 let mid = leaf.start + (split_size / 2);
259
260 if mid == leaf.start || mid == leaf.end {
261 leaves.push(leaf);
262 break;
263 }
264
265 leaves.push(Leaf {
266 start: leaf.start,
267 end: mid,
268 });
269 leaves.push(Leaf {
270 start: mid,
271 end: leaf.end,
272 });
273 }
274
275 if leaves.len() < num_centers {
276 crate::bail_invalid_basis!(
277 "equal-mass partition produced {} leaves, expected {num_centers}",
278 leaves.len()
279 );
280 }
281
282 let mut centers = Array2::<f64>::zeros((num_centers, d));
283 for (c, leaf) in leaves.iter().take(num_centers).enumerate() {
284 let slice = &order[leaf.start..leaf.end];
285 let m = slice.len() as f64;
286 let mut centroid = vec![0.0_f64; d];
287 for &idx in slice {
288 for j in 0..d {
289 centroid[j] += data[[idx, j]];
290 }
291 }
292 for v in &mut centroid {
293 *v /= m.max(1.0);
294 }
295
296 let best_idx = slice
297 .par_iter()
298 .filter_map(|&idx| {
299 let mut d2 = 0.0;
300 for j in 0..d {
301 let delta = data[[idx, j]] - centroid[j];
302 d2 += delta * delta;
303 }
304 if d2.is_finite() {
305 Some((idx, d2))
306 } else {
307 None
308 }
309 })
310 .reduce_with(|a, b| {
311 if b.1 < a.1 || (b.1 == a.1 && b.0 < a.0) {
312 b
313 } else {
314 a
315 }
316 })
317 .map(|(idx, _)| idx)
318 .unwrap_or(slice[0]);
319 centers.row_mut(c).assign(&data.row(best_idx));
320 }
321 Ok(centers)
322}
323
324pub(crate) fn select_equal_mass_covar_representative_centers(
325 data: ArrayView2<'_, f64>,
326 num_centers: usize,
327) -> Result<Array2<f64>, BasisError> {
328 validate_center_count(num_centers)?;
329 let n = data.nrows();
330 let d = data.ncols();
331 if num_centers > n {
332 crate::bail_invalid_basis!(
333 "equal-mass covariate-representative center selection requested {num_centers} centers but data has {n} rows"
334 );
335 }
336 if d == 0 {
337 crate::bail_invalid_basis!(
338 "equal-mass covariate-representative center selection requires at least one column"
339 .to_string(),
340 );
341 }
342
343 let mut split_dim = 0usize;
344 let mut best_span = f64::NEG_INFINITY;
345 for j in 0..d {
346 let mut minv = f64::INFINITY;
347 let mut maxv = f64::NEG_INFINITY;
348 for i in 0..n {
349 let v = data[[i, j]];
350 if v < minv {
351 minv = v;
352 }
353 if v > maxv {
354 maxv = v;
355 }
356 }
357 let span = maxv - minv;
358 if span > best_span {
359 best_span = span;
360 split_dim = j;
361 }
362 }
363
364 let mut sorted: Vec<usize> = (0..n).collect();
365 sorted.sort_by(|&a, &b| {
366 let ord = data[[a, split_dim]].total_cmp(&data[[b, split_dim]]);
367 if ord.is_eq() { a.cmp(&b) } else { ord }
368 });
369
370 let mut centers = Array2::<f64>::zeros((num_centers, d));
371 for c in 0..num_centers {
372 let lo = (c * n) / num_centers;
373 let hi = ((c + 1) * n) / num_centers;
374 let chunk = &sorted[lo..hi.max(lo + 1)];
375 let mid = chunk[chunk.len() / 2];
376 centers.row_mut(c).assign(&data.row(mid));
377 }
378 Ok(centers)
379}
380
381pub(crate) fn select_kmeans_centers(
382 data: ArrayView2<'_, f64>,
383 num_centers: usize,
384 max_iter: usize,
385) -> Result<Array2<f64>, BasisError> {
386 validate_center_count(num_centers)?;
387 let n = data.nrows();
388 let d = data.ncols();
389 if num_centers > n {
390 crate::bail_invalid_basis!("kmeans requested {num_centers} centers but data has {n} rows");
391 }
392 const KMEANS_PILOT_MAX_ROWS: usize = 20_000;
393 if n > KMEANS_PILOT_MAX_ROWS {
394 let pilot_n = KMEANS_PILOT_MAX_ROWS.max(num_centers);
395 log::info!(
400 "kmeans center selection using {}-row pilot subsample instead of full {} rows",
401 pilot_n,
402 n
403 );
404 let pilot = select_equal_mass_covar_representative_centers(data, pilot_n)?;
405 return select_kmeans_centers(pilot.view(), num_centers, max_iter);
406 }
407 let mut centers = select_thin_plate_knots(data, num_centers)?;
408 let mut assign = vec![0usize; n];
409 let iters = max_iter.max(1);
410
411 let use_parallel = n >= 10_000;
414
415 for _ in 0..iters {
416 if use_parallel {
418 const KMEANS_CHUNK: usize = 4096;
419 assign
420 .par_chunks_mut(KMEANS_CHUNK)
421 .enumerate()
422 .for_each(|(ci, chunk)| {
423 let base = ci * KMEANS_CHUNK;
424 for (local, slot) in chunk.iter_mut().enumerate() {
425 let i = base + local;
426 let mut best = 0usize;
427 let mut best_d2 = f64::INFINITY;
428 for k in 0..num_centers {
429 let mut d2 = 0.0;
430 for c in 0..d {
431 let delta = data[[i, c]] - centers[[k, c]];
432 d2 += delta * delta;
433 }
434 if d2 < best_d2 {
435 best_d2 = d2;
436 best = k;
437 }
438 }
439 *slot = best;
440 }
441 });
442 } else {
443 for i in 0..n {
444 let mut best = 0usize;
445 let mut best_d2 = f64::INFINITY;
446 for k in 0..num_centers {
447 let mut d2 = 0.0;
448 for c in 0..d {
449 let delta = data[[i, c]] - centers[[k, c]];
450 d2 += delta * delta;
451 }
452 if d2 < best_d2 {
453 best_d2 = d2;
454 best = k;
455 }
456 }
457 assign[i] = best;
458 }
459 }
460 let mut sums = Array2::<f64>::zeros((num_centers, d));
462 let mut counts = vec![0usize; num_centers];
463 for i in 0..n {
464 let k = assign[i];
465 counts[k] += 1;
466 for c in 0..d {
467 sums[[k, c]] += data[[i, c]];
468 }
469 }
470 for k in 0..num_centers {
471 if counts[k] == 0 {
472 continue;
473 }
474 let inv = 1.0 / counts[k] as f64;
475 for c in 0..d {
476 centers[[k, c]] = sums[[k, c]] * inv;
477 }
478 }
479 }
480 Ok(centers)
481}
482
483pub(crate) fn cartesian_grid_axes(axes: &[Array1<f64>]) -> Result<Array2<f64>, BasisError> {
484 if axes.is_empty() {
485 crate::bail_invalid_basis!("uniform grid requires at least one axis");
486 }
487 let d = axes.len();
488 let total = axes.iter().try_fold(1usize, |acc, axis| {
489 acc.checked_mul(axis.len())
490 .ok_or_else(|| BasisError::DimensionMismatch("uniform grid is too large".to_string()))
491 })?;
492 let mut out = Array2::<f64>::zeros((total, d));
493 for r in 0..total {
494 let mut q = r;
495 for c in (0..d).rev() {
496 let len = axes[c].len();
497 let idx = q % len;
498 q /= len;
499 out[[r, c]] = axes[c][idx];
500 }
501 }
502 Ok(out)
503}
504
505pub(crate) fn select_uniform_grid_centers(
506 data: ArrayView2<'_, f64>,
507 points_per_dim: usize,
508) -> Result<Array2<f64>, BasisError> {
509 if points_per_dim == 0 {
510 crate::bail_invalid_basis!("uniform-grid points_per_dim must be positive");
511 }
512 let d = data.ncols();
513 if d == 0 {
514 crate::bail_invalid_basis!("uniform-grid center selection requires at least one column");
515 }
516 let mut axes = Vec::with_capacity(d);
517 for c in 0..d {
518 let col = data.column(c);
519 let minv = col.iter().fold(f64::INFINITY, |a, &b| a.min(b));
520 let maxv = col.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
521 axes.push(Array::linspace(minv, maxv, points_per_dim));
522 }
523 cartesian_grid_axes(&axes)
524}
525
526#[cfg(test)]
527mod tests {
528 use super::*;
529
530 fn make_points() -> Array2<f64> {
535 let n_side = 11usize;
536 let n = n_side * n_side;
537 let mut pts = Array2::<f64>::zeros((n, 2));
538 let mut state: u64 = 0x9E37_79B9_7F4A_7C15;
539 let mut next = || {
540 state ^= state >> 12;
542 state ^= state << 25;
543 state ^= state >> 27;
544 let v = state.wrapping_mul(0x2545_F491_4F6C_DD1D);
545 ((v >> 11) as f64) / ((1u64 << 53) as f64)
546 };
547 let mut r = 0usize;
548 for i in 0..n_side {
549 for j in 0..n_side {
550 let x = i as f64;
551 let y = 0.35 * i as f64 + 1.7 * j as f64;
554 pts[[r, 0]] = x + 0.05 * (next() - 0.5);
555 pts[[r, 1]] = y + 0.05 * (next() - 0.5);
556 r += 1;
557 }
558 }
559 pts
560 }
561
562 fn assert_center_sets_match(
566 expected: ArrayView2<'_, f64>,
567 actual: ArrayView2<'_, f64>,
568 tol: f64,
569 ) {
570 assert_eq!(expected.nrows(), actual.nrows(), "center counts differ");
571 let k = expected.nrows();
572 let mut used = vec![false; k];
573 let mut worst = 0.0_f64;
574 for ei in 0..k {
575 let mut best = usize::MAX;
576 let mut best_d2 = f64::INFINITY;
577 for ai in 0..k {
578 if used[ai] {
579 continue;
580 }
581 let dx = expected[[ei, 0]] - actual[[ai, 0]];
582 let dy = expected[[ei, 1]] - actual[[ai, 1]];
583 let d2 = dx * dx + dy * dy;
584 if d2 < best_d2 {
585 best_d2 = d2;
586 best = ai;
587 }
588 }
589 assert!(best != usize::MAX, "no unmatched center available");
590 used[best] = true;
591 worst = worst.max(best_d2.sqrt());
592 }
593 assert!(
594 worst <= tol,
595 "rotation-equivariance violated: worst center match residual {worst:.3e} > tol {tol:.3e}"
596 );
597 }
598
599 #[test]
603 fn equal_mass_centers_are_permutation_invariant() {
604 let pts = make_points();
605 let num_centers = 16usize;
606 let base = select_equal_mass_centers(pts.view(), num_centers).unwrap();
607
608 let n = pts.nrows();
609 let mut perm: Vec<usize> = (0..n).collect();
610 let mut state: u64 = 0xD1B5_4A32_D192_ED03;
612 for i in (1..n).rev() {
613 state ^= state >> 12;
614 state ^= state << 25;
615 state ^= state >> 27;
616 let j = (state.wrapping_mul(0x2545_F491_4F6C_DD1D) % (i as u64 + 1)) as usize;
617 perm.swap(i, j);
618 }
619 let mut permuted = Array2::<f64>::zeros((n, 2));
620 for (new_r, &old_r) in perm.iter().enumerate() {
621 permuted[[new_r, 0]] = pts[[old_r, 0]];
622 permuted[[new_r, 1]] = pts[[old_r, 1]];
623 }
624 let permuted_centers = select_equal_mass_centers(permuted.view(), num_centers).unwrap();
625 assert_center_sets_match(base.view(), permuted_centers.view(), 1e-13);
626 }
627
628 #[test]
638 fn equal_mass_centers_are_rotation_equivariant() {
639 let n = 300usize;
642 let mut pts = Array2::<f64>::zeros((n, 2));
643 let mut state: u64 = 0x1234_5678_9ABC_DEF0;
644 let mut next = || {
645 state ^= state >> 12;
646 state ^= state << 25;
647 state ^= state >> 27;
648 let v = state.wrapping_mul(0x2545_F491_4F6C_DD1D);
649 ((v >> 11) as f64) / ((1u64 << 53) as f64)
650 };
651 for r in 0..n {
652 let u = next();
653 let v = next();
654 pts[[r, 0]] = 2.0 * u - 1.0 + 0.6 * (2.0 * v - 1.0);
656 pts[[r, 1]] = 2.0 * v - 1.0;
657 }
658 let num_centers = 48usize;
659 let base = select_equal_mass_centers(pts.view(), num_centers).unwrap();
660
661 let mut cx = 0.0;
663 let mut cy = 0.0;
664 for r in 0..n {
665 cx += pts[[r, 0]];
666 cy += pts[[r, 1]];
667 }
668 cx /= n as f64;
669 cy /= n as f64;
670
671 for &(ca, sa) in &[(0.0_f64, 1.0_f64), (0.7f64.cos(), 0.7f64.sin())] {
672 let mut rot = Array2::<f64>::zeros((n, 2));
673 for r in 0..n {
674 let x = pts[[r, 0]] - cx;
675 let y = pts[[r, 1]] - cy;
676 rot[[r, 0]] = ca * x - sa * y + cx;
677 rot[[r, 1]] = sa * x + ca * y + cy;
678 }
679 let rotated_centers = select_equal_mass_centers(rot.view(), num_centers).unwrap();
680 let mut unrotated = Array2::<f64>::zeros((num_centers, 2));
683 for r in 0..num_centers {
684 let x = rotated_centers[[r, 0]] - cx;
685 let y = rotated_centers[[r, 1]] - cy;
686 unrotated[[r, 0]] = ca * x + sa * y + cx;
687 unrotated[[r, 1]] = -sa * x + ca * y + cy;
688 }
689 assert_center_sets_match(base.view(), unrotated.view(), 1e-9);
690 }
691 }
692}