use crate::graph::{BondOrder, Molecule};
use petgraph::graph::NodeIndex;
use petgraph::visit::EdgeRef;
use serde::{Deserialize, Serialize};
use std::collections::BTreeSet;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HoseCode {
pub atom_index: usize,
pub element: u8,
pub spheres: Vec<String>,
pub full_code: String,
}
fn bond_symbol(order: BondOrder) -> &'static str {
match order {
BondOrder::Single => "",
BondOrder::Double => "=",
BondOrder::Triple => "#",
BondOrder::Aromatic => "*",
BondOrder::Unknown => "",
}
}
fn element_symbol(z: u8) -> &'static str {
match z {
1 => "H",
2 => "He",
3 => "Li",
4 => "Be",
5 => "B",
6 => "C",
7 => "N",
8 => "O",
9 => "F",
11 => "Na",
12 => "Mg",
13 => "Al",
14 => "Si",
15 => "P",
16 => "S",
17 => "Cl",
19 => "K",
20 => "Ca",
21 => "Sc",
22 => "Ti",
23 => "V",
24 => "Cr",
25 => "Mn",
26 => "Fe",
27 => "Co",
28 => "Ni",
29 => "Cu",
30 => "Zn",
31 => "Ga",
32 => "Ge",
33 => "As",
34 => "Se",
35 => "Br",
37 => "Rb",
38 => "Sr",
40 => "Zr",
41 => "Nb",
42 => "Mo",
44 => "Ru",
45 => "Rh",
46 => "Pd",
47 => "Ag",
48 => "Cd",
49 => "In",
50 => "Sn",
51 => "Sb",
52 => "Te",
53 => "I",
54 => "Xe",
55 => "Cs",
56 => "Ba",
74 => "W",
78 => "Pt",
79 => "Au",
80 => "Hg",
81 => "Tl",
82 => "Pb",
83 => "Bi",
_ => "X",
}
}
pub fn generate_hose_codes(mol: &Molecule, max_radius: usize) -> Vec<HoseCode> {
let n = mol.graph.node_count();
let mut codes = Vec::with_capacity(n);
for atom_idx in 0..n {
let center = NodeIndex::new(atom_idx);
let center_element = mol.graph[center].element;
let center_sym = element_symbol(center_element);
let mut spheres = Vec::with_capacity(max_radius + 1);
spheres.push(center_sym.to_string());
let mut visited = vec![false; n];
visited[atom_idx] = true;
let mut current_frontier: Vec<(NodeIndex, BondOrder)> = Vec::new();
for edge in mol.graph.edges(center) {
let neighbor = if edge.source() == center {
edge.target()
} else {
edge.source()
};
current_frontier.push((neighbor, edge.weight().order));
}
for _radius in 1..=max_radius {
if current_frontier.is_empty() {
spheres.push(String::new());
continue;
}
let mut sphere_parts: BTreeSet<String> = BTreeSet::new();
let mut next_frontier: Vec<(NodeIndex, BondOrder)> = Vec::new();
for (node, bond_order) in ¤t_frontier {
let node_idx = node.index();
if visited[node_idx] {
continue;
}
visited[node_idx] = true;
let elem = mol.graph[*node].element;
let sym = element_symbol(elem);
let bond_sym = bond_symbol(*bond_order);
sphere_parts.insert(format!("{}{}", bond_sym, sym));
for edge in mol.graph.edges(*node) {
let next = if edge.source() == *node {
edge.target()
} else {
edge.source()
};
if !visited[next.index()] {
next_frontier.push((next, edge.weight().order));
}
}
}
spheres.push(sphere_parts.into_iter().collect::<Vec<_>>().join(","));
current_frontier = next_frontier;
}
let full_code = format!("{}/{}", spheres[0], spheres[1..].join("/"));
codes.push(HoseCode {
atom_index: atom_idx,
element: center_element,
spheres,
full_code,
});
}
codes
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HoseShiftLookup {
pub atom_index: usize,
pub element: u8,
pub shift_ppm: f64,
pub matched_hose: String,
pub match_radius: usize,
pub confidence: f64,
}
fn h1_hose_database() -> Vec<(&'static str, f64)> {
vec![
("C/H,H,H,C", 0.90), ("C/H,H,C,C", 1.30), ("C/H,C,C,C", 1.50), ("C/H,H,H,O", 3.40), ("C/H,H,O", 3.50), ("C/H,H,H,N", 2.30), ("C/H,H,N", 2.60), ("C/H,H,H,=O", 2.10), ("C/H,H,=O", 2.50), ("C/H,H,H,*C", 2.30), ("C/*C,*C,H", 7.27), ("C/*C,*C,*C", 7.50), ("C/*C,*N,H", 7.80), ("C/=C,H,H", 5.25), ("C/=C,H,C", 5.40), ("C/=C,=C,H", 6.30), ("C/#C,H", 2.50), ("C/=O,H,C", 9.50), ("C/=O,H,H", 9.60), ("O/H,C", 2.50), ("O/H,*C", 5.50), ("N/H,H,C", 1.50), ("N/H,C,C", 2.20), ]
}
fn c13_hose_database() -> Vec<(&'static str, f64)> {
vec![
("C/H,H,H,C", 15.0), ("C/H,H,C,C", 25.0), ("C/H,C,C,C", 35.0), ("C/C,C,C,C", 40.0), ("C/H,H,H,O", 55.0), ("C/H,H,O,C", 65.0), ("C/H,H,H,N", 32.0), ("C/H,H,N", 45.0), ("C/H,H,H,=O", 30.0), ("C/*C,*C,H", 128.0), ("C/*C,*C,C", 137.0), ("C/*C,*C,O", 155.0), ("C/*C,*C,N", 148.0), ("C/*C,*C,F", 163.0), ("C/*C,*C,Cl", 134.0), ("C/=C,H,H", 115.0), ("C/=C,H,C", 130.0), ("C/=C,C,C", 140.0), ("C/=O,O,C", 175.0), ("C/=O,N,C", 170.0), ("C/=O,C,C", 205.0), ("C/=O,H,C", 200.0), ("C/#C,H", 70.0), ("C/#C,C", 85.0), ]
}
pub fn predict_shift_from_hose(hose_code: &HoseCode, nucleus: u8) -> Option<HoseShiftLookup> {
let database = match nucleus {
1 => h1_hose_database(),
6 => c13_hose_database(),
_ => return None,
};
for radius in (1..=hose_code.spheres.len().saturating_sub(1)).rev() {
let prefix = format!(
"{}/{}",
hose_code.spheres[0],
hose_code.spheres[1..=radius].join("/")
);
for &(pattern, shift) in &database {
if prefix.contains(pattern)
|| pattern.contains(&prefix)
|| fuzzy_hose_match(&prefix, pattern)
{
return Some(HoseShiftLookup {
atom_index: hose_code.atom_index,
element: nucleus,
shift_ppm: shift,
matched_hose: pattern.to_string(),
match_radius: radius,
confidence: 0.5 + 0.1 * radius as f64,
});
}
}
}
None
}
fn fuzzy_hose_match(hose: &str, pattern: &str) -> bool {
let hose_parts: Vec<&str> = hose.split('/').collect();
let pat_parts: Vec<&str> = pattern.split('/').collect();
if hose_parts.len() < 2 || pat_parts.len() < 2 {
return false;
}
if hose_parts[0] != pat_parts[0] {
return false;
}
let hose_neighbors: BTreeSet<&str> = hose_parts[1].split(',').collect();
let pat_neighbors: BTreeSet<&str> = pat_parts[1].split(',').collect();
let intersection = hose_neighbors.intersection(&pat_neighbors).count();
let union = hose_neighbors.union(&pat_neighbors).count();
if union == 0 {
return false;
}
(intersection as f64 / union as f64) > 0.5
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_hose_codes_ethanol() {
let mol = Molecule::from_smiles("CCO").unwrap();
let codes = generate_hose_codes(&mol, 3);
assert_eq!(codes.len(), mol.graph.node_count());
for code in &codes {
assert!(!code.spheres[0].is_empty());
assert!(!code.full_code.is_empty());
}
}
#[test]
fn test_hose_codes_benzene() {
let mol = Molecule::from_smiles("c1ccccc1").unwrap();
let codes = generate_hose_codes(&mol, 3);
assert_eq!(codes.len(), mol.graph.node_count());
let carbon_codes: Vec<&HoseCode> = codes.iter().filter(|c| c.element == 6).collect();
assert!(!carbon_codes.is_empty());
}
#[test]
fn test_hose_codes_deterministic() {
let mol = Molecule::from_smiles("CC(=O)O").unwrap();
let codes1 = generate_hose_codes(&mol, 3);
let codes2 = generate_hose_codes(&mol, 3);
for (c1, c2) in codes1.iter().zip(codes2.iter()) {
assert_eq!(c1.full_code, c2.full_code);
}
}
}