1use std::cell::RefCell;
3
4use crate::alignment::{align, AlignBuffer, Alignment};
5use crate::error::{Error, Result};
6use crate::numbering::{apply_numbering, segment as segment_positions};
7use crate::scoring::ScoringMatrix;
8use crate::types::{Chain, Position, Scheme, TCR_CHAINS};
9
10#[cfg(feature = "python")]
11use pyo3::prelude::*;
12use serde::{Deserialize, Serialize};
13
14#[cfg_attr(feature = "python", pyclass(get_all))]
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct NumberingResult {
18 pub chain: Chain,
20 pub scheme: Scheme,
22 pub positions: Vec<Position>,
24 pub cons_start: usize,
26 pub cons_end: usize,
28 pub confidence: f32,
30 pub query_start: usize,
32 pub query_end: usize,
34}
35
36#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct SegmentResult {
39 pub prefix: String,
40 pub fr1: String,
41 pub cdr1: String,
42 pub fr2: String,
43 pub cdr2: String,
44 pub fr3: String,
45 pub cdr3: String,
46 pub fr4: String,
47 pub postfix: String,
48}
49
50pub const DEFAULT_MIN_CONFIDENCE: f32 = 0.5;
61
62pub const MIN_SEQUENCE_LENGTH: usize = 30;
64
65pub const MAX_SEQUENCE_LENGTH: usize = 1000;
67
68fn validate_sequence(sequence: &str) -> Result<()> {
71 let len = sequence.len();
72 if len < MIN_SEQUENCE_LENGTH {
73 return Err(Error::InvalidSequence(format!(
74 "sequence length {} is below minimum {}",
75 len, MIN_SEQUENCE_LENGTH
76 )));
77 }
78 if len > MAX_SEQUENCE_LENGTH {
79 return Err(Error::InvalidSequence(format!(
80 "sequence length {} exceeds maximum {}",
81 len, MAX_SEQUENCE_LENGTH
82 )));
83 }
84 for (i, b) in sequence.bytes().enumerate() {
85 if !b.is_ascii_alphabetic() {
86 return Err(Error::InvalidSequence(format!(
87 "invalid character {:?} at position {i}",
88 b as char
89 )));
90 }
91 }
92 Ok(())
93}
94
95#[cfg_attr(
97 feature = "python",
98 pyclass(name = "_Annotator", module = "immunum._internal", unsendable)
99)]
100#[cfg_attr(feature = "wasm", wasm_bindgen::prelude::wasm_bindgen(skip_typescript))]
101#[derive(Serialize, Deserialize)]
102pub struct Annotator {
103 pub(crate) matrices: Vec<(Chain, ScoringMatrix)>,
104 pub(crate) scheme: Scheme,
105 pub(crate) chains: Vec<Chain>,
106 pub(crate) min_confidence: f32,
107 #[serde(skip)]
109 align_buf: RefCell<AlignBuffer>,
110}
111
112impl Clone for Annotator {
113 fn clone(&self) -> Self {
114 Self {
115 matrices: self.matrices.clone(),
116 scheme: self.scheme,
117 chains: self.chains.clone(),
118 min_confidence: self.min_confidence,
119 align_buf: RefCell::new(AlignBuffer::new()),
120 }
121 }
122}
123
124impl Annotator {
125 pub fn new(chains: &[Chain], scheme: Scheme, min_confidence: Option<f32>) -> Result<Self> {
126 if chains.is_empty() {
127 return Err(Error::InvalidChain("chains cannot be empty".to_string()));
128 }
129
130 if scheme == Scheme::Kabat && chains.iter().any(|c| TCR_CHAINS.contains(c)) {
132 return Err(Error::InvalidScheme(
133 "Kabat scheme only supported for antibody chains (IGH, IGK, IGL)".to_string(),
134 ));
135 }
136
137 let mut matrices = Vec::new();
138 for &chain in chains {
139 let matrix = ScoringMatrix::load(chain)?;
140 matrices.push((chain, matrix));
141 }
142
143 Ok(Self {
144 matrices,
145 scheme,
146 chains: chains.to_vec(),
147 min_confidence: min_confidence.unwrap_or(DEFAULT_MIN_CONFIDENCE),
148 align_buf: RefCell::new(AlignBuffer::new()),
149 })
150 }
151
152 pub fn number(&self, sequence: &str) -> Result<NumberingResult> {
154 validate_sequence(sequence)?;
155
156 let (chain, alignment) = self.get_best_alignment(sequence)?;
157
158 let aligned_positions = &alignment.positions[alignment.query_start..=alignment.query_end];
160 let positions = apply_numbering(aligned_positions, self.scheme, chain);
161 let confidence = if alignment.max_confidence_score > 0.0 {
162 (alignment.confidence_score / alignment.max_confidence_score).clamp(0.0, 1.0)
163 } else {
164 0.0
165 };
166
167 if confidence < self.min_confidence {
168 return Err(Error::LowConfidence {
169 confidence,
170 threshold: self.min_confidence,
171 });
172 }
173
174 Ok(NumberingResult {
175 chain,
176 scheme: self.scheme,
177 positions,
178 cons_start: alignment.cons_start as usize,
179 cons_end: alignment.cons_end as usize,
180 confidence,
181 query_start: alignment.query_start,
182 query_end: alignment.query_end,
183 })
184 }
185
186 pub fn segment(&self, sequence: &str) -> Result<SegmentResult> {
188 let result = self.number(sequence)?;
189 let aligned_seq = &sequence[result.query_start..=result.query_end];
190 let mut map = segment_positions(&result.positions, aligned_seq, result.scheme);
191 Ok(SegmentResult {
192 prefix: map.remove("prefix").unwrap_or_default(),
193 fr1: map.remove("fr1").unwrap_or_default(),
194 cdr1: map.remove("cdr1").unwrap_or_default(),
195 fr2: map.remove("fr2").unwrap_or_default(),
196 cdr2: map.remove("cdr2").unwrap_or_default(),
197 fr3: map.remove("fr3").unwrap_or_default(),
198 cdr3: map.remove("cdr3").unwrap_or_default(),
199 fr4: map.remove("fr4").unwrap_or_default(),
200 postfix: map.remove("postfix").unwrap_or_default(),
201 })
202 }
203
204 fn get_best_alignment(&self, sequence: &str) -> Result<(Chain, Alignment)> {
209 let mut buf = self.align_buf.borrow_mut();
210 let mut best: Option<(Chain, Alignment)> = None;
212 for (chain, matrix) in &self.matrices {
213 let alignment = align(sequence, &matrix.positions, Some(&mut *buf));
214 let is_better = match &best {
215 Some((_, prev)) => alignment.score > prev.score,
216 None => true,
217 };
218 if is_better {
219 best = Some((*chain, alignment));
220 }
221 }
222 best.ok_or_else(|| Error::AlignmentError("failed to align to any chain type".to_string()))
223 }
224}
225
226#[cfg(test)]
227mod tests {
228 use super::*;
229 use crate::types::ALL_CHAINS;
230
231 #[test]
232 fn test_create_annotator() {
233 let annotator = Annotator::new(ALL_CHAINS, Scheme::IMGT, None).unwrap();
234 assert_eq!(annotator.matrices.len(), 7);
235 }
236
237 #[test]
238 fn test_create_annotator_with_chains() {
239 let annotator = Annotator::new(&[Chain::IGH, Chain::IGK], Scheme::IMGT, None).unwrap();
240 assert_eq!(annotator.matrices.len(), 2);
241 }
242
243 #[test]
244 fn test_number_igh_sequence() {
245 let annotator = Annotator::new(ALL_CHAINS, Scheme::IMGT, None).unwrap();
246
247 let sequence =
249 "QVQLVQSGAEVKRPGSSVTVSCKASGGSFSTYALSWVRQAPGRGLEWMGGVIPLLTITNYAPRFQGRITITADRSTSTAYLELNSLRPEDTAVYYCAREGTTGKPIGAFAHWGQGTLVTVSS";
250
251 let result = annotator.number(sequence).unwrap();
252
253 assert_eq!(result.chain, Chain::IGH);
255 assert_eq!(result.scheme, Scheme::IMGT);
256 assert!(result.confidence > 0.0 && result.confidence <= 1.0);
257 assert_eq!(
258 result.positions.len(),
259 result.query_end - result.query_start + 1
260 );
261 }
262
263 #[test]
264 fn test_number_with_single_chain() {
265 let annotator = Annotator::new(&[Chain::IGH], Scheme::IMGT, None).unwrap();
266 let sequence =
267 "QVQLVQSGAEVKRPGSSVTVSCKASGGSFSTYALSWVRQAPGRGLEWMGGVIPLLTITNYAPRFQGRITITADRSTSTAYLELNSLRPEDTAVYYCAREGTTGKPIGAFAHWGQGTLVTVSS";
268
269 let result = annotator.number(sequence).unwrap();
270 assert_eq!(result.chain, Chain::IGH);
271 }
272
273 #[test]
274 fn test_empty_sequence() {
275 let annotator = Annotator::new(ALL_CHAINS, Scheme::IMGT, None).unwrap();
276 let result = annotator.number("");
277 assert!(result.is_err());
278 }
279
280 const FULL_IGH: &str = "EVQLVESGGGLVQPGGSLRLSCAASGFNVSYSSIHWVRQAPGKGLEWVAYIYPSSGYTSYADSVKGRFTISADTSKNTAYLQMNSLRAEDTAVYYCARSYSTKLAMDYWGQGTLVTVSS";
282
283 #[test]
284 fn test_number_no_flanking_has_zero_query_start_end() {
285 let annotator = Annotator::new(&[Chain::IGH], Scheme::IMGT, None).unwrap();
286 let result = annotator.number(FULL_IGH).unwrap();
287 assert_eq!(result.query_start, 0);
288 assert_eq!(result.query_end, FULL_IGH.len() - 1);
289 assert_eq!(result.positions.len(), FULL_IGH.len());
290 }
291
292 #[test]
293 fn test_number_with_prefix() {
294 let annotator = Annotator::new(&[Chain::IGH], Scheme::IMGT, None).unwrap();
295 let prefix = "AAAAAA";
296 let sequence = format!("{prefix}{FULL_IGH}");
297 let result = annotator.number(&sequence).unwrap();
298 assert_eq!(result.chain, Chain::IGH);
299 assert_eq!(result.query_start, prefix.len());
300 assert_eq!(result.query_end, sequence.len() - 1);
301 assert_eq!(result.positions.len(), FULL_IGH.len());
302 }
303
304 #[test]
305 fn test_number_with_suffix() {
306 let annotator = Annotator::new(&[Chain::IGH], Scheme::IMGT, None).unwrap();
307 let suffix = "AAAAAAA";
308 let sequence = format!("{FULL_IGH}{suffix}");
309 let result = annotator.number(&sequence).unwrap();
310 assert_eq!(result.chain, Chain::IGH);
311 assert_eq!(result.query_start, 0);
312 assert_eq!(result.query_end, FULL_IGH.len() - 1);
313 assert_eq!(result.positions.len(), FULL_IGH.len());
314 }
315
316 #[test]
317 fn test_segment_igh_sequence() {
318 let annotator = Annotator::new(&[Chain::IGH], Scheme::IMGT, None).unwrap();
319 let sequence =
320 "QVQLVQSGAEVKRPGSSVTVSCKASGGSFSTYALSWVRQAPGRGLEWMGGVIPLLTITNYAPRFQGRITITADRSTSTAYLELNSLRPEDTAVYYCAREGTTGKPIGAFAHWGQGTLVTVSS";
321 let segments = annotator.segment(sequence).unwrap();
322 assert_eq!(segments.fr1, "QVQLVQSGAEVKRPGSSVTVSCKAS");
323 assert_eq!(segments.cdr1, "GGSFSTYA");
324 assert_eq!(segments.cdr3, "AREGTTGKPIGAFAH");
325 assert_eq!(segments.fr4, "WGQGTLVTVSS");
326 assert!(segments.prefix.is_empty());
327 assert!(segments.postfix.is_empty());
328 }
329
330 #[test]
331 fn test_number_with_both_flanking() {
332 let annotator = Annotator::new(&[Chain::IGH], Scheme::IMGT, None).unwrap();
333 let prefix = "AAAAAA";
334 let suffix = "AAAAAAA";
335 let sequence = format!("{prefix}{FULL_IGH}{suffix}");
336 let result = annotator.number(&sequence).unwrap();
337 assert_eq!(result.chain, Chain::IGH);
338 assert_eq!(result.query_start, prefix.len());
339 assert_eq!(result.query_end, prefix.len() + FULL_IGH.len() - 1);
340 assert_eq!(result.positions.len(), FULL_IGH.len());
341 }
342}