twitcher 0.1.8

Find template switch mutations in genomic data
use std::collections::HashMap;
use std::hash::Hash;
use std::io::Write;
use std::process::Stdio;
use std::sync::atomic::{AtomicBool, AtomicU8};
use std::sync::{LazyLock, Mutex, MutexGuard};
use std::time::{Duration};
use std::sync::Arc;

use lib_tsalign::a_star_aligner::{
        alignment_geometry::AlignmentRange,
        configurable_a_star_align::Aligner,
    } ;
    use serde::{Deserialize, Serialize};
use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt};
use tokio::sync::watch::error::RecvError;
use tokio::sync::{watch, Semaphore};
use tracing::{error, instrument};

use crate::common::aligner::db::{Database, StaticAlignmentKey};
use crate::common::aligner::result::{SoftFailureReason, TwitcherAlignment, TwitcherAlignmentCase};
use crate::common::aligner::{
    fpa::FourPointAligner,
    result::{AlignmentFailure, TwitcherAlignmentResult},
};
use crate::common::coords::GenomeRegion;
use crate::common::{ImmutableSequence, SequencePair};
use crate::counter;
use crate::worker::{WorkerQuery, WorkerQueryMetadata};

mod db;
mod fpa;
pub mod cli;
pub mod result;

/// Use this to monitor how many tasks are running
pub static RUNNING: LazyLock<AtomicU8> = LazyLock::new(|| AtomicU8::new(0));

pub struct InMemoryCache {
    in_progress: HashMap<AlignmentKey, watch::Sender<Option<Arc<TwitcherAlignmentResult>>>>,
    finished: HashMap<AlignmentKey, Arc<TwitcherAlignmentResult>>,
}

pub struct AlignmentOrchestrator {
    aligners: Arc<AlignerSelector>,
    parallelism: Arc<Semaphore>,
    per_alignment_settings: PerAlignmentSettings,
    in_memory_cache: Option<Arc<Mutex<InMemoryCache>>>,
    database: Option<Arc<Mutex<Database>>>,
}

#[derive(PartialEq, Eq, Clone)]
pub struct AlignmentKey {
    // implicit part of key:
    // - reference
    // - aligner plus settings 
    // - cost function
    reference_region: GenomeRegion,
    alignment_ranges: AlignmentRange,
    query_sequence: ImmutableSequence,
}

impl Hash for AlignmentKey {
    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
        self.reference_region.hash(state);
        self.alignment_ranges.reference_offset().hash(state);
        self.alignment_ranges.reference_limit().hash(state);
        self.alignment_ranges.query_offset().hash(state);
        self.alignment_ranges.query_limit().hash(state);
        self.query_sequence.hash(state);
    }
}

pub struct PerAlignmentSettings {
    memory_allowance: usize,
    timeout: Option<Duration>,
}

pub struct InProgress {
    receiver: watch::Receiver<Option<Arc<TwitcherAlignmentResult>>>,
}

impl From<watch::Receiver<Option<Arc<TwitcherAlignmentResult>>>> for InProgress {
    fn from(receiver: watch::Receiver<Option<Arc<TwitcherAlignmentResult>>>) -> Self {
        Self { receiver }
    }
}


impl InProgress {
    pub async fn recv(&mut self) -> Result<Arc<TwitcherAlignmentResult>, RecvError> {
        match &*self.receiver.wait_for(Option::is_some).await? {
            Some(arc) => {
                count_result(arc);
                Ok(arc.clone())
            },
            None => unreachable!(),
        }
    }
}

impl AlignmentOrchestrator {
    pub fn enable_cache(&mut self) {
        let cache = InMemoryCache {
          in_progress: HashMap::new(),
          finished: HashMap::new(),  
        };
        self.in_memory_cache = Some(Arc::new(Mutex::new(cache)));
    }
    
    fn lock_cache(&self) -> Option<MutexGuard<'_, InMemoryCache>> {
        self.in_memory_cache.as_ref().map(|c| c.lock().unwrap())
    }

    fn lock_database(&self) -> Option<MutexGuard<'_, Database>> {
        self.database.as_ref().map(|c| c.lock().unwrap())
    }
    
    #[instrument(name = "get_alignment", skip_all, fields(pos = %cluster_region))]
    pub fn get_or_compute_alignment(
        &self,
        reference_sequence_name: &str,
        reference_region: &GenomeRegion,
        cluster_region: GenomeRegion,
        query: AlignmentQuery,
    ) -> anyhow::Result<InProgress> {
        counter!("alignments").inc(1);
        let key = AlignmentKey {
            reference_region: reference_region.clone(),
            alignment_ranges: query.ranges.clone(),
            query_sequence: query.sequences.query.clone(),
        };

        let cache_lock = self.lock_cache();
        
        if let Some(ref cache) = cache_lock {
            if let Some(sender) = cache.in_progress.get(&key) {
                // Alignment is already in progress.
                counter!("alignments.from_cache").inc(1);
                return Ok(sender.subscribe().into());
            }

            if let Some(result) = cache.finished.get(&key) {
                // We have a finished result at hand.
                counter!("alignments.from_cache").inc(1);
                let (_, rx) = watch::channel(Some(result.clone()));
                return Ok(rx.into());
            }
        }

        // Next, check if the DB has a result
        if let Some(mut db) = self.lock_database() {
            if db.needs_init() {
                db.init_with_config(&StaticAlignmentKey {
                    reference_name: reference_sequence_name,
                    aligner_config: self.aligners.describe()?,
                })?;
            }
            if let Ok(Some(result)) = tokio::task::block_in_place(|| db.lookup(&key)) {
                counter!("alignments.from_db").inc(1);
                let result = Arc::new(result);
                // write to in-memory cache
                if let Some(mut cache) = cache_lock {
                    cache.finished.insert(key.clone(), result.clone());
                }
                let (_, rx) = watch::channel(Some(result));
                return Ok(rx.into());
            }
        }

        let metadata = WorkerQueryMetadata { cluster_region };
        let (tx, rx) = watch::channel(None);

        if let Some(mut cache) = cache_lock {
            cache.in_progress.insert(key.clone(), tx.clone());
        }

        self.start_realignment_with_callback(query, metadata, key,  tx);
        Ok(rx.into())
    }

    fn start_realignment_with_callback(
        &self,
        query: AlignmentQuery,
        metadata: WorkerQueryMetadata,
        key: AlignmentKey,
        sender: watch::Sender<Option<Arc<TwitcherAlignmentResult>>>,
    ) {
        let aligner = self.aligners.clone();
        let log_level = *crate::THIS_LOG_LEVEL.get_or_init(Default::default);
        let memory = self.per_alignment_settings.memory_allowance;
        let timeout = self.per_alignment_settings.timeout;
        
        let queue = self.parallelism.clone().acquire_owned();
        let state_mutex = self.in_memory_cache.clone();
        let db_mutex = self.database.clone();
        tokio::spawn(async move {
            let Ok(_permit) = queue.await else {
                error!("Cannot aquire concurrency permit to align {}", metadata.cluster_region);
                return;
            };
            RUNNING.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
            let worker_query = WorkerQuery {
                aligner: (&*aligner).into(),
                log_level,
                memory,
                query,
                metadata,
            };

            counter!("alignments.computations").inc(1);
            let res = Arc::new(Self::run_alignment(&worker_query, timeout).await);

            let mut state = state_mutex.as_ref().map(|c| c.lock().unwrap());
            if let Some(ref mut cache) = state {cache.in_progress.remove(&key);}
            let _ = sender
                .send(Some(res.clone()))
                .inspect_err(|e| error!("Error sending result: {e:?}"));
            if let Some(mut db) = db_mutex.as_ref().map(|db| db.lock().unwrap()) {
                if let Err(e) = tokio::task::block_in_place(|| db.store(key.clone(), res.clone())) {
                    error!("Can't write alignment to database: {e}");
                }
            }
            if let Some(ref mut cache) = state {
                cache.finished.insert(key.clone(), res);
            }
            RUNNING.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
        });
    }

    #[instrument(skip_all)]
    async fn run_alignment(wq: &WorkerQuery<'_>, timeout: Option<Duration>) -> TwitcherAlignmentResult {
        let mut cmd = tokio::process::Command::new(&*crate::THIS_EXE);
        cmd.arg("worker");
        cmd.stdin(Stdio::piped());
        cmd.stdout(Stdio::piped());
        cmd.stderr(Stdio::piped());

        let mut child = cmd.spawn()?;

        let mut stdin = child.stdin.take().unwrap();
        let msg = rmp_serde::to_vec(&wq).map_err(|e| {
            AlignmentFailure::error(&format!("Can't encode and write worker query: {e}"))
        })?;
        stdin.write_all(&msg).await?;
        std::mem::drop(stdin);

        let stderr = child.stderr.take().unwrap();
        let mut stderr_writer_handle = None;
        let oom = Arc::new(AtomicBool::new(false));
        let oom2 = oom.clone();
        if let Some(mut log_w) = crate::STDERR_LOG_WRITER.get().cloned() {
            stderr_writer_handle = Some(tokio::spawn(async move {
                static WARNED_ALREADY_ABOUT_FW_DIR: AtomicBool = AtomicBool::new(false);
                // Forward any stderr lines to the stderr writer that lies on top of the progress bar
                let mut err_br = tokio::io::BufReader::new(stderr).lines();
                loop {
                    let l = match err_br.next_line().await {
                        Ok(Some(l)) => l,
                        Ok(None) => break,
                        Err(e) => {
                            error!("{e}");
                            continue;
                        },
                    };
                    if l.contains("memory allocation of") {
                        oom2.store(true, std::sync::atomic::Ordering::Relaxed);
                        continue;
                    }
                    if l.contains("Forward direction not yet supported in PreprocessedTemplateSwitchMinLengthStrategy") {
                        if WARNED_ALREADY_ABOUT_FW_DIR.load(std::sync::atomic::Ordering::Relaxed) {
                           continue; 
                        }
                        WARNED_ALREADY_ABOUT_FW_DIR.store(true, std::sync::atomic::Ordering::Relaxed);
                    }

                    let _ = tokio::task::block_in_place(|| writeln!(log_w, "{l}"));
                }
            }));
        }

        let exit = match tokio::time::timeout(timeout.unwrap_or(Duration::MAX), child.wait()).await {
            Ok(exit) => Some(exit?),
            Err(_elapsed) => {
                child.kill().await?;
                None
            },
        };

        if let Some(h) = stderr_writer_handle {
            let _ = h.await;
        }

        if let Some(exit_status) = exit {
            if exit_status.success() {
                let mut result_bytes = Vec::new();
                child.stdout.as_mut().unwrap().read_to_end(&mut result_bytes).await?;
                // Deserialize the result and return it
                rmp_serde::from_slice::<TwitcherAlignmentResult>(&result_bytes)
                    .map_err(|e| AlignmentFailure::error(&format!("Can't read result: {e:?}")))?
            } else if let Some(exit_code) = exit_status.code() {
                TwitcherAlignmentResult::Err(AlignmentFailure::error(&format!(
                    "Aligner exited with a non-zero exit code: {exit_code}",
                )))
            } else if oom.load(std::sync::atomic::Ordering::Relaxed) {
                TwitcherAlignmentResult::Err(AlignmentFailure::oom())
            } else {
                TwitcherAlignmentResult::Err(AlignmentFailure::error(
                    "Aligner exited abnormally (no exit code). Perhaps it ran out of memory?",
                ))
            }
        } else {
            TwitcherAlignmentResult::Err(AlignmentFailure::timeout(timeout.unwrap()))
        }
    }

    pub fn clear_cache(&self) {
        if let Some(cache) = &self.in_memory_cache {
        let mut cache = cache.lock().unwrap();
        cache.finished.clear();
        }
    }
}

fn count_result(res: &TwitcherAlignmentResult) {
    let key = match res {
        Ok(TwitcherAlignment {
            result: TwitcherAlignmentCase::FoundTS { .. },
            ..
        }) => "alignments.results.successful.with_ts",
        Ok(TwitcherAlignment {
            result: TwitcherAlignmentCase::NoTS { .. },
            ..
        }) => "alignments.results.successful.without_ts",
        Err(AlignmentFailure::SoftFailure {
            reason: SoftFailureReason::OutOfMemory,
        }) => "alignments.results.failed.oom",
        Err(AlignmentFailure::SoftFailure {
            reason: SoftFailureReason::Timeout(_),
        }) => "alignments.results.failed.timeout",
        Err(AlignmentFailure::SoftFailure {
            reason: SoftFailureReason::Other(_),
        }) => "alignments.results.failed",
        Err(AlignmentFailure::Error { .. }) => "alignments.results.error",
    };
    counter!(key).inc(1);
}

impl TryFrom<&cli::CliAlignmentArgs> for AlignmentOrchestrator {
    type Error = anyhow::Error;

    fn try_from(value: &cli::CliAlignmentArgs) -> Result<Self, Self::Error> {
        let (sem, mem_per_thread) = value.init_semaphore()?;
        let alns = value.init_aligner()?;
        let database: Option<Database> = (&value.database).try_into()?;
        Ok(Self {
            aligners: alns.into(),
            parallelism: sem.into(),
            per_alignment_settings: PerAlignmentSettings { memory_allowance: mem_per_thread, timeout: value.aligner_timeout },
            in_memory_cache: None,
            database: database.map(|db| Arc::new(Mutex::new(db))),
        })
    }
}

#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct AlignmentQuery {
    pub sequences: SequencePair,
    pub ranges: AlignmentRange,
}


#[derive(Deserialize, Serialize)]
pub struct AStarAlignerPair {
    pub ts: Aligner,
    pub no_ts: Aligner,
}

#[derive(Deserialize, Serialize)]
#[allow(clippy::large_enum_variant)]
pub enum AlignerSelector {
    AStar(AStarAlignerPair),
    Fpa(FourPointAligner),
}

pub type AlignerSelectorDescription = Vec<u8>;

impl AlignerSelector {
    pub fn describe(&self) -> anyhow::Result<AlignerSelectorDescription> {
        Ok(rmp_serde::to_vec(self)?)
    }
}