use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SteeringForce {
pub atom_index: usize,
pub target_xyz: [f64; 3],
pub spring_k: f64,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct SteeringForces {
pub forces: Vec<SteeringForce>,
}
impl SteeringForces {
pub fn new() -> Self {
Self { forces: Vec::new() }
}
pub fn apply(&mut self, atom_index: usize, target_xyz: [f64; 3], spring_k: f64) {
self.forces.retain(|f| f.atom_index != atom_index);
self.forces.push(SteeringForce {
atom_index,
target_xyz,
spring_k,
});
}
pub fn clear(&mut self, atom_index: usize) {
self.forces.retain(|f| f.atom_index != atom_index);
}
pub fn clear_all(&mut self) {
self.forces.clear();
}
pub fn is_active(&self) -> bool {
!self.forces.is_empty()
}
pub fn accumulate(&self, positions_flat: &[f64], forces_flat: &mut [f64]) -> f64 {
let mut total_energy = 0.0;
for sf in &self.forces {
let i = sf.atom_index;
let dx = positions_flat[3 * i] - sf.target_xyz[0];
let dy = positions_flat[3 * i + 1] - sf.target_xyz[1];
let dz = positions_flat[3 * i + 2] - sf.target_xyz[2];
let d2 = dx * dx + dy * dy + dz * dz;
total_energy += 0.5 * sf.spring_k * d2;
forces_flat[3 * i] -= sf.spring_k * dx;
forces_flat[3 * i + 1] -= sf.spring_k * dy;
forces_flat[3 * i + 2] -= sf.spring_k * dz;
}
total_energy
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn apply_and_clear() {
let mut sf = SteeringForces::new();
assert!(!sf.is_active());
sf.apply(0, [1.0, 0.0, 0.0], 100.0);
assert!(sf.is_active());
sf.clear(0);
assert!(!sf.is_active());
}
#[test]
fn accumulate_force_toward_target() {
let mut sf = SteeringForces::new();
sf.apply(0, [1.0, 0.0, 0.0], 100.0);
let positions = [0.0, 0.0, 0.0, 5.0, 0.0, 0.0];
let mut forces = [0.0; 6];
let energy = sf.accumulate(&positions, &mut forces);
assert!(forces[0] > 0.0, "force should be in +x direction");
assert!(energy > 0.0, "energy should be positive");
}
#[test]
fn clear_all_removes_everything() {
let mut sf = SteeringForces::new();
sf.apply(0, [1.0, 0.0, 0.0], 100.0);
sf.apply(1, [0.0, 1.0, 0.0], 50.0);
sf.clear_all();
assert!(!sf.is_active());
}
}