gosh-elastic 0.1.2

Elastic Potential Energy
Documentation
// [[file:../elastic.note::531551a8][531551a8]]
#![deny(warnings)]

use super::*;

type DistMatrix = nalgebra::DMatrix<f64>;
type CartMatrix = Vector3fVec;

use std::collections::HashMap;
// 531551a8 ends here

// [[file:../elastic.note::6474199c][6474199c]]
///
/// # Arguments
///
/// * dm     : distance matrix of current image (square, N by N)
/// * dm_ref : distance matrix of reference image (square, N by N)
/// * wm     : weight matrix for pairs (square, N by N)
/// * dmap   : a hash map contains normalized displacment vector (cartesian) of each pair
#[allow(unused)]
fn spring_force_generic(
    dm: &DistMatrix,
    dm_ref: &DistMatrix,
    wm: &DistMatrix,
    dmap: &HashMap<(usize, usize), Vector3f>,
) -> CartMatrix {
    let (n, m) = dm.shape();
    assert_eq!(n, m);
    assert_eq!((n, m), dm_ref.shape());
    assert_eq!((n, m), wm.shape());
    assert_eq!((n - 1) * n / 2, dmap.len());

    // initialize force vector in 3xN
    let mut fm = CartMatrix::zeros(n);
    for i in 0..n {
        for j in 0..n {
            if i != j {
                let wij = wm[(i, j)];
                let rij = dm[(i, j)];
                let dij = dm_ref[(i, j)];
                // let xij = dmap[&(i, j)];
                let xij = if i < j { dmap[&(i, j)] } else { -dmap[&(j, i)] };

                // hint:
                // e_spring = 1/2 k (r-r0)^2
                // f_spring = -k (r-r0)
                let v = -wij * (rij - dij) * xij;

                // update column vectors
                for k in 0..3 {
                    fm[(k, i)] += v[k];
                }
            }
        }
    }

    fm
}

#[allow(unused)]
/// NOTE: this version to slower than spring_force
///
/// # Arguments
///
/// * dm     : distance matrix of current image (square, N by N)
/// * dm_ref : distance matrix of reference image (square, N by N)
/// * wm     : weight matrix for pairs (square, N by N)
/// * xm     : position matrix of current image (cartesian matrix, 3xN)
fn spring_force_new(dm: &DistMatrix, dm_ref: &DistMatrix, wm: &DistMatrix, xm: &CartMatrix) -> CartMatrix {
    let n = dm.ncols();
    assert_eq!((3, n), xm.shape());

    let dmap: HashMap<_, _> = (0..n)
        .combinations(2)
        .map(|p| {
            let i = p[0];
            let j = p[1];
            let xij = (&xm.column(i) - &xm.column(j)).normalize();
            ((i, j), xij)
        })
        .collect();

    spring_force_generic(dm, dm_ref, wm, &dmap)
}

/// Compute spring force vector between two images (configurations)
///
/// # Arguments
///
/// * dm     : distance matrix of current image (NxN)
/// * dm_ref : distance matrix of reference image (NxN)
/// * wm     : weight matrix for pairs (NxN)
/// * xm     : position matrix of current image (3xN)
pub fn spring_force(dm: &DistMatrix, dm_ref: &DistMatrix, wm: &DistMatrix, xm: &CartMatrix) -> CartMatrix {
    let n = dm.ncols();

    // initialize force vector in 3xN
    let mut fm = xm.map(|_| 0.0);

    for i in 0..n {
        for j in 0..n {
            if i != j {
                let wij = wm[(i, j)];
                let rij = dm[(i, j)];
                let dij = dm_ref[(i, j)];
                let xij = &xm.column(i) - &xm.column(j);

                // hint:
                // e_spring = 1/2 k (r-r0)^2
                // f_spring = -k (r-r0)
                let v = -wij * (rij - dij) * xij.normalize();

                // update column vectors
                for k in 0..3 {
                    fm[(k, i)] += v[k];
                }
            }
        }
    }

    fm
}

/// Compute spring spring energy between two images (configurations)
///
/// # Arguments
///
/// * dm: displacement square matrix
/// * wm: weight square matrix
///
pub fn spring_energy(dm: &DistMatrix, wm: &DistMatrix) -> f64 {
    // wm.map(|e| e.sqrt()).component_mul(&dm).norm_squared()
    let (n, _) = dm.shape();
    let mut energy = 0.0;
    for i in 0..n {
        for j in 0..i {
            let d = dm[(i, j)];
            let w = wm[(i, j)];
            energy += 0.5 * d.powi(2) * w;
        }
    }
    energy
}

/// Compute elastic energy/force between image 1 in `position1` and
/// image 2 in `positions2`
pub fn compute_elastic_energy_and_force_distance_space(positions1: &[Array3], positions2: &[Array3]) -> (f64, Vec<Array3>) {
    let pm2 = positions2.to_matrix();
    let dm1 = positions1.distance_matrix();
    let dm2 = positions2.distance_matrix();
    let wm = dm1.map_with_location(|i, j, d| if i != j { d.powi(-2) } else { 0.0 });

    let f = spring_force(&dm2, &dm1, &wm, &pm2);

    let diff = &dm2 - &dm1;
    let e = spring_energy(&diff, &wm);
    (e, f.as_3d().to_vec())
}
// 6474199c ends here

// [[file:../elastic.note::ebef0ca6][ebef0ca6]]
#[test]
fn test_elastic_matrix() -> Result<()> {
    let mols = gchemol::io::read("./tests/files/md-traj-small.xyz")?.collect_vec();
    let positions1 = mols[0].positions().collect_vec();
    let positions2 = mols[1].positions().collect_vec();

    let (e, f) = compute_elastic_energy_and_force_distance_space(&positions1, &positions2);
    #[rustfmt::skip]
    let f_expected = [[-0.00295488, -0.00693774,  0.01931514],
                      [-0.00055362, -0.00547091, -0.00126652],
                      [-0.00721916, -0.00592138, -0.01877952],
                      [ 0.01529853,  0.02228969,  0.00206666],
                      [-0.00457087, -0.00395966, -0.00133576]];
    approx::assert_relative_eq!(0.001980439069872339, e, epsilon = 1E-5);
    approx::assert_relative_eq!(f_expected.to_matrix(), f.to_matrix(), epsilon = 1E-5);

    Ok(())
}
// ebef0ca6 ends here