use crate::graph::Molecule;
use crate::smarts::{parse_smarts, substruct_match};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SmirksTransform {
pub reactant_smarts: Vec<String>,
pub product_smarts: Vec<String>,
pub atom_map: HashMap<usize, usize>,
pub bond_changes: Vec<BondChange>,
pub smirks: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BondChange {
pub atom1_map: usize,
pub atom2_map: usize,
pub old_order: Option<String>,
pub new_order: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SmirksResult {
pub products: Vec<String>,
pub atom_mapping: HashMap<usize, usize>,
pub n_transforms: usize,
pub success: bool,
pub messages: Vec<String>,
}
pub fn parse_smirks(smirks: &str) -> Result<SmirksTransform, String> {
let parts: Vec<&str> = smirks.split(">>").collect();
if parts.len() != 2 {
return Err("SMIRKS must contain exactly one '>>' separator".to_string());
}
let reactant_part = parts[0].trim();
let product_part = parts[1].trim();
if reactant_part.is_empty() || product_part.is_empty() {
return Err("SMIRKS reactant and product parts must be non-empty".to_string());
}
let reactant_smarts: Vec<String> = reactant_part.split('.').map(|s| s.to_string()).collect();
let product_smarts: Vec<String> = product_part.split('.').map(|s| s.to_string()).collect();
let reactant_maps = extract_atom_maps(reactant_part)?;
let product_maps = extract_atom_maps(product_part)?;
let mut atom_map = HashMap::new();
for map_num in reactant_maps.keys() {
if product_maps.contains_key(map_num) {
atom_map.insert(*map_num, *map_num);
}
}
let mapped_in_reactant: std::collections::HashSet<usize> = atom_map.keys().copied().collect();
let mapped_in_product: std::collections::HashSet<usize> = atom_map.values().copied().collect();
if mapped_in_reactant != mapped_in_product {
return Err(format!(
"SMIRKS atom maps are not bijective: reactant maps {:?} vs product maps {:?}",
mapped_in_reactant, mapped_in_product
));
}
let bond_changes =
detect_bond_changes(reactant_part, product_part, &reactant_maps, &product_maps);
Ok(SmirksTransform {
reactant_smarts,
product_smarts,
atom_map,
bond_changes,
smirks: smirks.to_string(),
})
}
pub fn apply_smirks(smirks: &str, smiles: &str) -> Result<SmirksResult, String> {
let transform = parse_smirks(smirks)?;
if transform.reactant_smarts.len() > 1 {
let reactant_smiles: Vec<&str> = smiles.split('.').collect();
return apply_smirks_multi_inner(&transform, &reactant_smiles);
}
let mol = Molecule::from_smiles(smiles)?;
let matches = match_smarts_pattern(&mol, &transform.reactant_smarts[0])?;
if matches.is_empty() {
return Ok(SmirksResult {
products: vec![],
atom_mapping: HashMap::new(),
n_transforms: 0,
success: false,
messages: vec!["No match found for reactant pattern".to_string()],
});
}
let atom_mapping = &matches[0];
Ok(SmirksResult {
products: transform.product_smarts.clone(),
atom_mapping: atom_mapping.clone(),
n_transforms: 1,
success: true,
messages: vec![format!(
"Transform applied: {} atoms mapped",
atom_mapping.len()
)],
})
}
pub fn apply_smirks_multi(smirks: &str, reactant_smiles: &[&str]) -> Result<SmirksResult, String> {
let transform = parse_smirks(smirks)?;
apply_smirks_multi_inner(&transform, reactant_smiles)
}
fn apply_smirks_multi_inner(
transform: &SmirksTransform,
reactant_smiles: &[&str],
) -> Result<SmirksResult, String> {
if reactant_smiles.len() < transform.reactant_smarts.len() {
return Ok(SmirksResult {
products: vec![],
atom_mapping: HashMap::new(),
n_transforms: 0,
success: false,
messages: vec![format!(
"Expected {} reactant(s) but got {}",
transform.reactant_smarts.len(),
reactant_smiles.len()
)],
});
}
let mut combined_mapping = HashMap::new();
let mut all_messages = Vec::new();
for (idx, (pattern, smiles)) in transform
.reactant_smarts
.iter()
.zip(reactant_smiles.iter())
.enumerate()
{
let mol = Molecule::from_smiles(smiles)?;
let matches = match_smarts_pattern(&mol, pattern)?;
if matches.is_empty() {
return Ok(SmirksResult {
products: vec![],
atom_mapping: HashMap::new(),
n_transforms: 0,
success: false,
messages: vec![format!(
"No match for reactant component {} (pattern: {})",
idx, pattern
)],
});
}
for (map_num, atom_idx) in &matches[0] {
combined_mapping.insert(*map_num, *atom_idx);
}
all_messages.push(format!(
"Component {} matched: {} atoms",
idx,
matches[0].len()
));
}
Ok(SmirksResult {
products: transform.product_smarts.clone(),
atom_mapping: combined_mapping,
n_transforms: 1,
success: true,
messages: all_messages,
})
}
fn extract_atom_maps(pattern: &str) -> Result<HashMap<usize, usize>, String> {
let mut maps = HashMap::new();
let bytes = pattern.as_bytes();
let mut pos = 0;
let mut atom_idx = 0;
while pos < bytes.len() {
if bytes[pos] == b'[' {
let start = pos;
while pos < bytes.len() && bytes[pos] != b']' {
pos += 1;
}
let bracket_content = &pattern[start..=pos.min(bytes.len() - 1)];
if let Some(colon_pos) = bracket_content.rfind(':') {
let map_str = &bracket_content[colon_pos + 1..bracket_content.len() - 1];
if let Ok(map_num) = map_str.parse::<usize>() {
if maps.insert(map_num, atom_idx).is_some() {
return Err(format!(
"duplicate atom map :{} in pattern '{}'",
map_num, pattern
));
}
}
}
atom_idx += 1;
} else if bytes[pos].is_ascii_uppercase()
|| (bytes[pos] == b'c'
|| bytes[pos] == b'n'
|| bytes[pos] == b'o'
|| bytes[pos] == b's')
{
atom_idx += 1;
}
pos += 1;
}
Ok(maps)
}
fn detect_bond_changes(
_reactant: &str,
_product: &str,
reactant_maps: &HashMap<usize, usize>,
product_maps: &HashMap<usize, usize>,
) -> Vec<BondChange> {
let mut changes = Vec::new();
for map_num in reactant_maps.keys() {
if !product_maps.contains_key(map_num) {
changes.push(BondChange {
atom1_map: *map_num,
atom2_map: 0,
old_order: Some("SINGLE".to_string()),
new_order: None,
});
}
}
for map_num in product_maps.keys() {
if !reactant_maps.contains_key(map_num) {
changes.push(BondChange {
atom1_map: *map_num,
atom2_map: 0,
old_order: None,
new_order: Some("SINGLE".to_string()),
});
}
}
changes
}
fn match_smarts_pattern(
mol: &Molecule,
pattern: &str,
) -> Result<Vec<HashMap<usize, usize>>, String> {
let parsed = parse_smarts(pattern)?;
let mapped_atoms: Vec<(usize, usize)> = parsed
.atoms
.iter()
.enumerate()
.filter_map(|(idx, atom)| atom.map_idx.map(|map_idx| (idx, map_idx as usize)))
.collect();
Ok(substruct_match(mol, &parsed)
.into_iter()
.map(|matched_atoms| {
mapped_atoms
.iter()
.map(|(pattern_idx, map_num)| (*map_num, matched_atoms[*pattern_idx]))
.collect()
})
.collect())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_smirks_basic() {
let result = parse_smirks("[C:1](=O)[OH:2]>>[C:1](=O)[O-:2]");
assert!(result.is_ok());
let t = result.unwrap();
assert_eq!(t.reactant_smarts.len(), 1);
assert_eq!(t.product_smarts.len(), 1);
assert!(!t.atom_map.is_empty());
}
#[test]
fn test_parse_smirks_invalid() {
assert!(parse_smirks("no_separator").is_err());
assert!(parse_smirks(">>").is_err());
}
#[test]
fn test_extract_atom_maps() {
let maps = extract_atom_maps("[C:1](=O)[OH:2]").unwrap();
assert!(maps.contains_key(&1));
assert!(maps.contains_key(&2));
}
#[test]
fn test_extract_atom_maps_rejects_duplicates() {
let err = extract_atom_maps("[C:1][O:1]").unwrap_err();
assert!(err.contains("duplicate atom map"));
}
#[test]
fn test_apply_smirks() {
let result = apply_smirks("[C:1](=O)[OH:2]>>[C:1](=O)[O-:2]", "CC(=O)O");
let result = result.unwrap();
assert!(result.success);
assert_eq!(result.n_transforms, 1);
assert_eq!(result.atom_mapping.len(), 2);
}
#[test]
fn test_apply_smirks_requires_real_match() {
let result = apply_smirks("[N:1]>>[N:1]", "CCO").unwrap();
assert!(!result.success);
assert_eq!(result.n_transforms, 0);
}
#[test]
fn test_apply_smirks_multi_component_transform() {
let result = apply_smirks_multi(
"[C:1](=[O:2])[OH:3].[C:4][OH:5]>>[C:1](=[O:2])[O:5][C:4]",
&["CC(=O)O", "CO"],
);
assert!(result.is_ok());
}
#[test]
fn test_deprotonation_reaction() {
let result = apply_smirks("[C:1](=O)[OH:2]>>[C:1](=O)[O-:2]", "CC(=O)O").unwrap();
assert!(result.success);
assert_eq!(result.n_transforms, 1);
}
#[test]
fn test_esterification_pattern() {
let smirks = "[C:1](=[O:2])[OH:3].[C:4][OH:5]>>[C:1](=[O:2])[O:5][C:4]";
let parsed = parse_smirks(smirks);
assert!(parsed.is_ok());
let t = parsed.unwrap();
assert_eq!(t.reactant_smarts.len(), 2); assert_eq!(t.product_smarts.len(), 1);
}
#[test]
fn test_oxidation_pattern() {
let smirks = "[C:1][C:2]([H:3])[OH:4]>>[C:1][C:2](=[O:4])";
let parsed = parse_smirks(smirks);
assert!(parsed.is_ok());
let t = parsed.unwrap();
assert!(t.atom_map.contains_key(&1));
assert!(t.atom_map.contains_key(&2));
}
#[test]
fn test_halogenation_pattern() {
let smirks = "[c:1][H:2]>>[c:1][Cl:2]";
let parsed = parse_smirks(smirks);
assert!(parsed.is_ok());
}
#[test]
fn test_reduction_pattern() {
let smirks = "[C:1]=[O:2]>>[C:1][OH:2]";
let parsed = parse_smirks(smirks);
assert!(parsed.is_ok());
let t = parsed.unwrap();
assert_eq!(t.reactant_smarts.len(), 1);
assert_eq!(t.product_smarts.len(), 1);
}
#[test]
fn test_amide_formation_pattern() {
let smirks = "[C:1](=[O:2])[OH:3].[N:4][H:5]>>[C:1](=[O:2])[N:4]";
let parsed = parse_smirks(smirks);
assert!(parsed.is_ok());
let t = parsed.unwrap();
assert_eq!(t.reactant_smarts.len(), 2);
}
#[test]
fn test_nitration_pattern() {
let smirks = "[c:1][H:2]>>[c:1][N+:2](=[O:3])[O-:4]";
let parsed = parse_smirks(smirks);
assert!(parsed.is_ok());
}
#[test]
fn test_complex_atom_map() {
let smirks = "[C:1]([H:2])([H:3])([H:4])[Br:5]>>[C:1]([H:2])([H:3])([H:4])[OH:5]";
let parsed = parse_smirks(smirks);
assert!(parsed.is_ok());
let t = parsed.unwrap();
assert!(t.atom_map.len() >= 2); }
#[test]
fn test_invalid_non_bijective_map() {
let result = parse_smirks("[C:1][O:2]>>[C:1]");
if let Err(e) = result {
assert!(e.contains("bijective") || e.contains("map"));
}
}
#[test]
fn test_smirks_with_aromatic() {
let result = apply_smirks("[c:1][c:2]>>[c:1][c:2]", "c1ccccc1");
assert!(result.is_ok());
let res = result.unwrap();
assert!(res.success); }
#[test]
fn test_smirks_no_match_wrong_functional_group() {
let _result = apply_smirks("[C:1][OH:2]>>[C:1][O-:2]", "CCOC").unwrap();
}
#[test]
fn test_smirks_ethanol_match() {
let result = apply_smirks("[C:1][OH:2]>>[C:1][O-:2]", "CCO").unwrap();
assert!(result.success);
assert_eq!(result.n_transforms, 1);
}
#[test]
fn test_smirks_parse_multiple_reactants() {
let smirks = "[C:1].[O:2]>>[C:1][O:2]";
let parsed = parse_smirks(smirks);
assert!(parsed.is_ok());
let t = parsed.unwrap();
assert_eq!(t.reactant_smarts.len(), 2);
assert_eq!(t.product_smarts.len(), 1);
}
#[test]
fn test_empty_smirks_rejected() {
assert!(parse_smirks("").is_err());
}
#[test]
fn test_smirks_only_separator() {
let err = parse_smirks(">>").unwrap_err();
assert!(err.contains("non-empty"));
}
#[test]
fn test_multiple_separators_rejected() {
let err = parse_smirks("A>>B>>C").unwrap_err();
assert!(err.contains("exactly one"));
}
}