use crate::graph::Molecule;
use petgraph::graph::NodeIndex;
use petgraph::visit::EdgeRef;
use serde::{Deserialize, Serialize};
use std::collections::{BTreeMap, BTreeSet, VecDeque};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RingInfo {
pub atoms: Vec<usize>,
pub size: usize,
pub is_aromatic: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SssrResult {
pub rings: Vec<RingInfo>,
pub atom_ring_count: Vec<usize>,
pub atom_ring_sizes: Vec<Vec<usize>>,
pub ring_size_histogram: Vec<usize>,
}
pub fn compute_sssr(mol: &Molecule) -> SssrResult {
let n = mol.graph.node_count();
let m = mol.graph.edge_count();
if n == 0 || m == 0 {
return SssrResult {
rings: vec![],
atom_ring_count: vec![0; n],
atom_ring_sizes: vec![vec![]; n],
ring_size_histogram: vec![],
};
}
let mut adj: Vec<BTreeSet<usize>> = vec![BTreeSet::new(); n];
let mut edges: Vec<(usize, usize)> = Vec::with_capacity(m);
for edge in mol.graph.edge_references() {
let u = edge.source().index();
let v = edge.target().index();
adj[u].insert(v);
adj[v].insert(u);
if u < v {
edges.push((u, v));
} else {
edges.push((v, u));
}
}
edges.sort();
edges.dedup();
let n_expected = m.saturating_sub(n) + connected_components(&adj);
let mut ring_candidates: Vec<Vec<usize>> = Vec::new();
for &(u, v) in &edges {
if let Some(path) = bfs_shortest_path_excluding_edge(&adj, u, v, n) {
let ring = path;
if ring.len() >= 3 {
ring_candidates.push(ring);
}
}
}
let mut unique_rings: Vec<Vec<usize>> = Vec::new();
let mut seen: BTreeSet<Vec<usize>> = BTreeSet::new();
for ring in &ring_candidates {
let canonical = canonicalize_ring(ring);
if seen.insert(canonical.clone()) {
unique_rings.push(ring.clone());
}
}
unique_rings.sort_by_key(|r| r.len());
let mut selected_rings: Vec<Vec<usize>> = Vec::new();
let edge_to_bit: BTreeMap<(usize, usize), usize> = edges
.iter()
.enumerate()
.map(|(idx, &edge)| (edge, idx))
.collect();
let n_words = edges.len().div_ceil(64);
let mut basis_rows: Vec<(usize, Vec<u64>)> = Vec::new();
for ring in &unique_rings {
if selected_rings.len() >= n_expected {
break;
}
let ring_bits = ring_bitset(ring, &edge_to_bit, n_words);
if insert_basis_row(&mut basis_rows, ring_bits) {
selected_rings.push(ring.clone());
}
}
let rings: Vec<RingInfo> = selected_rings
.iter()
.map(|ring| {
let is_aromatic = check_ring_aromaticity(mol, ring);
RingInfo {
size: ring.len(),
atoms: ring.clone(),
is_aromatic,
}
})
.collect();
let mut atom_ring_count = vec![0usize; n];
let mut atom_ring_sizes: Vec<Vec<usize>> = vec![vec![]; n];
for ring in &rings {
for &atom in &ring.atoms {
atom_ring_count[atom] += 1;
atom_ring_sizes[atom].push(ring.size);
}
}
let max_size = rings.iter().map(|r| r.size).max().unwrap_or(0);
let mut ring_size_histogram = vec![0usize; max_size + 1];
for ring in &rings {
ring_size_histogram[ring.size] += 1;
}
SssrResult {
rings,
atom_ring_count,
atom_ring_sizes,
ring_size_histogram,
}
}
fn bfs_shortest_path_excluding_edge(
adj: &[BTreeSet<usize>],
start: usize,
end: usize,
n: usize,
) -> Option<Vec<usize>> {
let mut visited = vec![false; n];
let mut parent = vec![usize::MAX; n];
let mut queue = VecDeque::new();
visited[start] = true;
queue.push_back(start);
while let Some(current) = queue.pop_front() {
for &next in &adj[current] {
if current == start && next == end {
continue;
}
if current == end && next == start {
continue;
}
if !visited[next] {
visited[next] = true;
parent[next] = current;
if next == end {
let mut path = vec![end];
let mut curr = end;
while curr != start {
curr = parent[curr];
path.push(curr);
}
path.reverse();
return Some(path);
}
queue.push_back(next);
}
}
}
None
}
fn connected_components(adj: &[BTreeSet<usize>]) -> usize {
let mut visited = vec![false; adj.len()];
let mut components = 0;
for start in 0..adj.len() {
if visited[start] {
continue;
}
components += 1;
let mut queue = VecDeque::from([start]);
visited[start] = true;
while let Some(node) = queue.pop_front() {
for &next in &adj[node] {
if !visited[next] {
visited[next] = true;
queue.push_back(next);
}
}
}
}
components.max(1)
}
fn ring_bitset(
ring: &[usize],
edge_to_bit: &BTreeMap<(usize, usize), usize>,
n_words: usize,
) -> Vec<u64> {
let mut bits = vec![0u64; n_words];
for edge in ring_to_edges(ring) {
if let Some(&bit_index) = edge_to_bit.get(&edge) {
bits[bit_index / 64] |= 1u64 << (bit_index % 64);
}
}
bits
}
fn insert_basis_row(basis_rows: &mut Vec<(usize, Vec<u64>)>, mut row: Vec<u64>) -> bool {
for (_, basis_row) in basis_rows.iter() {
if let Some(pivot) = highest_set_bit(basis_row) {
if ((row[pivot / 64] >> (pivot % 64)) & 1) == 1 {
xor_bitsets(&mut row, basis_row);
}
}
}
let Some(pivot) = highest_set_bit(&row) else {
return false;
};
for (_, basis_row) in basis_rows.iter_mut() {
if ((basis_row[pivot / 64] >> (pivot % 64)) & 1) == 1 {
xor_bitsets(basis_row, &row);
}
}
basis_rows.push((pivot, row));
basis_rows.sort_by(|a, b| b.0.cmp(&a.0));
true
}
fn xor_bitsets(target: &mut [u64], other: &[u64]) {
for (lhs, rhs) in target.iter_mut().zip(other.iter()) {
*lhs ^= *rhs;
}
}
fn highest_set_bit(bits: &[u64]) -> Option<usize> {
for (word_index, &word) in bits.iter().enumerate().rev() {
if word != 0 {
let bit = 63usize - word.leading_zeros() as usize;
return Some(word_index * 64 + bit);
}
}
None
}
fn canonicalize_ring(ring: &[usize]) -> Vec<usize> {
if ring.is_empty() {
return vec![];
}
let n = ring.len();
let min_pos = ring
.iter()
.enumerate()
.min_by_key(|(_, &v)| v)
.map(|(i, _)| i)
.unwrap();
let forward: Vec<usize> = (0..n).map(|i| ring[(min_pos + i) % n]).collect();
let reverse: Vec<usize> = (0..n).map(|i| ring[(min_pos + n - i) % n]).collect();
if forward <= reverse {
forward
} else {
reverse
}
}
fn ring_to_edges(ring: &[usize]) -> Vec<(usize, usize)> {
let n = ring.len();
let mut edges = Vec::with_capacity(n);
for i in 0..n {
let u = ring[i];
let v = ring[(i + 1) % n];
if u < v {
edges.push((u, v));
} else {
edges.push((v, u));
}
}
edges.sort();
edges
}
fn check_ring_aromaticity(mol: &Molecule, ring_atoms: &[usize]) -> bool {
use crate::graph::BondOrder;
let all_sp2_or_aromatic = ring_atoms.iter().all(|&idx| {
let node = NodeIndex::new(idx);
let elem = mol.graph[node].element;
if !matches!(elem, 6 | 7 | 8 | 16) {
return false;
}
let has_aromatic = mol
.graph
.edges(node)
.any(|e| matches!(e.weight().order, BondOrder::Aromatic));
let is_sp2 = matches!(
mol.graph[node].hybridization,
crate::graph::Hybridization::SP2
);
has_aromatic || is_sp2
});
if !all_sp2_or_aromatic {
return false;
}
let mut pi_electrons = 0;
for &idx in ring_atoms {
let node = NodeIndex::new(idx);
let elem = mol.graph[node].element;
match elem {
6 => pi_electrons += 1, 7 => {
let h_count = mol
.graph
.neighbors(node)
.filter(|n| mol.graph[*n].element == 1)
.count();
if h_count > 0 {
pi_electrons += 2; } else {
pi_electrons += 1; }
}
8 => pi_electrons += 2, 16 => pi_electrons += 2, _ => pi_electrons += 1,
}
}
if pi_electrons < 2 {
return false;
}
(pi_electrons - 2) % 4 == 0
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_benzene_sssr() {
let mol = Molecule::from_smiles("c1ccccc1").unwrap();
let result = compute_sssr(&mol);
assert_eq!(result.rings.len(), 1, "Benzene should have 1 ring in SSSR");
assert_eq!(result.rings[0].size, 6, "Ring should be size 6");
assert!(result.rings[0].is_aromatic, "Ring should be aromatic");
}
#[test]
fn test_naphthalene_sssr() {
let mol = Molecule::from_smiles("c1ccc2ccccc2c1").unwrap();
let result = compute_sssr(&mol);
assert_eq!(
result.rings.len(),
2,
"Naphthalene SSSR should have 2 rings, got {}",
result.rings.len()
);
for ring in &result.rings {
assert_eq!(ring.size, 6, "All naphthalene rings should be size 6");
assert!(ring.is_aromatic, "All naphthalene rings should be aromatic");
}
}
#[test]
fn test_cyclohexane_sssr() {
let mol = Molecule::from_smiles("C1CCCCC1").unwrap();
let result = compute_sssr(&mol);
assert_eq!(result.rings.len(), 1, "Cyclohexane should have 1 ring");
assert_eq!(result.rings[0].size, 6);
assert!(!result.rings[0].is_aromatic, "Cyclohexane is not aromatic");
}
#[test]
fn test_ethane_no_rings() {
let mol = Molecule::from_smiles("CC").unwrap();
let result = compute_sssr(&mol);
assert_eq!(result.rings.len(), 0, "Ethane should have no rings");
}
#[test]
fn test_ring_canonicalization() {
let ring1 = vec![3, 0, 1, 2];
let ring2 = vec![1, 2, 3, 0];
assert_eq!(canonicalize_ring(&ring1), canonicalize_ring(&ring2));
}
#[test]
fn test_atom_ring_membership() {
let mol = Molecule::from_smiles("c1ccccc1").unwrap();
let result = compute_sssr(&mol);
for &idx in &result.rings[0].atoms {
assert!(
result.atom_ring_count[idx] >= 1,
"Atom {} should be in at least 1 ring",
idx
);
}
}
#[test]
fn test_connected_components_helper() {
let adj = vec![
BTreeSet::from([1]),
BTreeSet::from([0, 2]),
BTreeSet::from([1]),
BTreeSet::from([4]),
BTreeSet::from([3]),
];
assert_eq!(connected_components(&adj), 2);
}
}