twitcher 0.1.9

Find template switch mutations in genomic data
use std::{
    fs::File,
    io::{self, BufWriter, Read},
    path::PathBuf,
    sync::Arc,
};

use anyhow::anyhow;
use lib_tsalign::{
    a_star_aligner::{
        alignment_result::{AlignmentResult, AlignmentStatistics, a_star_sequences},
        template_switch_distance::AlignmentType,
    },
    costs::U64Cost,
};
use lib_tsshow::svg::SvgConfig;
use tracing::{error, info};

use crate::{
    RunnableCommand,
    common::{
        ImmutableSequence, MutableSequence,
        alignment::{ForwardAlignment, cigar_to_alignment},
        coords::GenomeRegion,
        csv::CSVRecord,
        reference::{CliReferenceArg, ReferenceReader},
    },
};

#[derive(Debug, clap::Args)]
pub struct Command {
    /// The input csv file. If omitted, read from stdin.
    input: Option<PathBuf>,

    #[arg(short, long)]
    /// The output directory, where one svg file per csv line will be created
    output_dir: PathBuf,

    #[arg(long)]
    reference: PathBuf,

    /// Render arrows
    #[arg(long)]
    arrows: bool,

    /// Render more complement
    #[arg(long)]
    more_complement: bool,

    /// Overwrite existing output files
    #[arg(long)]
    overwrite: bool,

    /// IDs to process. Accepts individual IDs, ranges (10..15, ..15, 10..), and comma-separated
    /// combinations (4,5,10..15). Can be repeated: --ids 4 --ids 10..15
    #[arg(long)]
    ids: Vec<String>,
}

impl RunnableCommand for Command {
    async fn run(self) -> anyhow::Result<()> {
        let cli_ref_args = CliReferenceArg::from(self.reference.as_os_str().to_str().unwrap());
        let reference_reader = ReferenceReader::try_from(&cli_ref_args)?;
        if !self.output_dir.exists() {
            info!("Creating output folder {}", self.output_dir.display());
            tokio::fs::create_dir(&self.output_dir).await?;
        }

        let input = if let Some(path) = &self.input {
            Box::new(File::open(path)?) as Box<dyn Read + Send>
        } else {
            Box::new(io::stdin())
        };

        let id_filters = parse_ids(&self.ids)?;

        let reader = csv::Reader::from_reader(input);

        for record in reader.into_deserialize::<CSVRecord>() {
            match record {
                Ok(record) => {
                    let id = record.cluster_id.parse::<usize>()?;
                    if !id_filters.is_empty() && !id_filters.iter().any(|f| f.matches(id)) {
                        continue;
                    }
                    if let Err(e) = self.process_record(&record, &reference_reader).await {
                        // log error, but continue with the next record.
                        error!("{e}");
                    }
                }
                Err(e) => error!("Unable to read record: {e}"),
            }
        }

        Ok(())
    }
}

impl Command {
    async fn process_record(
        &self,
        record: &CSVRecord,
        reference: &ReferenceReader,
    ) -> anyhow::Result<()> {
        let path = self
            .output_dir
            .join(&record.cluster_id)
            .with_extension("svg");

        if path.exists() && !self.overwrite {
            info!("Skipping existing file {}", path.display());
            return Ok(());
        }

        let sequences = record.get_sequences(reference).await?;

        let ts_alignment = record.get_ts_alignment_result(sequences.clone())?;
        let fw_alignment = record.get_forward_alignment_result(sequences)?;

        let out = BufWriter::new(File::create(&path)?);

        lib_tsshow::svg::create_ts_svg(
            out,
            &ts_alignment,
            &Some(fw_alignment),
            &SvgConfig {
                render_arrows: self.arrows,
                render_more_complement: self.more_complement,
            },
        )?;
        Ok(())
    }
}

impl CSVRecord {
    async fn get_sequences(
        &self,
        reference: &ReferenceReader,
    ) -> anyhow::Result<a_star_sequences::SequencePair> {
        let ref_region = GenomeRegion::parse(self.ref_ctx_region.as_bytes())?;
        let ref_seq = reference.get_seq_exact_unmasked(ref_region).await?;

        let forward_alignment = ForwardAlignment(cigar_to_alignment(&self.fw_cigar_ctx)?);
        let query_seq = apply_cigar_and_mi(&ref_seq, &forward_alignment, &self.fw_mi_ctx);

        Ok(a_star_sequences::SequencePair {
            reference_name: reference.get_name().to_string(),
            reference: String::from_utf8_lossy(&ref_seq).into_owned(),
            reference_rc: String::from_utf8_lossy(&reverse_complement(&ref_seq)).into_owned(),
            query_name: self.read_id.clone().unwrap_or_default(),
            query: String::from_utf8_lossy(&query_seq).into_owned(),
            query_rc: String::from_utf8_lossy(&reverse_complement(&query_seq)).into_owned(),
        })
    }

    fn get_ts_alignment_result(
        &self,
        sequences: a_star_sequences::SequencePair,
    ) -> anyhow::Result<AlignmentResult<AlignmentType, U64Cost>> {
        let alignment = cigar_to_alignment(&self.ts_cigar)?;
        let statistics: AlignmentStatistics<U64Cost> = AlignmentStatistics {
            result: generic_a_star::AStarResult::FoundTarget {
                identifier: (),
                cost: self.ts_cost.into(),
            },
            sequences,
            reference_offset: self.ref_cluster_offset,
            query_offset: self.alt_cluster_offset,
            cost: (self.ts_cost as f64).try_into().unwrap(),
            // Everything after this is just filled with dummy values.
            cost_per_base: 0f64.try_into().unwrap(),
            duration_seconds: 0f64.try_into().unwrap(),
            opened_nodes: 0f64.try_into().unwrap(),
            closed_nodes: 0f64.try_into().unwrap(),
            suboptimal_opened_nodes: 0f64.try_into().unwrap(),
            suboptimal_opened_nodes_ratio: 0f64.try_into().unwrap(),
            template_switch_amount: (self.ts_num as f64).try_into().unwrap(),
            runtime: 0f64.try_into().unwrap(),
            memory: 0f64.try_into().unwrap(),
        };
        Ok(AlignmentResult::WithTarget {
            alignment,
            statistics,
        })
    }

    fn get_forward_alignment_result(
        &self,
        sequences: a_star_sequences::SequencePair,
    ) -> anyhow::Result<AlignmentResult<AlignmentType, U64Cost>> {
        let alignment = cigar_to_alignment(&self.fw_cigar)?;
        let statistics: AlignmentStatistics<U64Cost> = AlignmentStatistics {
            result: generic_a_star::AStarResult::FoundTarget {
                identifier: (),
                cost: self.fw_cost.into(),
            },
            sequences,
            reference_offset: self.ref_cluster_offset,
            query_offset: self.alt_cluster_offset,
            cost: (self.fw_cost as f64).try_into().unwrap(),
            // Everything after this is just filled with dummy values.
            cost_per_base: 0f64.try_into().unwrap(),
            duration_seconds: 0f64.try_into().unwrap(),
            opened_nodes: 0f64.try_into().unwrap(),
            closed_nodes: 0f64.try_into().unwrap(),
            suboptimal_opened_nodes: 0f64.try_into().unwrap(),
            suboptimal_opened_nodes_ratio: 0f64.try_into().unwrap(),
            template_switch_amount: (self.ts_num as f64).try_into().unwrap(),
            runtime: 0f64.try_into().unwrap(),
            memory: 0f64.try_into().unwrap(),
        };
        Ok(AlignmentResult::WithTarget {
            alignment,
            statistics,
        })
    }
}

fn reverse_complement(seq: &ImmutableSequence) -> ImmutableSequence {
    let rc: MutableSequence = seq
        .iter()
        .rev()
        .map(|&b| match b {
            b'A' => b'T',
            b'T' => b'A',
            b'C' => b'G',
            b'G' => b'C',
            other => other,
        })
        .collect();
    Arc::from(rc.as_slice())
}

fn apply_cigar_and_mi(
    reference: &ImmutableSequence,
    forward_alignment: &ForwardAlignment,
    mi_string: &str,
) -> ImmutableSequence {
    let mut mi_bases = mi_string.bytes().filter(u8::is_ascii_alphabetic);
    let mut query = MutableSequence::new();
    let mut ref_idx = 0usize;

    for op in forward_alignment.iter_flat_cloned() {
        match op {
            AlignmentType::PrimaryMatch | AlignmentType::PrimaryFlankMatch => {
                query.push(reference[ref_idx]);
                ref_idx += 1;
            }
            AlignmentType::PrimarySubstitution | AlignmentType::PrimaryFlankSubstitution => {
                if let Some(base) = mi_bases.next() {
                    query.push(base);
                }
                ref_idx += 1;
            }
            AlignmentType::PrimaryInsertion | AlignmentType::PrimaryFlankInsertion => {
                if let Some(base) = mi_bases.next() {
                    query.push(base);
                }
            }
            AlignmentType::PrimaryDeletion | AlignmentType::PrimaryFlankDeletion => {
                ref_idx += 1;
            }
            _ => {}
        }
    }

    Arc::from(query.as_slice())
}

enum IdFilter {
    Single(usize),
    Range(Option<usize>, Option<usize>),
}

impl IdFilter {
    fn matches(&self, id: usize) -> bool {
        match self {
            IdFilter::Single(n) => *n == id,
            IdFilter::Range(lower, upper) => {
                lower.is_none_or(|lb| lb <= id) && upper.is_none_or(|ub| ub >= id)
            }
        }
    }
}

fn parse_ids(args: &[String]) -> anyhow::Result<Vec<IdFilter>> {
    args.iter()
        .flat_map(|arg| arg.split(','))
        .map(str::trim)
        .filter(|token| !token.is_empty())
        .map(parse_id_token)
        .collect()
}

fn parse_id_token(token: &str) -> anyhow::Result<IdFilter> {
    if let Ok(num) = token.parse::<usize>() {
        return Ok(IdFilter::Single(num));
    }

    if let Some((first, second)) = token.split_once("..") {
        fn parse_bound(s: &str) -> anyhow::Result<Option<usize>> {
            Ok(if s.is_empty() {
                None
            } else {
                Some(s.parse::<usize>()?)
            })
        }
        return Ok(IdFilter::Range(parse_bound(first)?, parse_bound(second)?));
    }

    Err(anyhow!(
        "Cannot parse '{token}' as an id. Valid forms: `1`, `10..`, `..15`, `10..15`"
    ))
}