use crate::{Error, Result};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum NetworkType {
Bitonic,
OddEven,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum RelaxDist {
Logistic,
Cauchy,
Gaussian,
}
#[derive(Debug, Clone)]
pub struct DiffSortNet {
pub network_type: NetworkType,
pub size: usize,
pub steepness: f64,
pub dist: RelaxDist,
comparators: Vec<(usize, usize)>,
}
impl DiffSortNet {
pub fn new(network_type: NetworkType, size: usize, steepness: f64, dist: RelaxDist) -> Self {
let actual_size = match network_type {
NetworkType::Bitonic => size.next_power_of_two(),
NetworkType::OddEven => size,
};
let comparators = match network_type {
NetworkType::Bitonic => bitonic_comparators(actual_size),
NetworkType::OddEven => odd_even_comparators(actual_size),
};
Self {
network_type,
size: actual_size,
steepness,
dist,
comparators,
}
}
pub fn sort(&self, x: &[f64]) -> Result<(Vec<f64>, Vec<Vec<f64>>)> {
if x.is_empty() {
return Err(Error::EmptyInput);
}
if self.steepness <= 0.0 {
return Err(Error::InvalidTemperature(self.steepness));
}
let n = self.size;
let out_n = x.len();
let pad_val = x.iter().cloned().fold(f64::NEG_INFINITY, f64::max) + 1e6;
let mut values: Vec<f64> = Vec::with_capacity(n);
values.extend_from_slice(x);
while values.len() < n {
values.push(pad_val);
}
let mut perm_flat = vec![0.0_f64; n * n];
for i in 0..n {
perm_flat[i * n + i] = 1.0;
}
for &(a, b) in &self.comparators {
let diff = (values[a] - values[b]) * self.steepness;
let sigma = relaxed_sigmoid(diff, self.dist);
let one_minus = 1.0 - sigma;
let va = values[a];
let vb = values[b];
values[a] = one_minus * va + sigma * vb;
values[b] = sigma * va + one_minus * vb;
let row_a = a * n;
let row_b = b * n;
for k in 0..n {
let pa = perm_flat[row_a + k];
let pb = perm_flat[row_b + k];
perm_flat[row_a + k] = one_minus * pa + sigma * pb;
perm_flat[row_b + k] = sigma * pa + one_minus * pb;
}
}
let sorted = values[..out_n].to_vec();
let trimmed_perm: Vec<Vec<f64>> = (0..out_n)
.map(|i| perm_flat[i * n..i * n + out_n].to_vec())
.collect();
Ok((sorted, trimmed_perm))
}
pub fn num_comparators(&self) -> usize {
self.comparators.len()
}
pub fn comparator_pairs(&self) -> &[(usize, usize)] {
&self.comparators
}
pub fn depth(&self) -> usize {
if self.comparators.is_empty() {
return 0;
}
let mut depth = 1;
let mut used_in_stage = vec![false; self.size];
for &(i, j) in &self.comparators {
if used_in_stage[i] || used_in_stage[j] {
depth += 1;
used_in_stage.fill(false);
}
used_in_stage[i] = true;
used_in_stage[j] = true;
}
depth
}
}
pub(crate) fn relaxed_sigmoid(x: f64, dist: RelaxDist) -> f64 {
match dist {
RelaxDist::Logistic => {
if x > 500.0 {
1.0
} else if x < -500.0 {
0.0
} else {
1.0 / (1.0 + (-x).exp())
}
}
RelaxDist::Cauchy => {
0.5 + x.atan() / std::f64::consts::PI
}
RelaxDist::Gaussian => {
0.5 * (1.0 + erf_approx(x / std::f64::consts::SQRT_2))
}
}
}
pub(crate) fn erf_approx(x: f64) -> f64 {
let sign = if x >= 0.0 { 1.0 } else { -1.0 };
let x = x.abs();
let t = 1.0 / (1.0 + 0.327_591_1 * x);
let t2 = t * t;
let t3 = t2 * t;
let t4 = t3 * t;
let t5 = t4 * t;
let poly = 0.254_829_592 * t - 0.284_496_736 * t2 + 1.421_413_741 * t3 - 1.453_152_027 * t4
+ 1.061_405_429 * t5;
sign * (1.0 - poly * (-x * x).exp())
}
fn bitonic_comparators(n: usize) -> Vec<(usize, usize)> {
debug_assert!(n.is_power_of_two());
let mut comps = Vec::new();
let mut k = 2;
while k <= n {
let mut j = k >> 1;
while j > 0 {
for i in 0..n {
let l = i ^ j;
if l > i {
if (i & k) == 0 {
comps.push((i, l));
} else {
comps.push((l, i));
}
}
}
j >>= 1;
}
k <<= 1;
}
comps
}
fn odd_even_comparators(n: usize) -> Vec<(usize, usize)> {
let mut comps = Vec::new();
for _ in 0..n {
let mut i = 0;
while i + 1 < n {
comps.push((i, i + 1));
i += 2;
}
let mut i = 1;
while i + 1 < n {
comps.push((i, i + 1));
i += 2;
}
}
comps
}
pub fn bitonic_sort(x: &[f64], steepness: f64) -> Result<(Vec<f64>, Vec<Vec<f64>>)> {
let net = DiffSortNet::new(
NetworkType::Bitonic,
x.len(),
steepness,
RelaxDist::Logistic,
);
net.sort(x)
}
pub fn odd_even_sort(x: &[f64], steepness: f64) -> Result<(Vec<f64>, Vec<Vec<f64>>)> {
let net = DiffSortNet::new(
NetworkType::OddEven,
x.len(),
steepness,
RelaxDist::Logistic,
);
net.sort(x)
}
pub fn ranks_from_permutation(perm: &[Vec<f64>]) -> Vec<f64> {
perm.iter()
.map(|row| {
row.iter()
.enumerate()
.map(|(j, &p)| (j as f64 + 1.0) * p)
.sum()
})
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bitonic_sort_basic() {
let x = vec![3.0, 1.0, 4.0, 2.0];
let (sorted, perm) = bitonic_sort(&x, 20.0).unwrap();
assert!(sorted[0] < sorted[1] + 0.1);
assert!(sorted[1] < sorted[2] + 0.1);
assert!(sorted[2] < sorted[3] + 0.1);
for row in &perm {
let sum: f64 = row.iter().sum();
assert!(
(sum - 1.0).abs() < 0.01,
"Row sum should be ~1.0, got {sum}"
);
}
}
#[test]
fn test_odd_even_sort_basic() {
let x = vec![3.0, 1.0, 4.0, 2.0];
let (sorted, perm) = odd_even_sort(&x, 20.0).unwrap();
assert!(sorted[0] < sorted[1] + 0.1);
assert!(sorted[1] < sorted[2] + 0.1);
assert!(sorted[2] < sorted[3] + 0.1);
for row in &perm {
let sum: f64 = row.iter().sum();
assert!(
(sum - 1.0).abs() < 0.01,
"Row sum should be ~1.0, got {sum}"
);
}
}
#[test]
fn test_already_sorted() {
let x = vec![1.0, 2.0, 3.0, 4.0];
let (sorted, perm) = bitonic_sort(&x, 20.0).unwrap();
for i in 0..4 {
assert!(
(sorted[i] - x[i]).abs() < 0.01,
"Already sorted input should remain sorted"
);
assert!(perm[i][i] > 0.9, "Should be near-identity permutation");
}
}
#[test]
fn test_reverse_sorted() {
let x = vec![4.0, 3.0, 2.0, 1.0];
let (sorted, _perm) = bitonic_sort(&x, 20.0).unwrap();
assert!(sorted[0] < sorted[3]);
}
#[test]
fn test_cauchy_distribution() {
let x = vec![3.0, 1.0, 4.0, 2.0];
let net = DiffSortNet::new(NetworkType::Bitonic, 4, 10.0, RelaxDist::Cauchy);
let (sorted, _) = net.sort(&x).unwrap();
assert!(sorted[0] < sorted[1] + 0.2);
assert!(sorted[2] < sorted[3] + 0.2);
}
#[test]
fn test_gaussian_distribution() {
let x = vec![3.0, 1.0, 4.0, 2.0];
let net = DiffSortNet::new(NetworkType::Bitonic, 4, 10.0, RelaxDist::Gaussian);
let (sorted, _) = net.sort(&x).unwrap();
assert!(sorted[0] < sorted[1] + 0.2);
assert!(sorted[2] < sorted[3] + 0.2);
}
#[test]
fn test_non_power_of_two() {
let x = vec![3.0, 1.0, 5.0];
let net = DiffSortNet::new(NetworkType::Bitonic, 3, 20.0, RelaxDist::Logistic);
let (sorted, perm) = net.sort(&x).unwrap();
assert_eq!(sorted.len(), 3);
assert_eq!(perm.len(), 3);
assert!(sorted[0] < sorted[2]);
}
#[test]
fn test_odd_even_non_power_of_two() {
let x = vec![5.0, 2.0, 7.0, 1.0, 3.0];
let (sorted, _) = odd_even_sort(&x, 20.0).unwrap();
assert_eq!(sorted.len(), 5);
for i in 1..sorted.len() {
assert!(
sorted[i] >= sorted[i - 1] - 0.2,
"Not approximately sorted at {i}"
);
}
}
#[test]
fn test_ranks_from_permutation() {
let x = vec![3.0, 1.0, 4.0, 2.0];
let (sorted, perm) = bitonic_sort(&x, 50.0).unwrap();
let ranks = ranks_from_permutation(&perm);
assert!(sorted[0] < sorted[3], "min < max: {sorted:?}");
for &r in &ranks {
assert!((0.5..=4.5).contains(&r), "Rank out of range: {r}");
}
}
#[test]
fn test_steepness_sharpness() {
let x = vec![3.0, 1.0, 4.0, 2.0];
let (_, perm_low) = bitonic_sort(&x, 1.0).unwrap();
let (_, perm_high) = bitonic_sort(&x, 100.0).unwrap();
let max_low: f64 = perm_low
.iter()
.flat_map(|r| r.iter())
.cloned()
.fold(0.0, f64::max);
let max_high: f64 = perm_high
.iter()
.flat_map(|r| r.iter())
.cloned()
.fold(0.0, f64::max);
assert!(
max_high > max_low,
"Higher steepness should produce sharper permutations"
);
}
#[test]
fn test_empty_input() {
assert!(bitonic_sort(&[], 10.0).is_err());
}
#[test]
fn test_single_element() {
let x = vec![42.0];
let (sorted, perm) = odd_even_sort(&x, 10.0).unwrap();
assert_eq!(sorted.len(), 1);
assert!((sorted[0] - 42.0).abs() < 1e-10);
assert!((perm[0][0] - 1.0).abs() < 1e-10);
}
#[test]
fn test_num_comparators() {
let net4 = DiffSortNet::new(NetworkType::Bitonic, 4, 1.0, RelaxDist::Logistic);
assert!(net4.num_comparators() > 0);
let net8 = DiffSortNet::new(NetworkType::Bitonic, 8, 1.0, RelaxDist::Logistic);
assert!(net8.num_comparators() > net4.num_comparators());
}
#[test]
fn test_column_sums_doubly_stochastic() {
let x = vec![3.0, 1.0, 4.0, 2.0];
let (_, perm) = bitonic_sort(&x, 20.0).unwrap();
let n = perm.len();
#[allow(clippy::needless_range_loop)] for j in 0..n {
let col_sum: f64 = (0..n).map(|i| perm[i][j]).sum();
assert!(
(col_sum - 1.0).abs() < 0.05,
"Column {j} sum should be ~1.0, got {col_sum}"
);
}
}
}