use ndarray::{Array4, s};
use ruvector_solver::neumann::NeumannSolver;
use ruvector_solver::types::CsrMatrix;
pub fn interpolate_subcarriers(arr: &Array4<f32>, target_sc: usize) -> Array4<f32> {
assert!(target_sc > 0, "target_sc must be > 0");
let shape = arr.shape();
let (n_t, n_tx, n_rx, n_sc) = (shape[0], shape[1], shape[2], shape[3]);
if n_sc == target_sc {
return arr.clone();
}
let mut out = Array4::<f32>::zeros((n_t, n_tx, n_rx, target_sc));
let weights = compute_interp_weights(n_sc, target_sc);
for t in 0..n_t {
for tx in 0..n_tx {
for rx in 0..n_rx {
let src = arr.slice(s![t, tx, rx, ..]);
let src_slice = src.as_slice().unwrap_or_else(|| {
panic!("Subcarrier slice is not contiguous");
});
for (k, &(i0, i1, w)) in weights.iter().enumerate() {
let v = src_slice[i0] * (1.0 - w) + src_slice[i1] * w;
out[[t, tx, rx, k]] = v;
}
}
}
}
out
}
pub fn compute_interp_weights(src_sc: usize, target_sc: usize) -> Vec<(usize, usize, f32)> {
assert!(src_sc > 0, "src_sc must be > 0");
assert!(target_sc > 0, "target_sc must be > 0");
let mut weights = Vec::with_capacity(target_sc);
for k in 0..target_sc {
let pos = if target_sc == 1 {
0.0f32
} else {
k as f32 * (src_sc - 1) as f32 / (target_sc - 1) as f32
};
let i0 = (pos.floor() as usize).min(src_sc - 1);
let i1 = (pos.ceil() as usize).min(src_sc - 1);
let frac = pos - pos.floor();
weights.push((i0, i1, frac));
}
weights
}
pub fn interpolate_subcarriers_sparse(arr: &Array4<f32>, target_sc: usize) -> Array4<f32> {
assert!(target_sc > 0, "target_sc must be > 0");
let shape = arr.shape();
let (n_t, n_tx, n_rx, n_sc) = (shape[0], shape[1], shape[2], shape[3]);
if n_sc == target_sc {
return arr.clone();
}
let sigma = 0.15_f32;
let sigma_sq = sigma * sigma;
let src_pos: Vec<f32> = (0..n_sc).map(|j| {
if n_sc == 1 { 0.0 } else { j as f32 / (n_sc - 1) as f32 }
}).collect();
let tgt_pos: Vec<f32> = (0..target_sc).map(|k| {
if target_sc == 1 { 0.0 } else { k as f32 / (target_sc - 1) as f32 }
}).collect();
let threshold = 1e-4_f32;
let lambda = 0.1_f32; let mut ata_coo: Vec<(usize, usize, f32)> = Vec::new();
let mut ata = vec![vec![0.0_f32; target_sc]; target_sc];
for j in 0..n_sc {
for k1 in 0..target_sc {
let diff1 = src_pos[j] - tgt_pos[k1];
let a_jk1 = (-diff1 * diff1 / sigma_sq).exp();
if a_jk1 < threshold { continue; }
for k2 in 0..target_sc {
let diff2 = src_pos[j] - tgt_pos[k2];
let a_jk2 = (-diff2 * diff2 / sigma_sq).exp();
if a_jk2 < threshold { continue; }
ata[k1][k2] += a_jk1 * a_jk2;
}
}
}
for k in 0..target_sc {
for k2 in 0..target_sc {
let val = ata[k][k2] + if k == k2 { lambda } else { 0.0 };
if val.abs() > 1e-8 {
ata_coo.push((k, k2, val));
}
}
}
let normal_matrix = CsrMatrix::<f32>::from_coo(target_sc, target_sc, ata_coo);
let solver = NeumannSolver::new(1e-5, 500);
let mut out = Array4::<f32>::zeros((n_t, n_tx, n_rx, target_sc));
for t in 0..n_t {
for tx in 0..n_tx {
for rx in 0..n_rx {
let src_slice: Vec<f32> = (0..n_sc).map(|s| arr[[t, tx, rx, s]]).collect();
let mut atb = vec![0.0_f32; target_sc];
for j in 0..n_sc {
let b_j = src_slice[j];
for k in 0..target_sc {
let diff = src_pos[j] - tgt_pos[k];
let a_jk = (-diff * diff / sigma_sq).exp();
if a_jk > threshold {
atb[k] += a_jk * b_j;
}
}
}
match solver.solve(&normal_matrix, &atb) {
Ok(result) => {
for k in 0..target_sc {
out[[t, tx, rx, k]] = result.solution[k];
}
}
Err(_) => {
let weights = compute_interp_weights(n_sc, target_sc);
for (k, &(i0, i1, w)) in weights.iter().enumerate() {
out[[t, tx, rx, k]] = src_slice[i0] * (1.0 - w) + src_slice[i1] * w;
}
}
}
}
}
}
out
}
pub fn select_subcarriers_by_variance(arr: &Array4<f32>, k: usize) -> Vec<usize> {
let shape = arr.shape();
let n_sc = shape[3];
assert!(k > 0, "k must be > 0");
assert!(k <= n_sc, "k ({k}) must be <= n_sc ({n_sc})");
let total_elems = shape[0] * shape[1] * shape[2];
let mut means = vec![0.0f64; n_sc];
for sc in 0..n_sc {
let col = arr.slice(s![.., .., .., sc]);
let sum: f64 = col.iter().map(|&v| v as f64).sum();
means[sc] = sum / total_elems as f64;
}
let mut variances = vec![0.0f64; n_sc];
for sc in 0..n_sc {
let col = arr.slice(s![.., .., .., sc]);
let mean = means[sc];
let var: f64 = col.iter().map(|&v| (v as f64 - mean).powi(2)).sum::<f64>()
/ total_elems as f64;
variances[sc] = var;
}
let mut ranked: Vec<usize> = (0..n_sc).collect();
ranked.sort_by(|&a, &b| variances[b].partial_cmp(&variances[a]).unwrap_or(std::cmp::Ordering::Equal));
let mut selected: Vec<usize> = ranked[..k].to_vec();
selected.sort_unstable();
selected
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
#[test]
fn identity_resample() {
let arr = Array4::<f32>::from_shape_fn((4, 3, 3, 56), |(t, tx, rx, k)| {
(t + tx + rx + k) as f32
});
let out = interpolate_subcarriers(&arr, 56);
assert_eq!(out.shape(), arr.shape());
for v in arr.iter().zip(out.iter()) {
assert_abs_diff_eq!(v.0, v.1, epsilon = 1e-6);
}
}
#[test]
fn upsample_endpoints_preserved() {
let arr = Array4::<f32>::from_shape_fn((1, 1, 1, 4), |(_, _, _, k)| k as f32);
let out = interpolate_subcarriers(&arr, 8);
assert_eq!(out.shape(), &[1, 1, 1, 8]);
assert_abs_diff_eq!(out[[0, 0, 0, 0]], 0.0_f32, epsilon = 1e-6);
assert_abs_diff_eq!(out[[0, 0, 0, 7]], 3.0_f32, epsilon = 1e-6);
}
#[test]
fn downsample_endpoints_preserved() {
let arr = Array4::<f32>::from_shape_fn((1, 1, 1, 8), |(_, _, _, k)| k as f32 * 2.0);
let out = interpolate_subcarriers(&arr, 4);
assert_eq!(out.shape(), &[1, 1, 1, 4]);
assert_abs_diff_eq!(out[[0, 0, 0, 0]], 0.0_f32, epsilon = 1e-5);
assert_abs_diff_eq!(out[[0, 0, 0, 3]], 14.0_f32, epsilon = 1e-5);
}
#[test]
fn compute_interp_weights_identity() {
let w = compute_interp_weights(5, 5);
assert_eq!(w.len(), 5);
for (k, &(i0, i1, frac)) in w.iter().enumerate() {
assert_eq!(i0, k);
assert_eq!(i1, k);
assert_abs_diff_eq!(frac, 0.0_f32, epsilon = 1e-6);
}
}
#[test]
fn select_subcarriers_returns_correct_count() {
let arr = Array4::<f32>::from_shape_fn((10, 3, 3, 56), |(t, _, _, k)| {
(t * k) as f32
});
let selected = select_subcarriers_by_variance(&arr, 8);
assert_eq!(selected.len(), 8);
}
#[test]
fn select_subcarriers_sorted_ascending() {
let arr = Array4::<f32>::from_shape_fn((10, 3, 3, 56), |(t, _, _, k)| {
(t * k) as f32
});
let selected = select_subcarriers_by_variance(&arr, 10);
for w in selected.windows(2) {
assert!(w[0] < w[1], "Indices must be sorted ascending");
}
}
#[test]
fn select_subcarriers_all_same_returns_all() {
let arr = Array4::<f32>::ones((5, 2, 2, 20));
let selected = select_subcarriers_by_variance(&arr, 5);
assert_eq!(selected.len(), 5);
for &idx in &selected {
assert!(idx < 20);
}
}
#[test]
fn sparse_interpolation_114_to_56_shape() {
let arr = Array4::<f32>::from_shape_fn((4, 1, 3, 114), |(t, _, rx, k)| {
((t + rx + k) as f32).sin()
});
let out = interpolate_subcarriers_sparse(&arr, 56);
assert_eq!(out.shape(), &[4, 1, 3, 56]);
}
#[test]
fn sparse_interpolation_identity() {
let arr = Array4::<f32>::from_shape_fn((2, 1, 1, 20), |(_, _, _, k)| k as f32);
let out = interpolate_subcarriers_sparse(&arr, 20);
assert_eq!(out.shape(), &[2, 1, 1, 20]);
}
}