use crate::grid::{GRID_COUNT, angle_to_grid};
use crate::math::atan2f;
use crate::rotamer::Rotamer;
#[derive(Clone, Copy)]
#[repr(C)]
pub struct GridEntry<const N: usize> {
pub prob: f32,
pub chi_sin: [f32; N],
pub chi_cos: [f32; N],
pub chi_sigma: [f32; N],
}
pub struct RotamerIter<const N: usize, const R: usize> {
items: [Rotamer<N>; R],
idx: usize,
}
impl<const N: usize, const R: usize> Iterator for RotamerIter<N, R> {
type Item = Rotamer<N>;
#[inline]
fn next(&mut self) -> Option<Rotamer<N>> {
if self.idx < R {
let item = self.items[self.idx];
self.idx += 1;
Some(item)
} else {
None
}
}
#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
let remaining = R - self.idx;
(remaining, Some(remaining))
}
}
impl<const N: usize, const R: usize> ExactSizeIterator for RotamerIter<N, R> {
#[inline]
fn len(&self) -> usize {
R - self.idx
}
}
impl<const N: usize, const R: usize> core::iter::FusedIterator for RotamerIter<N, R> {}
const RAD_TO_DEG: f32 = 180.0 / core::f32::consts::PI;
#[inline]
fn chi_mean_from_sc(weights: [f32; 4], sins: [f32; 4], coss: [f32; 4]) -> f32 {
let mut sin_sum = 0.0_f32;
let mut cos_sum = 0.0_f32;
for i in 0..4 {
sin_sum += weights[i] * sins[i];
cos_sum += weights[i] * coss[i];
}
atan2f(sin_sum, cos_sum) * RAD_TO_DEG
}
#[inline]
fn bilinear(weights: [f32; 4], values: [f32; 4]) -> f32 {
weights[0] * values[0]
+ weights[1] * values[1]
+ weights[2] * values[2]
+ weights[3] * values[3]
}
pub fn build_iter<const N: usize, const R: usize>(
table: &[[[GridEntry<N>; R]; GRID_COUNT]; GRID_COUNT],
keys: &[[u8; N]; R],
phi: f32,
psi: f32,
) -> RotamerIter<N, R> {
let (lo_phi, frac_phi) = angle_to_grid(phi);
let (lo_psi, frac_psi) = angle_to_grid(psi);
let w = [
(1.0 - frac_phi) * (1.0 - frac_psi), frac_phi * (1.0 - frac_psi), (1.0 - frac_phi) * frac_psi, frac_phi * frac_psi, ];
let corners = [
&table[lo_phi][lo_psi],
&table[lo_phi + 1][lo_psi],
&table[lo_phi][lo_psi + 1],
&table[lo_phi + 1][lo_psi + 1],
];
let mut items: [Rotamer<N>; R] = core::array::from_fn(|k| {
let prob = bilinear(
w,
[
corners[0][k].prob,
corners[1][k].prob,
corners[2][k].prob,
corners[3][k].prob,
],
);
let chi_mean: [f32; N] = core::array::from_fn(|i| {
chi_mean_from_sc(
w,
[
corners[0][k].chi_sin[i],
corners[1][k].chi_sin[i],
corners[2][k].chi_sin[i],
corners[3][k].chi_sin[i],
],
[
corners[0][k].chi_cos[i],
corners[1][k].chi_cos[i],
corners[2][k].chi_cos[i],
corners[3][k].chi_cos[i],
],
)
});
let chi_sigma: [f32; N] = core::array::from_fn(|i| {
bilinear(
w,
[
corners[0][k].chi_sigma[i],
corners[1][k].chi_sigma[i],
corners[2][k].chi_sigma[i],
corners[3][k].chi_sigma[i],
],
)
});
Rotamer {
r: keys[k],
prob,
chi_mean,
chi_sigma,
}
});
let prob_sum: f32 = items.iter().map(|rot| rot.prob).sum();
debug_assert!(
prob_sum.is_finite() && prob_sum > 0.0,
"prob_sum must be finite and positive, got {prob_sum}"
);
let inv = 1.0 / prob_sum;
for rot in &mut items {
rot.prob *= inv;
}
RotamerIter { items, idx: 0 }
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
fn deg_to_sc(deg: f32) -> (f32, f32) {
let rad = deg * (core::f32::consts::PI / 180.0);
(rad.sin(), rad.cos())
}
#[test]
fn test_chi_mean_from_sc_no_wrap() {
let (s, c) = deg_to_sc(60.0);
let mean = chi_mean_from_sc([0.25; 4], [s; 4], [c; 4]);
assert_relative_eq!(mean, 60.0, epsilon = 0.01);
}
#[test]
fn test_chi_mean_from_sc_wrap_around() {
let (s0, c0) = deg_to_sc(170.0);
let (s1, c1) = deg_to_sc(-170.0);
let mean = chi_mean_from_sc([0.5, 0.5, 0.0, 0.0], [s0, s1, 0.0, 0.0], [c0, c1, 0.0, 0.0]);
assert!(mean.abs() > 170.0, "expected |mean| > 170°, got {mean}");
}
#[test]
fn test_bilinear_uniform() {
let result = bilinear([0.25; 4], [4.0; 4]);
assert_relative_eq!(result, 4.0, epsilon = 1e-6);
}
#[test]
fn test_bilinear_weighted() {
let result = bilinear([1.0, 0.0, 0.0, 0.0], [10.0, 20.0, 30.0, 40.0]);
assert_relative_eq!(result, 10.0, epsilon = 1e-6);
}
#[test]
fn test_rotamer_iter_exact_size() {
let items = [
Rotamer {
r: [1],
prob: 0.6,
chi_mean: [60.0],
chi_sigma: [10.0],
},
Rotamer {
r: [2],
prob: 0.4,
chi_mean: [-60.0],
chi_sigma: [12.0],
},
];
let mut iter = RotamerIter { items, idx: 0 };
assert_eq!(iter.len(), 2);
let _ = iter.next();
assert_eq!(iter.len(), 1);
let _ = iter.next();
assert_eq!(iter.len(), 0);
assert!(iter.next().is_none());
}
}