#[cfg(feature = "python")]
use pyo3::prelude::*;
use crate::error::{Error, Result};
use crate::types::Chain;
use serde::{Deserialize, Serialize};
const DEFAULT_SCORE: f32 = -4.0;
#[cfg_attr(feature = "python", pyclass(get_all))]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ScoringMatrix {
pub positions: Vec<PositionScores>,
}
#[cfg_attr(feature = "python", pyclass(get_all))]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PositionScores {
pub position: u8,
pub scores: [f32; 26],
pub gap_penalty: f32,
pub insertion_penalty: f32,
pub max_score: f32,
pub counts_for_confidence: bool,
}
impl PositionScores {
#[inline(always)]
pub fn score_for(&self, aa: u8) -> f32 {
let idx = aa.wrapping_sub(b'A') as usize;
if idx < 26 {
self.scores[idx]
} else {
DEFAULT_SCORE
}
}
}
impl ScoringMatrix {
pub fn load(chain: Chain) -> Result<Self> {
let json = match chain {
Chain::IGH => include_str!(concat!(env!("OUT_DIR"), "/matrices/IGH.json")),
Chain::IGK => include_str!(concat!(env!("OUT_DIR"), "/matrices/IGK.json")),
Chain::IGL => include_str!(concat!(env!("OUT_DIR"), "/matrices/IGL.json")),
Chain::TRA => include_str!(concat!(env!("OUT_DIR"), "/matrices/TRA.json")),
Chain::TRB => include_str!(concat!(env!("OUT_DIR"), "/matrices/TRB.json")),
Chain::TRG => include_str!(concat!(env!("OUT_DIR"), "/matrices/TRG.json")),
Chain::TRD => include_str!(concat!(env!("OUT_DIR"), "/matrices/TRD.json")),
};
serde_json::from_str(json).map_err(|e| {
Error::ConsensusParseError(format!("Failed to parse scoring matrix: {}", e))
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_load_igh_matrix() {
let matrix = ScoringMatrix::load(Chain::IGH).unwrap();
assert!(!matrix.positions.is_empty());
assert!(matrix.positions.len() > 100);
}
#[test]
fn test_load_all_matrices() {
for chain in [
Chain::IGH,
Chain::IGK,
Chain::IGL,
Chain::TRA,
Chain::TRB,
Chain::TRG,
Chain::TRD,
] {
let matrix = ScoringMatrix::load(chain).unwrap();
assert!(!matrix.positions.is_empty());
}
}
#[test]
fn test_gap_penalties() {
let matrix = ScoringMatrix::load(Chain::IGH).unwrap();
for pos in &matrix.positions {
assert!(
pos.gap_penalty <= 0.0,
"Gap in query penalty should be non-positive,"
);
assert!(
pos.insertion_penalty < 0.0,
"Gap in consensus penalty should be negative"
);
}
}
#[test]
fn test_scores_reasonable() {
let matrix = ScoringMatrix::load(Chain::IGH).unwrap();
for pos in &matrix.positions {
for (i, &score) in pos.scores.iter().enumerate() {
let aa = (b'A' + i as u8) as char;
assert!(
(-10.0..=15.0).contains(&score),
"Score for {} at position {} is out of range: {}",
aa,
pos.position,
score
);
}
}
}
}