twitcher 0.2.2

Find template switch mutations in genomic data
use std::{
    fs::File,
    io::{self, BufRead, BufReader, BufWriter, Read},
    path::{Path, PathBuf},
    sync::Arc,
};

use lib_tsalign::{
    a_star_aligner::{
        alignment_result::{AlignmentResult, AlignmentStatistics, a_star_sequences},
        template_switch_distance::AlignmentType,
    },
    costs::U64Cost,
};
use lib_tsshow::svg::SvgConfig;
use tracing::{error, info};

use crate::{
    RunnableCommand,
    common::{
        ImmutableSequence, MutableSequence,
        alignment::{ForwardAlignment, cigar_to_alignment},
        coords::GenomeRegion,
        csv::CSVRecord,
        reference::{CliReferenceArg, ReferenceReader},
    },
    counter,
};

#[allow(clippy::struct_excessive_bools)]
#[derive(Debug, clap::Args)]
pub struct Command {
    /// The input csv file. If omitted, read from stdin.
    input: Option<PathBuf>,

    #[arg(short = 'o', long = "output-dir")]
    /// The output directory, where one svg file per csv line will be created
    output_dir: PathBuf,

    #[arg(short = 'r', long)]
    reference: PathBuf,

    /// Render arrows
    #[arg(short = 'a', long = "arrows")]
    arrows: bool,

    /// Render more complement
    #[arg(short = 'c', long = "more-complement")]
    more_complement: bool,

    /// "Zoom" in to show only the specified number of bases as context.
    /// Default: no zoom
    #[arg(short = 'z', long = "zoom")]
    context: Option<usize>,

    /// Visualize the equal-cost ranges.
    #[arg(short = 'e', long = "equal-cost-ranges")]
    visualise_equal_cost_ranges: bool,

    /// Overwrite existing output files
    #[arg(short = 'f', long = "overwrite")]
    overwrite: bool,

    /// IDs to process. Accepts individual IDs or a comma-separated list. See also -I
    #[arg(short = 'i', long)]
    ids: Vec<String>,

    /// File with IDs to process. One ID per line.
    #[arg(short = 'I', long)]
    id_file: Option<PathBuf>,
}

impl RunnableCommand for Command {
    async fn run(self) -> anyhow::Result<()> {
        let cli_ref_args = CliReferenceArg::from(self.reference.as_os_str().to_str().unwrap());
        let reference_reader = ReferenceReader::try_from(&cli_ref_args)?;
        if !self.output_dir.exists() {
            info!("Creating output folder {}", self.output_dir.display());
            tokio::fs::create_dir(&self.output_dir).await?;
        }

        let input = if let Some(path) = &self.input {
            Box::new(File::open(path)?) as Box<dyn Read + Send>
        } else {
            Box::new(io::stdin())
        };

        let ids = parse_ids(&self.ids, self.id_file.as_ref())?;

        let reader = csv::Reader::from_reader(input);

        for record in reader.into_deserialize::<CSVRecord>() {
            match record {
                Ok(record) => {
                    if ids
                        .as_ref()
                        .is_some_and(|ids| ids.binary_search(&record.id).is_err())
                    {
                        counter!("viz.filtered").inc(1);
                        continue;
                    }
                    if let Err(e) = self.process_record(&record, &reference_reader).await {
                        counter!("viz.error").inc(1);
                        error!("{e}");
                    }
                }
                Err(e) => error!("Unable to read record: {e}"),
            }
        }

        Ok(())
    }
}

impl Command {
    async fn process_record(
        &self,
        record: &CSVRecord,
        reference: &ReferenceReader,
    ) -> anyhow::Result<()> {
        let path = self.output_dir.join(&record.id).with_extension("svg");

        if path.exists() && !self.overwrite {
            info!("Skipping existing file {}", path.display());
            counter!("viz.skipped").inc(1);
            return Ok(());
        }

        let sequences = record.get_sequences(reference).await?;

        let ts_alignment = record.get_ts_alignment_result(sequences.clone())?;
        let fw_alignment = record.get_forward_alignment_result(sequences)?;

        let out = BufWriter::new(File::create(&path)?);

        lib_tsshow::svg::create_ts_svg(
            out,
            &ts_alignment,
            &Some(fw_alignment),
            &SvgConfig {
                render_arrows: self.arrows,
                render_more_complement: self.more_complement,
                restrict_context: self.context,
                visualise_equal_cost_ranges: self.visualise_equal_cost_ranges,
            },
        )?;
        counter!("viz.success").inc(1);
        Ok(())
    }
}

impl CSVRecord {
    async fn get_sequences(
        &self,
        reference: &ReferenceReader,
    ) -> anyhow::Result<a_star_sequences::SequencePair> {
        let ref_region = GenomeRegion::parse(self.ref_ctx_region.as_bytes())?;
        let ref_seq = reference.get_seq_exact_unmasked(ref_region).await?;

        let forward_alignment = ForwardAlignment(cigar_to_alignment(&self.fw_cigar_ctx)?);
        let query_seq = apply_cigar_and_mi(&ref_seq, &forward_alignment, &self.fw_mi_ctx);

        Ok(a_star_sequences::SequencePair {
            reference_name: reference.get_name().to_string(),
            reference: String::from_utf8_lossy(&ref_seq).into_owned(),
            reference_rc: String::from_utf8_lossy(&reverse_complement(&ref_seq)).into_owned(),
            query_name: self.read_id.clone().unwrap_or_default(),
            query: String::from_utf8_lossy(&query_seq).into_owned(),
            query_rc: String::from_utf8_lossy(&reverse_complement(&query_seq)).into_owned(),
        })
    }

    fn get_ts_alignment_result(
        &self,
        sequences: a_star_sequences::SequencePair,
    ) -> anyhow::Result<AlignmentResult<AlignmentType, U64Cost>> {
        let alignment = cigar_to_alignment(&self.ts_cigar)?;
        let statistics: AlignmentStatistics<U64Cost> = AlignmentStatistics {
            result: generic_a_star::AStarResult::FoundTarget {
                identifier: (),
                cost: self.ts_cost.into(),
            },
            sequences,
            reference_offset: self.ref_cluster_offset,
            query_offset: self.alt_cluster_offset,
            cost: (self.ts_cost as f64).try_into().unwrap(),
            // Everything after this is just filled with dummy values.
            cost_per_base: 0f64.try_into().unwrap(),
            duration_seconds: 0f64.try_into().unwrap(),
            opened_nodes: 0f64.try_into().unwrap(),
            closed_nodes: 0f64.try_into().unwrap(),
            suboptimal_opened_nodes: 0f64.try_into().unwrap(),
            suboptimal_opened_nodes_ratio: 0f64.try_into().unwrap(),
            template_switch_amount: (self.ts_num as f64).try_into().unwrap(),
            runtime: 0f64.try_into().unwrap(),
            memory: 0f64.try_into().unwrap(),
        };
        Ok(AlignmentResult::WithTarget {
            alignment,
            statistics,
        })
    }

    fn get_forward_alignment_result(
        &self,
        sequences: a_star_sequences::SequencePair,
    ) -> anyhow::Result<AlignmentResult<AlignmentType, U64Cost>> {
        let alignment = cigar_to_alignment(&self.fw_cigar)?;
        let statistics: AlignmentStatistics<U64Cost> = AlignmentStatistics {
            result: generic_a_star::AStarResult::FoundTarget {
                identifier: (),
                cost: self.fw_cost.into(),
            },
            sequences,
            reference_offset: self.ref_cluster_offset,
            query_offset: self.alt_cluster_offset,
            cost: (self.fw_cost as f64).try_into().unwrap(),
            // Everything after this is just filled with dummy values.
            cost_per_base: 0f64.try_into().unwrap(),
            duration_seconds: 0f64.try_into().unwrap(),
            opened_nodes: 0f64.try_into().unwrap(),
            closed_nodes: 0f64.try_into().unwrap(),
            suboptimal_opened_nodes: 0f64.try_into().unwrap(),
            suboptimal_opened_nodes_ratio: 0f64.try_into().unwrap(),
            template_switch_amount: (self.ts_num as f64).try_into().unwrap(),
            runtime: 0f64.try_into().unwrap(),
            memory: 0f64.try_into().unwrap(),
        };
        Ok(AlignmentResult::WithTarget {
            alignment,
            statistics,
        })
    }
}

fn reverse_complement(seq: &ImmutableSequence) -> ImmutableSequence {
    let rc: MutableSequence = seq
        .iter()
        .rev()
        .map(|&b| match b {
            b'A' => b'T',
            b'T' => b'A',
            b'C' => b'G',
            b'G' => b'C',
            other => other,
        })
        .collect();
    Arc::from(rc.as_slice())
}

fn apply_cigar_and_mi(
    reference: &ImmutableSequence,
    forward_alignment: &ForwardAlignment,
    mi_string: &str,
) -> ImmutableSequence {
    let mut mi_bases = mi_string.bytes().filter(u8::is_ascii_alphabetic);
    let mut query = MutableSequence::new();
    let mut ref_idx = 0usize;

    for op in forward_alignment.iter_flat_cloned() {
        match op {
            AlignmentType::PrimaryMatch | AlignmentType::PrimaryFlankMatch => {
                query.push(reference[ref_idx]);
                ref_idx += 1;
            }
            AlignmentType::PrimarySubstitution | AlignmentType::PrimaryFlankSubstitution => {
                if let Some(base) = mi_bases.next() {
                    query.push(base);
                }
                ref_idx += 1;
            }
            AlignmentType::PrimaryInsertion | AlignmentType::PrimaryFlankInsertion => {
                if let Some(base) = mi_bases.next() {
                    query.push(base);
                }
            }
            AlignmentType::PrimaryDeletion | AlignmentType::PrimaryFlankDeletion => {
                ref_idx += 1;
            }
            _ => {}
        }
    }

    Arc::from(query.as_slice())
}

fn parse_ids<P: AsRef<Path>>(
    args: &[String],
    file: Option<P>,
) -> anyhow::Result<Option<Vec<String>>> {
    let mut ids: Vec<_> = args
        .iter()
        .flat_map(|arg| arg.split(','))
        .map(str::trim)
        .filter(|token| !token.is_empty())
        .map(str::to_string)
        .collect();
    if let Some(file) = file {
        let path = file.as_ref();
        for line in BufReader::new(File::open(path)?).lines() {
            let line = line?.trim().to_string();
            ids.push(line);
        }
    }
    if ids.is_empty() {
        return Ok(None);
    }
    ids.sort();
    Ok(Some(ids))
}