use std::collections::HashMap;
use std::hash::Hash;
use std::io::Write;
use std::process::Stdio;
use std::sync::atomic::{AtomicBool, AtomicU8};
use std::sync::{LazyLock, Mutex, MutexGuard};
use std::time::{Duration};
use std::sync::Arc;
use lib_tsalign::a_star_aligner::{
alignment_geometry::AlignmentRange,
configurable_a_star_align::Aligner,
} ;
use serde::{Deserialize, Serialize};
use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt};
use tokio::sync::watch::error::RecvError;
use tokio::sync::{watch, Semaphore};
use tracing::{error, instrument};
use crate::common::aligner::db::{Database, StaticAlignmentKey};
use crate::common::aligner::result::{SoftFailureReason, TwitcherAlignment, TwitcherAlignmentCase};
use crate::common::aligner::{
fpa::FourPointAligner,
result::{AlignmentFailure, TwitcherAlignmentResult},
};
use crate::common::coords::GenomeRegion;
use crate::common::{ImmutableSequence, SequencePair};
use crate::counter;
use crate::worker::{WorkerQuery, WorkerQueryMetadata};
mod db;
mod fpa;
pub mod cli;
pub mod result;
pub static RUNNING: LazyLock<AtomicU8> = LazyLock::new(|| AtomicU8::new(0));
pub struct InMemoryCache {
in_progress: HashMap<AlignmentKey, watch::Sender<Option<Arc<TwitcherAlignmentResult>>>>,
finished: HashMap<AlignmentKey, Arc<TwitcherAlignmentResult>>,
}
pub struct AlignmentOrchestrator {
aligners: Arc<AlignerSelector>,
parallelism: Arc<Semaphore>,
per_alignment_settings: PerAlignmentSettings,
in_memory_cache: Option<Arc<Mutex<InMemoryCache>>>,
database: Option<Arc<Mutex<Database>>>,
}
#[derive(PartialEq, Eq, Clone)]
pub struct AlignmentKey {
reference_region: GenomeRegion,
alignment_ranges: AlignmentRange,
query_sequence: ImmutableSequence,
}
impl Hash for AlignmentKey {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.reference_region.hash(state);
self.alignment_ranges.reference_offset().hash(state);
self.alignment_ranges.reference_limit().hash(state);
self.alignment_ranges.query_offset().hash(state);
self.alignment_ranges.query_limit().hash(state);
self.query_sequence.hash(state);
}
}
pub struct PerAlignmentSettings {
memory_allowance: usize,
timeout: Option<Duration>,
}
pub struct InProgress {
receiver: watch::Receiver<Option<Arc<TwitcherAlignmentResult>>>,
}
impl From<watch::Receiver<Option<Arc<TwitcherAlignmentResult>>>> for InProgress {
fn from(receiver: watch::Receiver<Option<Arc<TwitcherAlignmentResult>>>) -> Self {
Self { receiver }
}
}
impl InProgress {
pub async fn recv(&mut self) -> Result<Arc<TwitcherAlignmentResult>, RecvError> {
match &*self.receiver.wait_for(Option::is_some).await? {
Some(arc) => {
count_result(arc);
Ok(arc.clone())
},
None => unreachable!(),
}
}
}
impl AlignmentOrchestrator {
pub fn enable_cache(&mut self) {
let cache = InMemoryCache {
in_progress: HashMap::new(),
finished: HashMap::new(),
};
self.in_memory_cache = Some(Arc::new(Mutex::new(cache)));
}
fn lock_cache(&self) -> Option<MutexGuard<'_, InMemoryCache>> {
self.in_memory_cache.as_ref().map(|c| c.lock().unwrap())
}
fn lock_database(&self) -> Option<MutexGuard<'_, Database>> {
self.database.as_ref().map(|c| c.lock().unwrap())
}
#[instrument(name = "get_alignment", skip_all, fields(pos = %cluster_region))]
pub fn get_or_compute_alignment(
&self,
reference_sequence_name: &str,
reference_region: &GenomeRegion,
cluster_region: GenomeRegion,
query: AlignmentQuery,
) -> anyhow::Result<InProgress> {
counter!("alignments").inc(1);
let key = AlignmentKey {
reference_region: reference_region.clone(),
alignment_ranges: query.ranges.clone(),
query_sequence: query.sequences.query.clone(),
};
let cache_lock = self.lock_cache();
if let Some(ref cache) = cache_lock {
if let Some(sender) = cache.in_progress.get(&key) {
counter!("alignments.from_cache").inc(1);
return Ok(sender.subscribe().into());
}
if let Some(result) = cache.finished.get(&key) {
counter!("alignments.from_cache").inc(1);
let (_, rx) = watch::channel(Some(result.clone()));
return Ok(rx.into());
}
}
if let Some(mut db) = self.lock_database() {
if db.needs_init() {
db.init_with_config(&StaticAlignmentKey {
reference_name: reference_sequence_name,
aligner_config: self.aligners.describe()?,
})?;
}
if let Ok(Some(result)) = tokio::task::block_in_place(|| db.lookup(&key)) {
counter!("alignments.from_db").inc(1);
let result = Arc::new(result);
if let Some(mut cache) = cache_lock {
cache.finished.insert(key.clone(), result.clone());
}
let (_, rx) = watch::channel(Some(result));
return Ok(rx.into());
}
}
let metadata = WorkerQueryMetadata { cluster_region };
let (tx, rx) = watch::channel(None);
if let Some(mut cache) = cache_lock {
cache.in_progress.insert(key.clone(), tx.clone());
}
self.start_realignment_with_callback(query, metadata, key, tx);
Ok(rx.into())
}
fn start_realignment_with_callback(
&self,
query: AlignmentQuery,
metadata: WorkerQueryMetadata,
key: AlignmentKey,
sender: watch::Sender<Option<Arc<TwitcherAlignmentResult>>>,
) {
let aligner = self.aligners.clone();
let log_level = *crate::THIS_LOG_LEVEL.get_or_init(Default::default);
let memory = self.per_alignment_settings.memory_allowance;
let timeout = self.per_alignment_settings.timeout;
let queue = self.parallelism.clone().acquire_owned();
let state_mutex = self.in_memory_cache.clone();
let db_mutex = self.database.clone();
tokio::spawn(async move {
let Ok(_permit) = queue.await else {
error!("Cannot aquire concurrency permit to align {}", metadata.cluster_region);
return;
};
RUNNING.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let worker_query = WorkerQuery {
aligner: (&*aligner).into(),
log_level,
memory,
query,
metadata,
};
counter!("alignments.computations").inc(1);
let res = Arc::new(Self::run_alignment(&worker_query, timeout).await);
let mut state = state_mutex.as_ref().map(|c| c.lock().unwrap());
if let Some(ref mut cache) = state {cache.in_progress.remove(&key);}
let _ = sender
.send(Some(res.clone()))
.inspect_err(|e| error!("Error sending result: {e:?}"));
if let Some(mut db) = db_mutex.as_ref().map(|db| db.lock().unwrap()) {
if let Err(e) = tokio::task::block_in_place(|| db.store(key.clone(), res.clone())) {
error!("Can't write alignment to database: {e}");
}
}
if let Some(ref mut cache) = state {
cache.finished.insert(key.clone(), res);
}
RUNNING.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
});
}
#[instrument(skip_all)]
async fn run_alignment(wq: &WorkerQuery<'_>, timeout: Option<Duration>) -> TwitcherAlignmentResult {
let mut cmd = tokio::process::Command::new(&*crate::THIS_EXE);
cmd.arg("worker");
cmd.stdin(Stdio::piped());
cmd.stdout(Stdio::piped());
cmd.stderr(Stdio::piped());
let mut child = cmd.spawn()?;
let mut stdin = child.stdin.take().unwrap();
let msg = rmp_serde::to_vec(&wq).map_err(|e| {
AlignmentFailure::error(&format!("Can't encode and write worker query: {e}"))
})?;
stdin.write_all(&msg).await?;
std::mem::drop(stdin);
let stderr = child.stderr.take().unwrap();
let mut stderr_writer_handle = None;
let oom = Arc::new(AtomicBool::new(false));
let oom2 = oom.clone();
if let Some(mut log_w) = crate::STDERR_LOG_WRITER.get().cloned() {
stderr_writer_handle = Some(tokio::spawn(async move {
static WARNED_ALREADY_ABOUT_FW_DIR: AtomicBool = AtomicBool::new(false);
let mut err_br = tokio::io::BufReader::new(stderr).lines();
loop {
let l = match err_br.next_line().await {
Ok(Some(l)) => l,
Ok(None) => break,
Err(e) => {
error!("{e}");
continue;
},
};
if l.contains("memory allocation of") {
oom2.store(true, std::sync::atomic::Ordering::Relaxed);
continue;
}
if l.contains("Forward direction not yet supported in PreprocessedTemplateSwitchMinLengthStrategy") {
if WARNED_ALREADY_ABOUT_FW_DIR.load(std::sync::atomic::Ordering::Relaxed) {
continue;
}
WARNED_ALREADY_ABOUT_FW_DIR.store(true, std::sync::atomic::Ordering::Relaxed);
}
let _ = tokio::task::block_in_place(|| writeln!(log_w, "{l}"));
}
}));
}
let exit = match tokio::time::timeout(timeout.unwrap_or(Duration::MAX), child.wait()).await {
Ok(exit) => Some(exit?),
Err(_elapsed) => {
child.kill().await?;
None
},
};
if let Some(h) = stderr_writer_handle {
let _ = h.await;
}
if let Some(exit_status) = exit {
if exit_status.success() {
let mut result_bytes = Vec::new();
child.stdout.as_mut().unwrap().read_to_end(&mut result_bytes).await?;
rmp_serde::from_slice::<TwitcherAlignmentResult>(&result_bytes)
.map_err(|e| AlignmentFailure::error(&format!("Can't read result: {e:?}")))?
} else if let Some(exit_code) = exit_status.code() {
TwitcherAlignmentResult::Err(AlignmentFailure::error(&format!(
"Aligner exited with a non-zero exit code: {exit_code}",
)))
} else if oom.load(std::sync::atomic::Ordering::Relaxed) {
TwitcherAlignmentResult::Err(AlignmentFailure::oom())
} else {
TwitcherAlignmentResult::Err(AlignmentFailure::error(
"Aligner exited abnormally (no exit code). Perhaps it ran out of memory?",
))
}
} else {
TwitcherAlignmentResult::Err(AlignmentFailure::timeout(timeout.unwrap()))
}
}
pub fn clear_cache(&self) {
if let Some(cache) = &self.in_memory_cache {
let mut cache = cache.lock().unwrap();
cache.finished.clear();
}
}
}
fn count_result(res: &TwitcherAlignmentResult) {
let key = match res {
Ok(TwitcherAlignment {
result: TwitcherAlignmentCase::FoundTS { .. },
..
}) => "alignments.results.successful.with_ts",
Ok(TwitcherAlignment {
result: TwitcherAlignmentCase::NoTS { .. },
..
}) => "alignments.results.successful.without_ts",
Err(AlignmentFailure::SoftFailure {
reason: SoftFailureReason::OutOfMemory,
}) => "alignments.results.failed.oom",
Err(AlignmentFailure::SoftFailure {
reason: SoftFailureReason::Timeout(_),
}) => "alignments.results.failed.timeout",
Err(AlignmentFailure::SoftFailure {
reason: SoftFailureReason::Other(_),
}) => "alignments.results.failed",
Err(AlignmentFailure::Error { .. }) => "alignments.results.error",
};
counter!(key).inc(1);
}
impl TryFrom<&cli::CliAlignmentArgs> for AlignmentOrchestrator {
type Error = anyhow::Error;
fn try_from(value: &cli::CliAlignmentArgs) -> Result<Self, Self::Error> {
let (sem, mem_per_thread) = value.init_semaphore()?;
let alns = value.init_aligner()?;
let database: Option<Database> = (&value.database).try_into()?;
Ok(Self {
aligners: alns.into(),
parallelism: sem.into(),
per_alignment_settings: PerAlignmentSettings { memory_allowance: mem_per_thread, timeout: value.aligner_timeout },
in_memory_cache: None,
database: database.map(|db| Arc::new(Mutex::new(db))),
})
}
}
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct AlignmentQuery {
pub sequences: SequencePair,
pub ranges: AlignmentRange,
}
#[derive(Deserialize, Serialize)]
pub struct AStarAlignerPair {
pub ts: Aligner,
pub no_ts: Aligner,
}
#[derive(Deserialize, Serialize)]
#[allow(clippy::large_enum_variant)]
pub enum AlignerSelector {
AStar(AStarAlignerPair),
Fpa(FourPointAligner),
}
pub type AlignerSelectorDescription = Vec<u8>;
impl AlignerSelector {
pub fn describe(&self) -> anyhow::Result<AlignerSelectorDescription> {
Ok(rmp_serde::to_vec(self)?)
}
}