use gapsmith_db::ComplexSubunitTable;
use regex::Regex;
use std::collections::HashMap;
use std::sync::OnceLock;
pub fn detect_subunits(
rxn_id: &str,
descriptors: &[&str],
dict: &ComplexSubunitTable,
) -> Vec<Option<String>> {
if descriptors.is_empty() {
return Vec::new();
}
let mut hits: Vec<Option<String>> = descriptors
.iter()
.map(|s| extract_subunit_phrase(s))
.collect();
for h in hits.iter_mut().flatten() {
*h = canonicalize_subunit(h);
}
if let Some(entries) = dict.for_rxn(rxn_id) {
let lookup: HashMap<String, String> = entries
.iter()
.map(|(synonym, canonical)| {
(format!("Subunit {synonym}"), format!("Subunit {canonical}"))
})
.collect();
for h in hits.iter_mut().flatten() {
if let Some(repl) = lookup.get(h.as_str()) {
*h = repl.clone();
}
}
}
for h in hits.iter_mut().flatten() {
*h = apply_numeral_maps(h);
}
let re_sub_num_letter = sub_num_letter_re();
for h in hits.iter_mut().flatten() {
if let Some(caps) = re_sub_num_letter.captures(h) {
*h = caps.get(1).unwrap().as_str().to_string();
}
}
apply_low_count_filter(&mut hits);
apply_numbered_quality_filter(&mut hits);
let n = hits.len();
let any = hits.iter().filter(|h| h.is_some()).count();
if any * 5 <= n {
hits.fill(None);
}
hits
}
fn extract_subunit_phrase(descriptor: &str) -> Option<String> {
static COMPILED: OnceLock<Regex> = OnceLock::new();
let re = COMPILED.get_or_init(|| {
let syn = r"(?:subunit|chain|polypeptide|component)";
let numeric = format!(r"{syn} [1-9]+(?:[A-Z])?\b");
let caps_letters = format!(r"{syn} [A-Z]+\b");
let pre_caps = format!(r"\b[A-Z]+ {syn}");
let greek_pre = format!(
r"(?:alpha|beta|gamma|delta|epsilon|zeta|eta|theta|iota|kappa|lambda|my|ny|omikron|pi|rho|sigma) {syn}"
);
let greek_post = format!(
r"{syn} (?:alpha|beta|gamma|delta|epsilon|zeta|eta|theta|iota|kappa|lambda|my|ny|omikron|pi|rho|sigma)"
);
let titled = format!(r"{syn} [A-Z][A-Za-z]+\b");
let size_pre = format!(r"(?:large|medium|small) {syn}");
let greek_dash = format!(
r"(?:alpha|beta|gamma|delta|epsilon|zeta|eta|theta|iota|kappa|lambda|my|ny|omikron|pi|rho|sigma)-{syn}"
);
let combined = format!(
r"({numeric}|{caps_letters}|{pre_caps}|{greek_pre}|{greek_post}|{titled}|{size_pre}|{greek_dash})"
);
Regex::new(&combined).expect("complex pattern should compile")
});
re.find(descriptor).map(|m| m.as_str().to_string())
}
fn canonicalize_subunit(raw: &str) -> String {
static SYN: OnceLock<Regex> = OnceLock::new();
let re_syn = SYN.get_or_init(|| {
Regex::new(r"subunit|chain|polypeptide|component").expect("syn regex")
});
let mut out = re_syn.replace_all(raw, "Subunit").to_string();
static PRE_CAPS: OnceLock<Regex> = OnceLock::new();
let re_cap = PRE_CAPS
.get_or_init(|| Regex::new(r"([A-Z]+) Subunit").expect("pre-caps regex"));
out = re_cap.replace_all(&out, "Subunit $1").to_string();
static PRE_GREEK: OnceLock<Regex> = OnceLock::new();
let re_greek = PRE_GREEK.get_or_init(|| {
Regex::new(
r"(alpha|beta|gamma|delta|epsilon|zeta|eta|theta|iota|kappa|lambda|my|ny|omikron|pi|rho|sigma) Subunit",
)
.expect("greek pre-regex")
});
out = re_greek.replace_all(&out, "Subunit $1").to_string();
static DASH_GREEK: OnceLock<Regex> = OnceLock::new();
let re_dash = DASH_GREEK.get_or_init(|| {
Regex::new(
r"(alpha|beta|gamma|delta|epsilon|zeta|eta|theta|iota|kappa|lambda|my|ny|omikron|pi|rho|sigma)-Subunit",
)
.expect("dash greek regex")
});
out = re_dash.replace_all(&out, "Subunit $1").to_string();
static PRE_SIZE: OnceLock<Regex> = OnceLock::new();
let re_size =
PRE_SIZE.get_or_init(|| Regex::new(r"(small|medium|large) Subunit").expect("size regex"));
out = re_size.replace_all(&out, "Subunit $1").to_string();
out
}
fn apply_numeral_maps(raw: &str) -> String {
const LATIN: &[(&str, &str)] = &[
("XV", "15"),
("XIV", "14"),
("XIII", "13"),
("XII", "12"),
("XI", "11"),
("X", "10"),
("IX", "9"),
("VIII", "8"),
("VII", "7"),
("VI", "6"),
("V", "5"),
("IV", "4"),
("III", "3"),
("II", "2"),
("I", "1"),
];
static LATIN_RES: OnceLock<Vec<(Regex, String)>> = OnceLock::new();
let latin_rs = LATIN_RES.get_or_init(|| {
LATIN
.iter()
.map(|(pat, rep)| {
(Regex::new(&format!(r"(?i)\b{pat}")).expect("latin re"), (*rep).to_string())
})
.collect()
});
let mut out = raw.to_string();
for (re, rep) in latin_rs {
out = re.replace_all(&out, rep).to_string();
}
static LETTER_RES: OnceLock<Vec<(Regex, String)>> = OnceLock::new();
let letter_rs = LETTER_RES.get_or_init(|| {
(b'A'..=b'Z')
.enumerate()
.map(|(i, c)| {
(
Regex::new(&format!(r"(?i)\b{}\b", c as char)).expect("letter re"),
(i + 1).to_string(),
)
})
.collect()
});
for (re, rep) in letter_rs {
out = re.replace_all(&out, rep).to_string();
}
const GREEK: &[&str] = &[
"alpha", "beta", "gamma", "delta", "epsilon", "zeta", "eta", "theta", "iota", "kappa",
"lambda", "my", "ny", "omikron", "pi", "rho", "sigma",
];
static GREEK_RES: OnceLock<Vec<(Regex, String)>> = OnceLock::new();
let greek_rs = GREEK_RES.get_or_init(|| {
GREEK
.iter()
.enumerate()
.map(|(i, g)| {
(
Regex::new(&format!(r"(?i)\b{g}\b")).expect("greek re"),
(i + 1).to_string(),
)
})
.collect()
});
for (re, rep) in greek_rs {
out = re.replace_all(&out, rep).to_string();
}
out = out.replace("small", "1").replace("medium", "2").replace("large", "3");
out
}
fn sub_num_letter_re() -> &'static Regex {
static RE: OnceLock<Regex> = OnceLock::new();
RE.get_or_init(|| Regex::new(r"^(Subunit [0-9]+)[A-Za-z]$").expect("sub-num-letter re"))
}
fn apply_low_count_filter(hits: &mut [Option<String>]) {
let mut tally: HashMap<&str, usize> = HashMap::new();
for h in hits.iter().flatten() {
*tally.entry(h.as_str()).or_insert(0) += 1;
}
if tally.is_empty() {
return;
}
let sum: usize = tally.values().sum();
let mean = sum as f64 / tally.len() as f64;
if mean < 10.0 {
return;
}
let low: Vec<String> =
tally.iter().filter(|(_, c)| **c < 5).map(|(k, _)| (*k).to_string()).collect();
for h in hits.iter_mut() {
if let Some(inner) = h {
if low.iter().any(|lo| lo == inner) {
*h = None;
}
}
}
}
fn apply_numbered_quality_filter(hits: &mut [Option<String>]) {
static NUMBERED: OnceLock<Regex> = OnceLock::new();
let re = NUMBERED.get_or_init(|| Regex::new(r"^Subunit [0-9]+").expect("numbered re"));
let numbered_count = hits
.iter()
.flatten()
.filter(|s| re.is_match(s))
.count();
if hits.is_empty() {
return;
}
let cov = numbered_count as f64 / hits.len() as f64;
if cov < 0.66 {
return;
}
for h in hits.iter_mut() {
if let Some(inner) = h {
if !re.is_match(inner) {
*h = None;
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn empty_dict() -> ComplexSubunitTable {
ComplexSubunitTable::default()
}
#[test]
fn extract_numeric_subunit() {
let r = extract_subunit_phrase("DNA-directed RNA polymerase subunit 2");
assert_eq!(r.as_deref(), Some("subunit 2"));
}
#[test]
fn extract_greek_subunit() {
let r = extract_subunit_phrase("ATP synthase alpha chain");
assert_eq!(r.as_deref(), Some("alpha chain"));
}
#[test]
fn extract_size_subunit() {
let r = extract_subunit_phrase("Ribulose bisphosphate carboxylase large chain");
assert_eq!(r.as_deref(), Some("large chain"));
}
#[test]
fn no_match() {
assert_eq!(extract_subunit_phrase("Acetaldehyde dehydrogenase"), None);
}
#[test]
fn canonicalize_reorders() {
assert_eq!(canonicalize_subunit("alpha chain"), "Subunit alpha");
assert_eq!(canonicalize_subunit("small subunit"), "Subunit small");
assert_eq!(canonicalize_subunit("ABC component"), "Subunit ABC");
}
#[test]
fn numeral_mapping_greek() {
assert_eq!(apply_numeral_maps("Subunit alpha"), "Subunit 1");
assert_eq!(apply_numeral_maps("Subunit sigma"), "Subunit 17");
}
#[test]
fn numeral_mapping_latin() {
assert_eq!(apply_numeral_maps("Subunit XIII"), "Subunit 13");
assert_eq!(apply_numeral_maps("Subunit IV"), "Subunit 4");
}
#[test]
fn numeral_mapping_letters() {
assert_eq!(apply_numeral_maps("Subunit A"), "Subunit 1");
assert_eq!(apply_numeral_maps("Subunit C"), "Subunit 3");
}
#[test]
fn numeral_mapping_size() {
assert_eq!(apply_numeral_maps("Subunit small"), "Subunit 1");
assert_eq!(apply_numeral_maps("Subunit large"), "Subunit 3");
}
#[test]
fn full_pipeline_on_mixed_headers() {
let dict = empty_dict();
let descriptors = &[
"DNA polymerase subunit alpha OS=...",
"DNA polymerase subunit beta OS=...",
"DNA polymerase subunit gamma OS=...",
"Unrelated protein OS=...",
];
let out = detect_subunits("RXN-TEST", descriptors, &dict);
assert_eq!(out.len(), 4);
assert_eq!(out[0].as_deref(), Some("Subunit 1"));
assert_eq!(out[1].as_deref(), Some("Subunit 2"));
assert_eq!(out[2].as_deref(), Some("Subunit 3"));
assert_eq!(out[3], None);
}
#[test]
fn coverage_under_20_percent_blanks_all() {
let dict = empty_dict();
let mut headers = vec!["random protein"; 9];
headers.push("DNA pol subunit alpha");
let out = detect_subunits("RXN", &headers, &dict);
assert!(out.iter().all(|h| h.is_none()));
}
#[test]
fn numbered_quality_filter_drops_non_numbered() {
let dict = empty_dict();
let hs = (1..=7)
.map(|i| format!("DNA pol subunit {i}"))
.collect::<Vec<_>>();
let mut input: Vec<String> = hs.clone();
input.push("Ribonuclease subunit alpha".into());
input.push("random".into());
let refs: Vec<&str> = input.iter().map(|s| s.as_str()).collect();
let out = detect_subunits("RXN", &refs, &dict);
for (i, h) in out.iter().take(7).enumerate() {
assert!(h.is_some(), "numbered input {} dropped: {h:?}", i + 1);
let s = h.as_ref().unwrap();
assert!(s.starts_with("Subunit "), "expected Subunit prefix: {s}");
}
assert_eq!(out[7].as_deref(), Some("Subunit 1"));
assert_eq!(out[8], None);
}
}