pub const PRUNED: f32 = f32::INFINITY;
pub struct SelfEnergyTable {
data: Vec<f32>,
offsets: Vec<usize>,
}
impl SelfEnergyTable {
pub fn new(counts: &[u16]) -> Self {
let n = counts.len();
let mut offsets = vec![0usize; n + 1];
for (i, &c) in counts.iter().enumerate() {
offsets[i + 1] = offsets[i] + c as usize;
}
Self {
data: vec![0.0; offsets[n]],
offsets,
}
}
pub fn n_slots(&self) -> usize {
self.offsets.len() - 1
}
pub fn n_candidates(&self, s: usize) -> usize {
debug_assert!(
s < self.n_slots(),
"slot {s} out of bounds (n_slots={})",
self.n_slots(),
);
self.offsets[s + 1] - self.offsets[s]
}
pub fn get(&self, s: usize, r: usize) -> f32 {
debug_assert!(
s < self.n_slots(),
"slot {s} out of bounds (n_slots={})",
self.n_slots(),
);
debug_assert!(
r < self.n_candidates(s),
"candidate {r} out of bounds (n_candidates={})",
self.n_candidates(s),
);
self.data[self.offsets[s] + r]
}
pub fn set(&mut self, s: usize, r: usize, val: f32) {
debug_assert!(
s < self.n_slots(),
"slot {s} out of bounds (n_slots={})",
self.n_slots(),
);
debug_assert!(
r < self.n_candidates(s),
"candidate {r} out of bounds (n_candidates={})",
self.n_candidates(s),
);
self.data[self.offsets[s] + r] = val;
}
pub fn prune(&mut self, s: usize, r: usize) {
self.set(s, r, PRUNED);
}
pub fn is_pruned(&self, s: usize, r: usize) -> bool {
self.get(s, r) == PRUNED
}
}
pub struct PairEnergyTable {
data: Vec<f32>,
offsets: Vec<usize>,
sizes: Vec<(u16, u16)>,
}
impl PairEnergyTable {
pub fn new(dims: &[(u16, u16)]) -> Self {
let n = dims.len();
let mut offsets = vec![0usize; n + 1];
for (i, &(ni, nj)) in dims.iter().enumerate() {
offsets[i + 1] = offsets[i] + ni as usize * nj as usize;
}
Self {
data: vec![0.0; offsets[n]],
offsets,
sizes: dims.to_vec(),
}
}
pub fn n_edges(&self) -> usize {
self.sizes.len()
}
pub fn dims(&self, edge: usize) -> (usize, usize) {
debug_assert!(
edge < self.n_edges(),
"edge {edge} out of bounds (n_edges={})",
self.n_edges(),
);
(self.sizes[edge].0 as usize, self.sizes[edge].1 as usize)
}
pub fn matrix(&self, edge: usize) -> &[f32] {
debug_assert!(
edge < self.n_edges(),
"edge {edge} out of bounds (n_edges={})",
self.n_edges(),
);
&self.data[self.offsets[edge]..self.offsets[edge + 1]]
}
pub fn matrices_mut(&mut self) -> Vec<&mut [f32]> {
let mut slices = Vec::with_capacity(self.sizes.len());
let mut rest = self.data.as_mut_slice();
for e in 0..self.sizes.len() {
let len = self.offsets[e + 1] - self.offsets[e];
let (head, tail) = rest.split_at_mut(len);
slices.push(head);
rest = tail;
}
slices
}
}
pub struct RotamerBias {
data: Vec<f32>,
offsets: Vec<usize>,
}
impl RotamerBias {
pub fn new(per_slot: Vec<Vec<f32>>) -> Self {
let n = per_slot.len();
let mut offsets = vec![0usize; n + 1];
for (i, v) in per_slot.iter().enumerate() {
offsets[i + 1] = offsets[i] + v.len();
}
let mut data = Vec::with_capacity(offsets[n]);
for v in per_slot {
data.extend(v);
}
Self { data, offsets }
}
pub fn n_slots(&self) -> usize {
self.offsets.len() - 1
}
pub fn slot(&self, s: usize) -> &[f32] {
debug_assert!(
s < self.n_slots(),
"slot {s} out of bounds (n_slots = {})",
self.n_slots(),
);
&self.data[self.offsets[s]..self.offsets[s + 1]]
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn self_new_empty_has_zero_slots() {
let t = SelfEnergyTable::new(&[]);
assert_eq!(t.n_slots(), 0);
}
#[test]
fn self_n_slots_matches_counts() {
let t = SelfEnergyTable::new(&[3, 5, 2]);
assert_eq!(t.n_slots(), 3);
}
#[test]
fn self_n_candidates_matches_each_count() {
let t = SelfEnergyTable::new(&[4, 7, 1]);
assert_eq!(t.n_candidates(0), 4);
assert_eq!(t.n_candidates(1), 7);
assert_eq!(t.n_candidates(2), 1);
}
#[test]
fn self_all_entries_zero_after_new() {
let t = SelfEnergyTable::new(&[3, 2]);
for s in 0..t.n_slots() {
for r in 0..t.n_candidates(s) {
assert_eq!(t.get(s, r), 0.0);
}
}
}
#[test]
fn self_set_then_get_round_trips() {
let mut t = SelfEnergyTable::new(&[3, 2]);
t.set(0, 2, 1.5);
t.set(1, 0, -3.0);
assert_eq!(t.get(0, 2), 1.5);
assert_eq!(t.get(1, 0), -3.0);
}
#[test]
fn self_slots_are_independent() {
let mut t = SelfEnergyTable::new(&[2, 2]);
t.set(0, 0, 99.0);
assert_eq!(t.get(1, 0), 0.0);
}
#[test]
fn self_is_pruned_false_initially() {
let t = SelfEnergyTable::new(&[3]);
assert!(!t.is_pruned(0, 0));
assert!(!t.is_pruned(0, 2));
}
#[test]
fn self_prune_sets_infinity() {
let mut t = SelfEnergyTable::new(&[4]);
t.prune(0, 1);
assert_eq!(t.get(0, 1), PRUNED);
}
#[test]
fn self_is_pruned_true_after_prune() {
let mut t = SelfEnergyTable::new(&[3]);
t.prune(0, 2);
assert!(t.is_pruned(0, 2));
assert!(!t.is_pruned(0, 0));
}
#[test]
#[cfg(debug_assertions)]
#[should_panic]
fn self_get_panics_slot_out_of_bounds() {
let t = SelfEnergyTable::new(&[3]);
let _ = t.get(1, 0);
}
#[test]
#[cfg(debug_assertions)]
#[should_panic]
fn self_get_panics_candidate_out_of_bounds() {
let t = SelfEnergyTable::new(&[3]);
let _ = t.get(0, 3);
}
#[test]
fn pair_new_empty_has_zero_edges() {
let t = PairEnergyTable::new(&[]);
assert_eq!(t.n_edges(), 0);
}
#[test]
fn pair_n_edges_matches_dims() {
let t = PairEnergyTable::new(&[(3, 2), (4, 5)]);
assert_eq!(t.n_edges(), 2);
}
#[test]
fn pair_dims_match_input() {
let t = PairEnergyTable::new(&[(3, 2), (4, 5)]);
assert_eq!(t.dims(0), (3, 2));
assert_eq!(t.dims(1), (4, 5));
}
#[test]
fn pair_all_entries_zero_after_new() {
let t = PairEnergyTable::new(&[(2, 3)]);
assert!(t.matrix(0).iter().all(|&x| x == 0.0));
}
#[test]
fn pair_matrix_length_matches_product() {
let t = PairEnergyTable::new(&[(3, 4), (2, 5)]);
assert_eq!(t.matrix(0).len(), 12);
assert_eq!(t.matrix(1).len(), 10);
}
#[test]
fn pair_matrices_mut_empty_returns_empty() {
let mut t = PairEnergyTable::new(&[]);
assert!(t.matrices_mut().is_empty());
}
#[test]
fn pair_matrices_mut_count_matches_edges() {
let mut t = PairEnergyTable::new(&[(2, 3), (4, 1), (1, 5)]);
assert_eq!(t.matrices_mut().len(), 3);
}
#[test]
fn pair_matrices_mut_slice_lengths_match_dims() {
let mut t = PairEnergyTable::new(&[(3, 4), (2, 5)]);
let slices = t.matrices_mut();
assert_eq!(slices[0].len(), 12);
assert_eq!(slices[1].len(), 10);
}
#[test]
fn pair_matrices_mut_written_data_visible_via_matrix() {
let mut t = PairEnergyTable::new(&[(2, 3)]);
{
let mut slices = t.matrices_mut();
slices[0][0 * 3 + 0] = 1.0;
slices[0][0 * 3 + 2] = 3.0;
slices[0][1 * 3 + 1] = 5.0;
}
assert_eq!(t.matrix(0), &[1.0, 0.0, 3.0, 0.0, 5.0, 0.0]);
}
#[test]
fn pair_matrices_mut_edges_are_independent() {
let mut t = PairEnergyTable::new(&[(2, 2), (2, 2)]);
{
let mut slices = t.matrices_mut();
slices[0][1 * 2 + 0] = 7.0;
}
assert_eq!(t.matrix(1)[1 * 2 + 0], 0.0);
}
#[test]
#[cfg(debug_assertions)]
#[should_panic]
fn pair_dims_panics_edge_out_of_bounds() {
let t = PairEnergyTable::new(&[(2, 3)]);
let _ = t.dims(1);
}
#[test]
#[cfg(debug_assertions)]
#[should_panic]
fn pair_matrix_panics_edge_out_of_bounds() {
let t = PairEnergyTable::new(&[(2, 3)]);
let _ = t.matrix(1);
}
#[test]
fn bias_empty_has_zero_slots() {
let b = RotamerBias::new(vec![]);
assert_eq!(b.n_slots(), 0);
}
#[test]
fn bias_n_slots_matches() {
let b = RotamerBias::new(vec![vec![1.0, 2.0], vec![3.0]]);
assert_eq!(b.n_slots(), 2);
}
#[test]
fn bias_slot_returns_correct_slice() {
let b = RotamerBias::new(vec![vec![1.0, 2.0], vec![3.0, 4.0, 5.0]]);
assert_eq!(b.slot(0), [1.0, 2.0]);
assert_eq!(b.slot(1), [3.0, 4.0, 5.0]);
}
#[test]
fn bias_data_is_flat() {
let b = RotamerBias::new(vec![vec![1.0], vec![2.0, 3.0]]);
assert_eq!(b.data.len(), 3);
}
#[test]
#[cfg(debug_assertions)]
#[should_panic]
fn bias_slot_panics_out_of_bounds() {
let b = RotamerBias::new(vec![vec![1.0]]);
let _ = b.slot(1);
}
}