Skip to main content

immunum/
annotator.rs

1//! High-level API for sequence annotation and chain detection
2use std::cell::RefCell;
3
4use crate::alignment::{align, AlignBuffer, Alignment};
5use crate::error::{Error, Result};
6use crate::numbering::{apply_numbering, segment as segment_positions};
7use crate::scoring::ScoringMatrix;
8use crate::types::{Chain, Position, Scheme, TCR_CHAINS};
9
10#[cfg(feature = "python")]
11use pyo3::prelude::*;
12use serde::{Deserialize, Serialize};
13
14/// Result of numbering a sequence
15#[cfg_attr(feature = "python", pyclass(get_all))]
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct NumberingResult {
18    /// Detected chain type
19    pub chain: Chain,
20    /// Numbering scheme used
21    pub scheme: Scheme,
22    /// Numbered positions for the aligned region only (length == query_end - query_start + 1)
23    pub positions: Vec<Position>,
24    /// First aligned consensus position
25    pub cons_start: usize,
26    /// Last aligned consensus position
27    pub cons_end: usize,
28    /// Confidence score (normalized alignment score)
29    pub confidence: f32,
30    /// 0-based index of the first antibody residue in the query (0 when no prefix)
31    pub query_start: usize,
32    /// 0-based index of the last antibody residue in the query (query.len()-1 when no suffix)
33    pub query_end: usize,
34}
35
36/// Result of segmenting a sequence into FR/CDR regions
37#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct SegmentResult {
39    pub prefix: String,
40    pub fr1: String,
41    pub cdr1: String,
42    pub fr2: String,
43    pub cdr2: String,
44    pub fr3: String,
45    pub cdr3: String,
46    pub fr4: String,
47    pub postfix: String,
48}
49
50/// Default minimum confidence threshold for accepting a numbering result.
51///
52/// Based on empirical analysis of validated sequences:
53/// - Antibody sequences (IGH/IGK/IGL): min ~0.51, median ~0.78-0.85
54/// - TCR sequences (TRA/TRB/TRG/TRD): min ~0.28, median ~0.62-0.83
55///
56/// A threshold of 0.5 filters non-immunoglobulin sequences while retaining
57/// all validated antibody sequences. Some low-scoring TCR sequences (notably
58/// TCR-A p5=0.39, TCR-B p5=0.49) may fall below this threshold due to less
59/// complete consensus data. Set to 0.0 to disable filtering.
60pub const DEFAULT_MIN_CONFIDENCE: f32 = 0.5;
61
62/// Minimum allowed input sequence length.
63pub const MIN_SEQUENCE_LENGTH: usize = 30;
64
65/// Maximum allowed input sequence length.
66pub const MAX_SEQUENCE_LENGTH: usize = 1000;
67
68/// Validate that `sequence` contains only standard amino acid characters
69/// (case-insensitive) and that its length is within the allowed bounds.
70fn validate_sequence(sequence: &str) -> Result<()> {
71    let len = sequence.len();
72    if len < MIN_SEQUENCE_LENGTH {
73        return Err(Error::InvalidSequence(format!(
74            "sequence length {} is below minimum {}",
75            len, MIN_SEQUENCE_LENGTH
76        )));
77    }
78    if len > MAX_SEQUENCE_LENGTH {
79        return Err(Error::InvalidSequence(format!(
80            "sequence length {} exceeds maximum {}",
81            len, MAX_SEQUENCE_LENGTH
82        )));
83    }
84    for (i, b) in sequence.bytes().enumerate() {
85        if !b.is_ascii_alphabetic() {
86            return Err(Error::InvalidSequence(format!(
87                "invalid character {:?} at position {i}",
88                b as char
89            )));
90        }
91    }
92    Ok(())
93}
94
95/// Annotator for numbering sequences
96#[cfg_attr(
97    feature = "python",
98    pyclass(name = "_Annotator", module = "immunum._internal", unsendable)
99)]
100#[cfg_attr(feature = "wasm", wasm_bindgen::prelude::wasm_bindgen(skip_typescript))]
101#[derive(Serialize, Deserialize)]
102pub struct Annotator {
103    pub(crate) matrices: Vec<(Chain, ScoringMatrix)>,
104    pub(crate) scheme: Scheme,
105    pub(crate) chains: Vec<Chain>,
106    pub(crate) min_confidence: f32,
107    /// Reusable alignment buffer to avoid per-alignment allocation
108    #[serde(skip)]
109    align_buf: RefCell<AlignBuffer>,
110}
111
112impl Clone for Annotator {
113    fn clone(&self) -> Self {
114        Self {
115            matrices: self.matrices.clone(),
116            scheme: self.scheme,
117            chains: self.chains.clone(),
118            min_confidence: self.min_confidence,
119            align_buf: RefCell::new(AlignBuffer::new()),
120        }
121    }
122}
123
124impl Annotator {
125    pub fn new(chains: &[Chain], scheme: Scheme, min_confidence: Option<f32>) -> Result<Self> {
126        if chains.is_empty() {
127            return Err(Error::InvalidChain("chains cannot be empty".to_string()));
128        }
129
130        // Validate: Kabat only supported for antibody chains
131        if scheme == Scheme::Kabat && chains.iter().any(|c| TCR_CHAINS.contains(c)) {
132            return Err(Error::InvalidScheme(
133                "Kabat scheme only supported for antibody chains (IGH, IGK, IGL)".to_string(),
134            ));
135        }
136
137        let mut matrices = Vec::new();
138        for &chain in chains {
139            let matrix = ScoringMatrix::load(chain)?;
140            matrices.push((chain, matrix));
141        }
142
143        Ok(Self {
144            matrices,
145            scheme,
146            chains: chains.to_vec(),
147            min_confidence: min_confidence.unwrap_or(DEFAULT_MIN_CONFIDENCE),
148            align_buf: RefCell::new(AlignBuffer::new()),
149        })
150    }
151
152    /// Number a sequence by aligning to the configured chain types and applying the numbering scheme
153    pub fn number(&self, sequence: &str) -> Result<NumberingResult> {
154        validate_sequence(sequence)?;
155
156        let (chain, alignment) = self.get_best_alignment(sequence)?;
157
158        // Apply numbering only to the aligned subregion of the query
159        let aligned_positions = &alignment.positions[alignment.query_start..=alignment.query_end];
160        let positions = apply_numbering(aligned_positions, self.scheme, chain);
161        let confidence = if alignment.max_confidence_score > 0.0 {
162            (alignment.confidence_score / alignment.max_confidence_score).clamp(0.0, 1.0)
163        } else {
164            0.0
165        };
166
167        if confidence < self.min_confidence {
168            return Err(Error::LowConfidence {
169                confidence,
170                threshold: self.min_confidence,
171            });
172        }
173
174        Ok(NumberingResult {
175            chain,
176            scheme: self.scheme,
177            positions,
178            cons_start: alignment.cons_start as usize,
179            cons_end: alignment.cons_end as usize,
180            confidence,
181            query_start: alignment.query_start,
182            query_end: alignment.query_end,
183        })
184    }
185
186    /// Segment a sequence into FR/CDR regions
187    pub fn segment(&self, sequence: &str) -> Result<SegmentResult> {
188        let result = self.number(sequence)?;
189        let aligned_seq = &sequence[result.query_start..=result.query_end];
190        let mut map = segment_positions(&result.positions, aligned_seq, result.scheme);
191        Ok(SegmentResult {
192            prefix: map.remove("prefix").unwrap_or_default(),
193            fr1: map.remove("fr1").unwrap_or_default(),
194            cdr1: map.remove("cdr1").unwrap_or_default(),
195            fr2: map.remove("fr2").unwrap_or_default(),
196            cdr2: map.remove("cdr2").unwrap_or_default(),
197            fr3: map.remove("fr3").unwrap_or_default(),
198            cdr3: map.remove("cdr3").unwrap_or_default(),
199            fr4: map.remove("fr4").unwrap_or_default(),
200            postfix: map.remove("postfix").unwrap_or_default(),
201        })
202    }
203
204    /// Align the sequence to all loaded chain types and return the best match
205    /// If multiple chains were provided during initialization, this will align to all
206    /// of them and return the best match. If only one chain was provided, it will
207    /// align to that chain directly.
208    fn get_best_alignment(&self, sequence: &str) -> Result<(Chain, Alignment)> {
209        let mut buf = self.align_buf.borrow_mut();
210        // Align to all loaded chain types and find best match by raw alignment score
211        let mut best: Option<(Chain, Alignment)> = None;
212        for (chain, matrix) in &self.matrices {
213            let alignment = align(sequence, &matrix.positions, Some(&mut *buf));
214            let is_better = match &best {
215                Some((_, prev)) => alignment.score > prev.score,
216                None => true,
217            };
218            if is_better {
219                best = Some((*chain, alignment));
220            }
221        }
222        best.ok_or_else(|| Error::AlignmentError("failed to align to any chain type".to_string()))
223    }
224}
225
226#[cfg(test)]
227mod tests {
228    use super::*;
229    use crate::types::ALL_CHAINS;
230
231    #[test]
232    fn test_create_annotator() {
233        let annotator = Annotator::new(ALL_CHAINS, Scheme::IMGT, None).unwrap();
234        assert_eq!(annotator.matrices.len(), 7);
235    }
236
237    #[test]
238    fn test_create_annotator_with_chains() {
239        let annotator = Annotator::new(&[Chain::IGH, Chain::IGK], Scheme::IMGT, None).unwrap();
240        assert_eq!(annotator.matrices.len(), 2);
241    }
242
243    #[test]
244    fn test_number_igh_sequence() {
245        let annotator = Annotator::new(ALL_CHAINS, Scheme::IMGT, None).unwrap();
246
247        // Known IGH sequence
248        let sequence =
249            "QVQLVQSGAEVKRPGSSVTVSCKASGGSFSTYALSWVRQAPGRGLEWMGGVIPLLTITNYAPRFQGRITITADRSTSTAYLELNSLRPEDTAVYYCAREGTTGKPIGAFAHWGQGTLVTVSS";
250
251        let result = annotator.number(sequence).unwrap();
252
253        // Should detect as IGH
254        assert_eq!(result.chain, Chain::IGH);
255        assert_eq!(result.scheme, Scheme::IMGT);
256        assert!(result.confidence > 0.0 && result.confidence <= 1.0);
257        assert_eq!(
258            result.positions.len(),
259            result.query_end - result.query_start + 1
260        );
261    }
262
263    #[test]
264    fn test_number_with_single_chain() {
265        let annotator = Annotator::new(&[Chain::IGH], Scheme::IMGT, None).unwrap();
266        let sequence =
267            "QVQLVQSGAEVKRPGSSVTVSCKASGGSFSTYALSWVRQAPGRGLEWMGGVIPLLTITNYAPRFQGRITITADRSTSTAYLELNSLRPEDTAVYYCAREGTTGKPIGAFAHWGQGTLVTVSS";
268
269        let result = annotator.number(sequence).unwrap();
270        assert_eq!(result.chain, Chain::IGH);
271    }
272
273    #[test]
274    fn test_empty_sequence() {
275        let annotator = Annotator::new(ALL_CHAINS, Scheme::IMGT, None).unwrap();
276        let result = annotator.number("");
277        assert!(result.is_err());
278    }
279
280    // Full IGH from the task description (FR1 through FR4)
281    const FULL_IGH: &str = "EVQLVESGGGLVQPGGSLRLSCAASGFNVSYSSIHWVRQAPGKGLEWVAYIYPSSGYTSYADSVKGRFTISADTSKNTAYLQMNSLRAEDTAVYYCARSYSTKLAMDYWGQGTLVTVSS";
282
283    #[test]
284    fn test_number_no_flanking_has_zero_query_start_end() {
285        let annotator = Annotator::new(&[Chain::IGH], Scheme::IMGT, None).unwrap();
286        let result = annotator.number(FULL_IGH).unwrap();
287        assert_eq!(result.query_start, 0);
288        assert_eq!(result.query_end, FULL_IGH.len() - 1);
289        assert_eq!(result.positions.len(), FULL_IGH.len());
290    }
291
292    #[test]
293    fn test_number_with_prefix() {
294        let annotator = Annotator::new(&[Chain::IGH], Scheme::IMGT, None).unwrap();
295        let prefix = "AAAAAA";
296        let sequence = format!("{prefix}{FULL_IGH}");
297        let result = annotator.number(&sequence).unwrap();
298        assert_eq!(result.chain, Chain::IGH);
299        assert_eq!(result.query_start, prefix.len());
300        assert_eq!(result.query_end, sequence.len() - 1);
301        assert_eq!(result.positions.len(), FULL_IGH.len());
302    }
303
304    #[test]
305    fn test_number_with_suffix() {
306        let annotator = Annotator::new(&[Chain::IGH], Scheme::IMGT, None).unwrap();
307        let suffix = "AAAAAAA";
308        let sequence = format!("{FULL_IGH}{suffix}");
309        let result = annotator.number(&sequence).unwrap();
310        assert_eq!(result.chain, Chain::IGH);
311        assert_eq!(result.query_start, 0);
312        assert_eq!(result.query_end, FULL_IGH.len() - 1);
313        assert_eq!(result.positions.len(), FULL_IGH.len());
314    }
315
316    #[test]
317    fn test_segment_igh_sequence() {
318        let annotator = Annotator::new(&[Chain::IGH], Scheme::IMGT, None).unwrap();
319        let sequence =
320            "QVQLVQSGAEVKRPGSSVTVSCKASGGSFSTYALSWVRQAPGRGLEWMGGVIPLLTITNYAPRFQGRITITADRSTSTAYLELNSLRPEDTAVYYCAREGTTGKPIGAFAHWGQGTLVTVSS";
321        let segments = annotator.segment(sequence).unwrap();
322        assert_eq!(segments.fr1, "QVQLVQSGAEVKRPGSSVTVSCKAS");
323        assert_eq!(segments.cdr1, "GGSFSTYA");
324        assert_eq!(segments.cdr3, "AREGTTGKPIGAFAH");
325        assert_eq!(segments.fr4, "WGQGTLVTVSS");
326        assert!(segments.prefix.is_empty());
327        assert!(segments.postfix.is_empty());
328    }
329
330    #[test]
331    fn test_number_with_both_flanking() {
332        let annotator = Annotator::new(&[Chain::IGH], Scheme::IMGT, None).unwrap();
333        let prefix = "AAAAAA";
334        let suffix = "AAAAAAA";
335        let sequence = format!("{prefix}{FULL_IGH}{suffix}");
336        let result = annotator.number(&sequence).unwrap();
337        assert_eq!(result.chain, Chain::IGH);
338        assert_eq!(result.query_start, prefix.len());
339        assert_eq!(result.query_end, prefix.len() + FULL_IGH.len() - 1);
340        assert_eq!(result.positions.len(), FULL_IGH.len());
341    }
342}