sci-form 0.15.2

High-performance 3D molecular conformer generation using ETKDG distance geometry
Documentation
use serde::{Deserialize, Serialize};

#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum CoordinationGeometryGuess {
    Linear,
    Trigonal,
    Tetrahedral,
    SquarePlanar,
    TrigonalBipyramidal,
    Octahedral,
    Unknown,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MetalCoordinationCenter {
    pub atom_index: usize,
    pub element: u8,
    pub ligand_indices: Vec<usize>,
    pub coordination_number: usize,
    pub geometry: CoordinationGeometryGuess,
    pub geometry_score: f64,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TopologyAnalysisResult {
    pub metal_centers: Vec<MetalCoordinationCenter>,
    pub warnings: Vec<String>,
}

fn distance(a: [f64; 3], b: [f64; 3]) -> f64 {
    let dx = a[0] - b[0];
    let dy = a[1] - b[1];
    let dz = a[2] - b[2];
    (dx * dx + dy * dy + dz * dz).sqrt()
}

fn normalize(v: [f64; 3]) -> [f64; 3] {
    let mag = (v[0] * v[0] + v[1] * v[1] + v[2] * v[2]).sqrt();
    if mag <= 1e-12 {
        [0.0, 0.0, 1.0]
    } else {
        [v[0] / mag, v[1] / mag, v[2] / mag]
    }
}

fn dot(a: [f64; 3], b: [f64; 3]) -> f64 {
    a[0] * b[0] + a[1] * b[1] + a[2] * b[2]
}

fn is_likely_coordinated(metal_z: u8, ligand_z: u8, dist: f64) -> bool {
    if ligand_z == 1 {
        return false;
    }
    let cutoff = 1.3
        * (crate::graph::get_covalent_radius(metal_z)
            + crate::graph::get_covalent_radius(ligand_z))
        + 0.25;
    dist <= cutoff
}

fn ideal_directions(geometry: CoordinationGeometryGuess) -> &'static [[f64; 3]] {
    const LINEAR: [[f64; 3]; 2] = [[1.0, 0.0, 0.0], [-1.0, 0.0, 0.0]];
    const TRIGONAL: [[f64; 3]; 3] = [
        [1.0, 0.0, 0.0],
        [-0.5, 0.8660254037844386, 0.0],
        [-0.5, -0.8660254037844386, 0.0],
    ];
    const TETRAHEDRAL: [[f64; 3]; 4] = [
        [0.5773502691896258, 0.5773502691896258, 0.5773502691896258],
        [0.5773502691896258, -0.5773502691896258, -0.5773502691896258],
        [-0.5773502691896258, 0.5773502691896258, -0.5773502691896258],
        [-0.5773502691896258, -0.5773502691896258, 0.5773502691896258],
    ];
    const SQUARE_PLANAR: [[f64; 3]; 4] = [
        [1.0, 0.0, 0.0],
        [0.0, 1.0, 0.0],
        [-1.0, 0.0, 0.0],
        [0.0, -1.0, 0.0],
    ];
    const TRIGONAL_BIPYRAMIDAL: [[f64; 3]; 5] = [
        [0.0, 0.0, 1.0],
        [0.0, 0.0, -1.0],
        [1.0, 0.0, 0.0],
        [-0.5, 0.8660254037844386, 0.0],
        [-0.5, -0.8660254037844386, 0.0],
    ];
    const OCTAHEDRAL: [[f64; 3]; 6] = [
        [1.0, 0.0, 0.0],
        [-1.0, 0.0, 0.0],
        [0.0, 1.0, 0.0],
        [0.0, -1.0, 0.0],
        [0.0, 0.0, 1.0],
        [0.0, 0.0, -1.0],
    ];

    match geometry {
        CoordinationGeometryGuess::Linear => &LINEAR,
        CoordinationGeometryGuess::Trigonal => &TRIGONAL,
        CoordinationGeometryGuess::Tetrahedral => &TETRAHEDRAL,
        CoordinationGeometryGuess::SquarePlanar => &SQUARE_PLANAR,
        CoordinationGeometryGuess::TrigonalBipyramidal => &TRIGONAL_BIPYRAMIDAL,
        CoordinationGeometryGuess::Octahedral => &OCTAHEDRAL,
        CoordinationGeometryGuess::Unknown => &[],
    }
}

fn assignment_cost(vectors: &[[f64; 3]], ideals: &[[f64; 3]]) -> f64 {
    fn recurse(
        vectors: &[[f64; 3]],
        ideals: &[[f64; 3]],
        used: &mut [bool],
        idx: usize,
        current: f64,
        best: &mut f64,
    ) {
        if idx == vectors.len() {
            *best = best.min(current);
            return;
        }
        if current >= *best {
            return;
        }

        for ideal_idx in 0..ideals.len() {
            if used[ideal_idx] {
                continue;
            }
            used[ideal_idx] = true;
            let cost = 1.0 - dot(vectors[idx], ideals[ideal_idx]);
            recurse(vectors, ideals, used, idx + 1, current + cost, best);
            used[ideal_idx] = false;
        }
    }

    let mut used = vec![false; ideals.len()];
    let mut best = f64::INFINITY;
    recurse(vectors, ideals, &mut used, 0, 0.0, &mut best);
    best / (vectors.len() as f64)
}

fn classify_geometry(vectors: &[[f64; 3]]) -> (CoordinationGeometryGuess, f64) {
    let candidates: &[CoordinationGeometryGuess] = match vectors.len() {
        2 => &[CoordinationGeometryGuess::Linear],
        3 => &[CoordinationGeometryGuess::Trigonal],
        4 => &[
            CoordinationGeometryGuess::Tetrahedral,
            CoordinationGeometryGuess::SquarePlanar,
        ],
        5 => &[CoordinationGeometryGuess::TrigonalBipyramidal],
        6 => &[CoordinationGeometryGuess::Octahedral],
        _ => return (CoordinationGeometryGuess::Unknown, 0.0),
    };

    let mut best_geometry = CoordinationGeometryGuess::Unknown;
    let mut best_cost = f64::INFINITY;

    for &geometry in candidates {
        let ideals = ideal_directions(geometry);
        let cost = assignment_cost(vectors, ideals);
        if cost < best_cost {
            best_cost = cost;
            best_geometry = geometry;
        }
    }

    let score = (1.0 - (best_cost / 2.0).sqrt()).clamp(0.0, 1.0);
    (best_geometry, score)
}

pub fn analyze_topology(elements: &[u8], positions: &[[f64; 3]]) -> TopologyAnalysisResult {
    let mut metal_centers = Vec::new();

    for (metal_idx, &metal_z) in elements.iter().enumerate() {
        if !crate::eht::is_transition_metal(metal_z) {
            continue;
        }

        let mut ligand_indices = Vec::new();
        let mut vectors = Vec::new();
        for ligand_idx in 0..elements.len() {
            if ligand_idx == metal_idx {
                continue;
            }
            let dist = distance(positions[metal_idx], positions[ligand_idx]);
            if is_likely_coordinated(metal_z, elements[ligand_idx], dist) {
                ligand_indices.push(ligand_idx);
                vectors.push(normalize([
                    positions[ligand_idx][0] - positions[metal_idx][0],
                    positions[ligand_idx][1] - positions[metal_idx][1],
                    positions[ligand_idx][2] - positions[metal_idx][2],
                ]));
            }
        }

        let (geometry, geometry_score) = classify_geometry(&vectors);
        metal_centers.push(MetalCoordinationCenter {
            atom_index: metal_idx,
            element: metal_z,
            coordination_number: ligand_indices.len(),
            ligand_indices,
            geometry,
            geometry_score,
        });
    }

    let warnings = if metal_centers.is_empty() {
        vec!["No transition-metal centers detected for topology analysis.".to_string()]
    } else {
        Vec::new()
    };

    TopologyAnalysisResult {
        metal_centers,
        warnings,
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_detect_square_planar_pt() {
        let elements = [78u8, 17, 17, 7, 7, 1, 1, 1, 1, 1, 1];
        let positions = [
            [0.0, 0.0, 0.0],
            [2.32, 0.0, 0.0],
            [-2.32, 0.0, 0.0],
            [0.0, 2.05, 0.0],
            [0.0, -2.05, 0.0],
            [0.90, 2.65, 0.0],
            [-0.90, 2.65, 0.0],
            [0.00, 2.05, 0.95],
            [0.90, -2.65, 0.0],
            [-0.90, -2.65, 0.0],
            [0.00, -2.05, -0.95],
        ];

        let result = analyze_topology(&elements, &positions);
        assert_eq!(result.metal_centers.len(), 1);
        assert_eq!(result.metal_centers[0].coordination_number, 4);
        assert_eq!(
            result.metal_centers[0].geometry,
            CoordinationGeometryGuess::SquarePlanar
        );
    }

    #[test]
    fn test_detect_octahedral_fe() {
        let elements = [26u8, 17, 17, 17, 17, 17, 17];
        let positions = [
            [0.0, 0.0, 0.0],
            [2.30, 0.0, 0.0],
            [-2.30, 0.0, 0.0],
            [0.0, 2.30, 0.0],
            [0.0, -2.30, 0.0],
            [0.0, 0.0, 2.30],
            [0.0, 0.0, -2.30],
        ];

        let result = analyze_topology(&elements, &positions);
        assert_eq!(
            result.metal_centers[0].geometry,
            CoordinationGeometryGuess::Octahedral
        );
    }

    #[test]
    fn test_detect_trigonal_bipyramidal_center() {
        let elements = [25u8, 17, 17, 17, 17, 17];
        let positions = [
            [0.0, 0.0, 0.0],
            [0.0, 0.0, 2.20],
            [0.0, 0.0, -2.20],
            [2.10, 0.0, 0.0],
            [-1.05, 1.8186533479473213, 0.0],
            [-1.05, -1.8186533479473213, 0.0],
        ];

        let result = analyze_topology(&elements, &positions);
        assert_eq!(
            result.metal_centers[0].geometry,
            CoordinationGeometryGuess::TrigonalBipyramidal
        );
    }
}