use std::cell::RefCell;
use crate::alignment::{align, AlignBuffer, Alignment};
use crate::error::{Error, Result};
use crate::numbering::apply_numbering;
use crate::scoring::ScoringMatrix;
use crate::types::{Chain, Position, Scheme, TCR_CHAINS};
#[cfg(feature = "python")]
use pyo3::prelude::*;
use serde::{Deserialize, Serialize};
#[cfg_attr(feature = "python", pyclass(get_all))]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NumberingResult {
pub chain: Chain,
pub scheme: Scheme,
pub positions: Vec<Position>,
pub cons_start: usize,
pub cons_end: usize,
pub confidence: f32,
pub query_start: usize,
pub query_end: usize,
}
pub const DEFAULT_MIN_CONFIDENCE: f32 = 0.5;
#[cfg_attr(
feature = "python",
pyclass(name = "_Annotator", module = "immunum._internal", unsendable)
)]
#[cfg_attr(feature = "wasm", wasm_bindgen::prelude::wasm_bindgen(skip_typescript))]
#[derive(Serialize, Deserialize)]
pub struct Annotator {
pub(crate) matrices: Vec<(Chain, ScoringMatrix)>,
pub(crate) scheme: Scheme,
pub(crate) chains: Vec<Chain>,
pub(crate) min_confidence: f32,
#[serde(skip)]
align_buf: RefCell<AlignBuffer>,
}
impl Clone for Annotator {
fn clone(&self) -> Self {
Self {
matrices: self.matrices.clone(),
scheme: self.scheme,
chains: self.chains.clone(),
min_confidence: self.min_confidence,
align_buf: RefCell::new(AlignBuffer::new()),
}
}
}
impl Annotator {
pub fn new(chains: &[Chain], scheme: Scheme, min_confidence: Option<f32>) -> Result<Self> {
if chains.is_empty() {
return Err(Error::InvalidChain("chains cannot be empty".to_string()));
}
if scheme == Scheme::Kabat && chains.iter().any(|c| TCR_CHAINS.contains(c)) {
return Err(Error::InvalidScheme(
"Kabat scheme only supported for antibody chains (IGH, IGK, IGL)".to_string(),
));
}
let mut matrices = Vec::new();
for &chain in chains {
let matrix = ScoringMatrix::load(chain)?;
matrices.push((chain, matrix));
}
Ok(Self {
matrices,
scheme,
chains: chains.to_vec(),
min_confidence: min_confidence.unwrap_or(DEFAULT_MIN_CONFIDENCE),
align_buf: RefCell::new(AlignBuffer::new()),
})
}
pub fn number(&self, sequence: &str) -> Result<NumberingResult> {
if sequence.is_empty() {
return Err(Error::InvalidSequence("empty sequence".to_string()));
}
let (chain, alignment) = self.get_best_alignment(sequence)?;
let aligned_positions = &alignment.positions[alignment.query_start..=alignment.query_end];
let positions = apply_numbering(aligned_positions, self.scheme, chain);
let confidence = if alignment.max_confidence_score > 0.0 {
(alignment.confidence_score / alignment.max_confidence_score).clamp(0.0, 1.0)
} else {
0.0
};
if confidence < self.min_confidence {
return Err(Error::LowConfidence {
confidence,
threshold: self.min_confidence,
});
}
Ok(NumberingResult {
chain,
scheme: self.scheme,
positions,
cons_start: alignment.cons_start as usize,
cons_end: alignment.cons_end as usize,
confidence,
query_start: alignment.query_start,
query_end: alignment.query_end,
})
}
pub fn get_best_alignment(&self, sequence: &str) -> Result<(Chain, Alignment)> {
let mut buf = self.align_buf.borrow_mut();
let mut best: Option<(Chain, Alignment)> = None;
for (chain, matrix) in &self.matrices {
if let Ok(alignment) = align(sequence, &matrix.positions, Some(&mut *buf)) {
let is_better = match &best {
Some((_, prev)) => alignment.score > prev.score,
None => true,
};
if is_better {
best = Some((*chain, alignment));
}
}
}
best.ok_or_else(|| Error::AlignmentError("failed to align to any chain type".to_string()))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::ALL_CHAINS;
#[test]
fn test_create_annotator() {
let annotator = Annotator::new(ALL_CHAINS, Scheme::IMGT, None).unwrap();
assert_eq!(annotator.matrices.len(), 7);
}
#[test]
fn test_create_annotator_with_chains() {
let annotator = Annotator::new(&[Chain::IGH, Chain::IGK], Scheme::IMGT, None).unwrap();
assert_eq!(annotator.matrices.len(), 2);
}
#[test]
fn test_number_igh_sequence() {
let annotator = Annotator::new(ALL_CHAINS, Scheme::IMGT, None).unwrap();
let sequence =
"QVQLVQSGAEVKRPGSSVTVSCKASGGSFSTYALSWVRQAPGRGLEWMGGVIPLLTITNYAPRFQGRITITADRSTSTAYLELNSLRPEDTAVYYCAREGTTGKPIGAFAHWGQGTLVTVSS";
let result = annotator.number(sequence).unwrap();
assert_eq!(result.chain, Chain::IGH);
assert_eq!(result.scheme, Scheme::IMGT);
assert!(result.confidence > 0.0 && result.confidence <= 1.0);
assert_eq!(
result.positions.len(),
result.query_end - result.query_start + 1
);
}
#[test]
fn test_number_with_single_chain() {
let annotator = Annotator::new(&[Chain::IGH], Scheme::IMGT, None).unwrap();
let sequence =
"QVQLVQSGAEVKRPGSSVTVSCKASGGSFSTYALSWVRQAPGRGLEWMGGVIPLLTITNYAPRFQGRITITADRSTSTAYLELNSLRPEDTAVYYCAREGTTGKPIGAFAHWGQGTLVTVSS";
let result = annotator.number(sequence).unwrap();
assert_eq!(result.chain, Chain::IGH);
}
#[test]
fn test_empty_sequence() {
let annotator = Annotator::new(ALL_CHAINS, Scheme::IMGT, None).unwrap();
let result = annotator.number("");
assert!(result.is_err());
}
const FULL_IGH: &str = "EVQLVESGGGLVQPGGSLRLSCAASGFNVSYSSIHWVRQAPGKGLEWVAYIYPSSGYTSYADSVKGRFTISADTSKNTAYLQMNSLRAEDTAVYYCARSYSTKLAMDYWGQGTLVTVSS";
#[test]
fn test_number_no_flanking_has_zero_query_start_end() {
let annotator = Annotator::new(&[Chain::IGH], Scheme::IMGT, None).unwrap();
let result = annotator.number(FULL_IGH).unwrap();
assert_eq!(result.query_start, 0);
assert_eq!(result.query_end, FULL_IGH.len() - 1);
assert_eq!(result.positions.len(), FULL_IGH.len());
}
#[test]
fn test_number_with_prefix() {
let annotator = Annotator::new(&[Chain::IGH], Scheme::IMGT, None).unwrap();
let prefix = "AAAAAA";
let sequence = format!("{prefix}{FULL_IGH}");
let result = annotator.number(&sequence).unwrap();
assert_eq!(result.chain, Chain::IGH);
assert_eq!(result.query_start, prefix.len());
assert_eq!(result.query_end, sequence.len() - 1);
assert_eq!(result.positions.len(), FULL_IGH.len());
}
#[test]
fn test_number_with_suffix() {
let annotator = Annotator::new(&[Chain::IGH], Scheme::IMGT, None).unwrap();
let suffix = "AAAAAAA";
let sequence = format!("{FULL_IGH}{suffix}");
let result = annotator.number(&sequence).unwrap();
assert_eq!(result.chain, Chain::IGH);
assert_eq!(result.query_start, 0);
assert_eq!(result.query_end, FULL_IGH.len() - 1);
assert_eq!(result.positions.len(), FULL_IGH.len());
}
#[test]
fn test_number_with_both_flanking() {
let annotator = Annotator::new(&[Chain::IGH], Scheme::IMGT, None).unwrap();
let prefix = "AAAAAA";
let suffix = "AAAAAAA";
let sequence = format!("{prefix}{FULL_IGH}{suffix}");
let result = annotator.number(&sequence).unwrap();
assert_eq!(result.chain, Chain::IGH);
assert_eq!(result.query_start, prefix.len());
assert_eq!(result.query_end, prefix.len() + FULL_IGH.len() - 1);
assert_eq!(result.positions.len(), FULL_IGH.len());
}
}