nail 0.5.0

nail is an alignment inference tool
use std::time::{Duration, Instant};

use anyhow::bail;
use derive_builder::Builder;
use libnail::{
    align::{
        backward, forward, null_one_score, null_two_score, optimal_accuracy, p_value, posterior,
        structs::{Alignment, AlignmentBuilder, DpMatrixSparse, RowBounds, Trace},
        traceback, Bits,
    },
    structs::{Profile, Sequence},
};

use crate::args::SearchArgs;

use super::StageResult;

pub type AlignStageResult = StageResult<Alignment, AlignStageStats>;

#[derive(Builder, Default)]
#[builder(setter(strip_option), default)]
pub struct AlignStageStats {
    pub forward_cells: usize,
    pub backward_cells: usize,
    pub score: Bits,
    pub p_value: f64,
    pub memory_init_time: Duration,
    pub forward_time: Duration,
    pub backward_time: Duration,
    pub posterior_time: Duration,
    pub optimal_accuracy_time: Duration,
    pub traceback_time: Duration,
    pub null_two_time: Duration,
}

impl AlignStageStatsBuilder {
    fn add_memory_init_time(&mut self, duration: Duration) {
        match self.memory_init_time {
            Some(ref mut time) => *time += duration,
            None => {
                self.memory_init_time(duration);
            }
        }
    }
}

#[derive(Clone)]
pub struct AlignConfig {
    pub do_null_two: bool,
}

impl Default for AlignConfig {
    fn default() -> Self {
        Self { do_null_two: true }
    }
}

pub trait AlignStage: dyn_clone::DynClone + Send + Sync {
    fn run(
        &mut self,
        profile: &mut Profile,
        target: &Sequence,
        bounds: &RowBounds,
    ) -> StageResult<Alignment, AlignStageStats>;
}

dyn_clone::clone_trait_object!(AlignStage);

#[derive(Default, Clone)]
pub struct DefaultAlignStage {
    forward_matrix: DpMatrixSparse,
    backward_matrix: DpMatrixSparse,
    posterior_matrix: DpMatrixSparse,
    optimal_matrix: DpMatrixSparse,
    forward_p_value_threshold: f64,
    target_count: usize,
    config: AlignConfig,
}

impl DefaultAlignStage {
    pub fn new(args: &SearchArgs) -> anyhow::Result<Self> {
        Ok(Self {
            target_count: match args.expert_args.target_database_size {
                Some(size) => size,
                None => {
                    bail!("no target database size")
                }
            },
            forward_p_value_threshold: args.pipeline_args.forward_pvalue_threshold,
            config: AlignConfig {
                do_null_two: !args.expert_args.no_null_two,
            },
            ..Default::default()
        })
    }
}

impl AlignStage for DefaultAlignStage {
    fn run(
        &mut self,
        profile: &mut Profile,
        target: &Sequence,
        bounds: &RowBounds,
    ) -> StageResult<Alignment, AlignStageStats> {
        let mut stats = AlignStageStatsBuilder::default();

        let now = Instant::now();
        self.forward_matrix
            .reuse(target.length, profile.length, bounds);
        stats.memory_init_time(now.elapsed());

        // we use the forward score to compute the final bit score (later)
        let now = Instant::now();

        let raw_forward_score = forward(profile, target, &mut self.forward_matrix, bounds);

        // the denominator is the null one score
        let null_one = null_one_score(target.length);

        let forward_score = (raw_forward_score - null_one).to_bits();

        stats.forward_time(now.elapsed());
        stats.forward_cells(bounds.num_cells);

        // for now we compute the P-value for filtering purposes
        let forward_p_value = p_value(forward_score, profile.fwd_lambda, profile.fwd_tau);
        stats.score(forward_score);
        stats.p_value(forward_p_value);

        if forward_p_value >= self.forward_p_value_threshold {
            return StageResult::Filtered {
                stats: stats.build().unwrap(),
            };
        }

        let now = Instant::now();
        self.backward_matrix
            .reuse(target.length, profile.length, bounds);
        self.posterior_matrix
            .reuse(target.length, profile.length, bounds);
        self.optimal_matrix
            .reuse(target.length, profile.length, bounds);
        stats.add_memory_init_time(now.elapsed());

        let now = Instant::now();
        backward(profile, target, &mut self.backward_matrix, bounds);
        stats.backward_time(now.elapsed());
        stats.backward_cells(bounds.num_cells);

        let now = Instant::now();
        posterior(
            profile,
            &self.forward_matrix,
            &self.backward_matrix,
            &mut self.posterior_matrix,
            bounds,
        );
        stats.posterior_time(now.elapsed());

        let now = Instant::now();
        optimal_accuracy(
            profile,
            &self.posterior_matrix,
            &mut self.optimal_matrix,
            bounds,
        );
        stats.optimal_accuracy_time(now.elapsed());

        let now = Instant::now();
        let mut trace = Trace::new(target.length, profile.length);
        traceback(
            profile,
            &self.posterior_matrix,
            &self.optimal_matrix,
            &mut trace,
            bounds.seq_end,
        );
        stats.traceback_time(now.elapsed());

        let null_two_score = if self.config.do_null_two {
            let now = Instant::now();
            let score = Some(null_two_score(
                &self.posterior_matrix,
                profile,
                target,
                bounds,
            ));
            stats.null_two_time(now.elapsed());
            score
        } else {
            None
        };

        StageResult::Passed {
            data: AlignmentBuilder::default()
                .with_profile(profile)
                .with_target(target)
                .with_database_size(self.target_count)
                .with_cell_count(bounds.num_cells)
                .with_forward_score(forward_score)
                .with_trace(&trace)
                .with_null_two(null_two_score)
                .build()
                .unwrap(),
            stats: stats.build().unwrap(),
        }
    }
}