immunum 1.1.2

Fast antibody and T-cell receptor numbering in Rust and Python
Documentation
use postcard::{from_bytes, to_allocvec};
use std::str::FromStr;

use pyo3::prelude::*;
use pyo3::types::PyDict;

use crate::annotator::Annotator;
use crate::numbering::segment;
use crate::types::{Chain, Scheme};

#[pymethods]
impl Annotator {
    // python methods
    #[new]
    #[pyo3(signature = (chains, scheme, min_confidence=None))]
    pub fn init(
        chains: Vec<String>,
        scheme: String,
        min_confidence: Option<f32>,
    ) -> PyResult<Self> {
        let parsed_chains = chains
            .iter()
            .map(|chain| {
                Chain::from_str(chain).map_err(|_| {
                    PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
                        "Invalid chain: {}",
                        chain
                    ))
                })
            })
            .collect::<Result<Vec<_>, _>>()?;
        let parsed_scheme = Scheme::from_str(&scheme).map_err(|_| {
            PyErr::new::<pyo3::exceptions::PyValueError, _>(format!("Invalid scheme: {}", scheme))
        })?;
        let annotator = Annotator::new(parsed_chains.as_slice(), parsed_scheme, min_confidence)
            .map_err(|e| {
                PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
                    "Failed to initialize annotator: {}",
                    e
                ))
            })?;

        Ok(annotator)
    }

    #[pyo3(signature = (sequence), name = "number")]
    pub fn _number<'py>(&self, py: Python<'py>, sequence: &str) -> PyResult<Bound<'py, PyDict>> {
        let dict = PyDict::new(py);
        match self.number(sequence) {
            Ok(result) => {
                let numbering = PyDict::new(py);
                let aligned_seq = &sequence[result.query_start..=result.query_end];
                for (pos, ch) in result.positions.iter().zip(aligned_seq.chars()) {
                    numbering.set_item(pos.to_string(), ch.to_string())?;
                }
                dict.set_item("chain", result.chain.to_string())?;
                dict.set_item("scheme", result.scheme.to_string())?;
                dict.set_item("confidence", result.confidence)?;
                dict.set_item("numbering", numbering)?;
                dict.set_item("query_start", result.query_start)?;
                dict.set_item("query_end", result.query_end)?;
                dict.set_item("error", py.None())?;
            }
            Err(e) => {
                dict.set_item("chain", py.None())?;
                dict.set_item("scheme", py.None())?;
                dict.set_item("confidence", py.None())?;
                dict.set_item("numbering", py.None())?;
                dict.set_item("query_start", py.None())?;
                dict.set_item("query_end", py.None())?;
                dict.set_item("error", e.to_string())?;
            }
        }
        Ok(dict)
    }

    #[pyo3(signature = (sequence), name = "segment")]
    pub fn _segment<'py>(&self, py: Python<'py>, sequence: &str) -> PyResult<Bound<'py, PyDict>> {
        let dict = PyDict::new(py);
        match self.number(sequence) {
            Ok(result) => {
                let aligned_seq = &sequence[result.query_start..=result.query_end];
                let segments = segment(&result.positions, aligned_seq, result.scheme);
                for (region, seq) in segments {
                    dict.set_item(region, seq)?;
                }
                dict.set_item("error", py.None())?;
            }
            Err(e) => {
                dict.set_item("error", e.to_string())?;
            }
        }
        Ok(dict)
    }

    pub fn __setstate__(
        &mut self,
        state: &pyo3::Bound<'_, pyo3::types::PyBytes>,
    ) -> pyo3::PyResult<()> {
        let annotator: Annotator = from_bytes(state.as_bytes()).unwrap();
        self.matrices = annotator.matrices;
        self.scheme = annotator.scheme;
        self.chains = annotator.chains;
        self.min_confidence = annotator.min_confidence;
        Ok(())
    }

    pub fn __getstate__<'py>(
        &self,
        py: pyo3::Python<'py>,
    ) -> pyo3::PyResult<pyo3::Bound<'py, pyo3::types::PyBytes>> {
        Ok(pyo3::types::PyBytes::new(py, &to_allocvec(&self).unwrap()))
    }

    pub fn __getnewargs__(&self) -> pyo3::PyResult<(Vec<String>, String)> {
        Ok((
            self.chains
                .clone()
                .iter()
                .map(move |s| s.to_string())
                .collect(),
            self.scheme.to_string().clone(),
        ))
    }
}

#[pymodule]
fn _internal(_py: Python, m: &Bound<PyModule>) -> PyResult<()> {
    m.add("__version__", env!("CARGO_PKG_VERSION"))?;
    m.add_class::<Annotator>()?;
    Ok(())
}