use crate::error::{Error, Result};
use serde::{Deserialize, Serialize};
use std::fmt;
use std::str::FromStr;
use strum_macros::{Display, EnumString};
#[cfg(feature = "python")]
use pyo3::prelude::*;
#[cfg_attr(feature = "python", pyclass(get_all))]
#[derive(Debug, EnumString, Display, PartialEq, Serialize, Deserialize, Clone, Copy)]
pub enum Chain {
#[strum(
serialize = "IGH",
to_string = "H",
serialize = "heavy",
ascii_case_insensitive
)]
IGH,
#[strum(
serialize = "IGK",
to_string = "K",
serialize = "kappa",
ascii_case_insensitive
)]
IGK,
#[strum(
serialize = "IGL",
to_string = "L",
serialize = "lambda",
ascii_case_insensitive
)]
IGL,
#[strum(
serialize = "TRA",
to_string = "A",
serialize = "alpha",
ascii_case_insensitive
)]
TRA,
#[strum(
serialize = "TRB",
to_string = "B",
serialize = "beta",
ascii_case_insensitive
)]
TRB,
#[strum(
serialize = "TRG",
to_string = "G",
serialize = "gamma",
ascii_case_insensitive
)]
TRG,
#[strum(
serialize = "TRD",
to_string = "D",
serialize = "delta",
ascii_case_insensitive
)]
TRD,
}
pub const ALL_CHAINS: &[Chain] = &[
Chain::IGH,
Chain::IGK,
Chain::IGL,
Chain::TRA,
Chain::TRB,
Chain::TRG,
Chain::TRD,
];
pub const IG_CHAINS: &[Chain] = &[Chain::IGH, Chain::IGK, Chain::IGL];
pub const TCR_CHAINS: &[Chain] = &[Chain::TRA, Chain::TRB, Chain::TRG, Chain::TRD];
impl Chain {
pub fn parse_chain_spec(s: &str) -> Result<Vec<Chain>> {
match s.to_lowercase().as_str() {
"all" => Ok(ALL_CHAINS.to_vec()),
"ig" => Ok(IG_CHAINS.to_vec()),
"tcr" => Ok(TCR_CHAINS.to_vec()),
_ => s
.split(',')
.map(|c| {
Chain::from_str(c.trim()).map_err(|_| {
Error::InvalidChain(format!(
"unknown chain '{}' (options: h,k,l,a,b,g,d,ig,tcr,all)",
c.trim()
))
})
})
.collect(),
}
}
}
#[cfg_attr(feature = "python", pyclass(get_all))]
#[derive(Debug, EnumString, Display, PartialEq, Serialize, Deserialize, Clone, Copy)]
pub enum Scheme {
#[strum(to_string = "IMGT", serialize = "i", ascii_case_insensitive)]
IMGT,
#[strum(to_string = "Kabat", serialize = "k", ascii_case_insensitive)]
Kabat,
}
#[cfg_attr(feature = "python", pyclass(get_all))]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct Position {
pub number: u8,
pub insertion: Option<char>,
}
impl Position {
pub fn new(number: u8) -> Self {
Self {
number,
insertion: None,
}
}
pub fn with_insertion(number: u8, insertion: char) -> Self {
Self {
number,
insertion: Some(insertion),
}
}
}
impl fmt::Display for Position {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if let Some(ins) = self.insertion {
write!(f, "{}{}", self.number, ins)
} else {
write!(f, "{}", self.number)
}
}
}
impl FromStr for Position {
type Err = Error;
fn from_str(s: &str) -> Result<Self> {
let s = s.trim();
if s.is_empty() {
return Err(Error::InvalidPosition("empty string".to_string()));
}
let digit_end = s
.chars()
.position(|c| !c.is_ascii_digit())
.unwrap_or(s.len());
if digit_end == 0 {
return Err(Error::InvalidPosition(format!("no numeric part: {}", s)));
}
let number: u8 = s[..digit_end]
.parse()
.map_err(|_| Error::InvalidPosition(format!("invalid number: {}", s)))?;
let insertion = match &s[digit_end..] {
"" => None,
rest if rest.len() == 1 && rest.chars().next().unwrap().is_alphabetic() => {
Some(rest.chars().next().unwrap())
}
_ => {
return Err(Error::InvalidPosition(format!(
"invalid insertion part: {}",
s
)))
}
};
Ok(Self { number, insertion })
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, EnumString, Display)]
pub enum Region {
FR1,
CDR1,
FR2,
CDR2,
FR3,
CDR3,
FR4,
}
#[derive(Debug, Clone, Copy)]
pub struct NumberingRule {
pub align_start: u8,
pub align_end: u8,
pub num_start: u8,
pub num_end: u8,
pub deletion_order: &'static [u8],
pub insertion: Insertion,
}
impl NumberingRule {
pub const fn fr(start: u8, end: u8) -> Self {
Self {
align_start: start,
align_end: end,
num_start: start,
num_end: end,
deletion_order: &[],
insertion: Insertion::None,
}
}
pub const fn offset(align_start: u8, align_end: u8, offset: i8) -> Self {
let num_start = (align_start as i16 + offset as i16) as u8;
Self {
align_start,
align_end,
num_start,
num_end: num_start + (align_end - align_start),
deletion_order: &[],
insertion: Insertion::None,
}
}
pub const fn variable(
align_start: u8,
align_end: u8,
num_start: u8,
num_end: u8,
deletion_order: &'static [u8],
insertion: Insertion,
) -> Self {
Self {
align_start,
align_end,
num_start,
num_end,
deletion_order,
insertion,
}
}
#[inline]
pub const fn contains(&self, pos: u8) -> bool {
pos >= self.align_start && pos <= self.align_end
}
}
#[derive(Debug, Clone, Copy)]
pub enum Insertion {
None,
Sequential(u8),
Symmetric { left: u8, right: u8 },
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_chain_parsing() {
assert_eq!("IGH".parse::<Chain>().unwrap(), Chain::IGH);
assert_eq!("igh".parse::<Chain>().unwrap(), Chain::IGH);
assert_eq!("H".parse::<Chain>().unwrap(), Chain::IGH);
assert_eq!("heavy".parse::<Chain>().unwrap(), Chain::IGH);
assert_eq!("TRA".parse::<Chain>().unwrap(), Chain::TRA);
assert_eq!("A".parse::<Chain>().unwrap(), Chain::TRA);
assert!("invalid".parse::<Chain>().is_err());
}
#[test]
fn test_position_parsing() {
let pos = "111".parse::<Position>().unwrap();
assert_eq!(pos.number, 111);
assert_eq!(pos.insertion, None);
let pos = "111A".parse::<Position>().unwrap();
assert_eq!(pos.number, 111);
assert_eq!(pos.insertion, Some('A'));
assert!("".parse::<Position>().is_err());
assert!("A".parse::<Position>().is_err());
assert!("111AB".parse::<Position>().is_err());
}
#[test]
fn test_parse_chain_spec_groups() {
let ig = Chain::parse_chain_spec("ig").unwrap();
assert_eq!(ig, vec![Chain::IGH, Chain::IGK, Chain::IGL]);
let tcr = Chain::parse_chain_spec("tcr").unwrap();
assert_eq!(tcr, vec![Chain::TRA, Chain::TRB, Chain::TRG, Chain::TRD]);
let all = Chain::parse_chain_spec("all").unwrap();
assert_eq!(all.len(), 7);
}
#[test]
fn test_parse_chain_spec_csv() {
let chains = Chain::parse_chain_spec("h,k,l").unwrap();
assert_eq!(chains, vec![Chain::IGH, Chain::IGK, Chain::IGL]);
}
#[test]
fn test_parse_chain_spec_invalid() {
assert!(Chain::parse_chain_spec("xyz").is_err());
}
}