twitcher 0.3.2

Find template switch mutations in genomic data
use std::{fmt::Display, path::PathBuf, sync::Arc, time::Duration};

use anyhow::Context;
use clap::ValueEnum;
use itertools::Itertools;
use rusqlite::{Connection, OptionalExtension};
use tracing::{error, warn};

use crate::common::aligner::{
    AlignerSelectorDescription, AlignmentKey,
    result::{AlignmentFailure, SoftFailureReason, TwitcherAlignmentResult},
};

const DB_SCHEMA_VERSION: i32 = 2;
const DB_TABLES: [DbTableSpec; 2] = [
    DbTableSpec {
        name: "configurations",
        colunms: &[
            ["id", "INTEGER PRIMARY KEY AUTOINCREMENT"],
            ["reference", "TEXT NOT NULL"],
            ["aligner", "BLOB NOT NULL"],
            ["created_at", "TEXT NOT NULL DEFAULT (datetime('now'))"],
        ],
        extra: None,
    },
    DbTableSpec {
        name: "alignments",
        colunms: &[
            ["id", "INTEGER PRIMARY KEY AUTOINCREMENT"],
            [
                "config_id",
                "INTEGER NOT NULL REFERENCES configurations(id)",
            ],
            ["reference_region", "TEXT NOT NULL"],
            ["ref_offset", "INTEGER NOT NULL"],
            ["ref_limit", "INTEGER NOT NULL"],
            ["query_offset", "INTEGER NOT NULL"],
            ["query_limit", "INTEGER NOT NULL"],
            ["query_seq", "BLOB NOT NULL"],
            ["query_seq_n_pos", "BLOB NOT NULL"],
            ["query_seq_len", "INTEGER NOT NULL"],
            ["memory_allowance", "INTEGER NOT NULL"],
            ["timeout_ns", "INTEGER NOT NULL"],
            ["result", "BLOB NOT NULL"],
        ],
        extra: Some(
            "UNIQUE(config_id, reference_region, ref_offset, ref_limit, query_offset, query_limit, query_seq, query_seq_n_pos, query_seq_len)",
        ),
    },
];

const WRITE_BUFFER_SIZE: usize = 128;

#[derive(Debug)]
struct DbTableSpec {
    name: &'static str,
    colunms: &'static [[&'static str; 2]],
    extra: Option<&'static str>,
}

#[derive(clap::Args, Debug, Default)]
pub struct CliDatabaseArgs {
    #[arg(long = "db")]
    path: Option<PathBuf>,

    #[arg(long = "db-mode", default_value_t, requires = "path")]
    mode: DbMode,
}

#[derive(clap::ValueEnum, Debug, Clone, Copy, Default, PartialEq, Eq)]
enum DbMode {
    ReadOnly,
    WriteOnly,
    #[default]
    ReadWrite,
}

impl DbMode {
    const fn is_read(self) -> bool {
        matches!(self, Self::ReadOnly | Self::ReadWrite)
    }
    const fn is_write(self) -> bool {
        matches!(self, Self::WriteOnly | Self::ReadWrite)
    }
}

impl Display for DbMode {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        let pv = self.to_possible_value().unwrap();
        f.write_str(pv.get_name())
    }
}

pub struct Database {
    conn: Connection,
    mode: DbMode,
    config_id: Option<u32>,
    pending_writes: Vec<(AlignmentKey, Arc<TwitcherAlignmentResult>, usize, i64)>,
    initialized: bool,
}

impl TryFrom<&CliDatabaseArgs> for Option<Database> {
    type Error = anyhow::Error;
    fn try_from(args: &CliDatabaseArgs) -> anyhow::Result<Self> {
        let Some(path) = &args.path else {
            return Ok(None);
        };
        let exists = path.exists();
        let conn = if exists {
            let conn = Connection::open(path)?;
            let version: Result<i32, _> =
                conn.pragma_query_value(None, "user_version", |row| row.get(0));
            if version != Ok(DB_SCHEMA_VERSION) {
                anyhow::bail!("The specified database file does not have the appropriate schema.")
            }
            Database::check_tables(&conn)?;
            conn
        } else if args.mode.is_write() {
            // Create database
            let conn = Connection::open(path)?;
            conn.pragma_update(None, "user_version", DB_SCHEMA_VERSION)?;
            Database::init_tables(&conn)?;
            conn
        } else {
            anyhow::bail!("Cannot create database: read-only was specified.");
        };
        conn.pragma_update(None, "journal_mode", "WAL")?; // write-ahead log
        Ok(Some(Database {
            conn,
            mode: args.mode,
            config_id: None,
            pending_writes: Vec::new(),
            initialized: false,
        }))
    }
}

pub struct StaticAlignmentKey<'a> {
    pub reference_name: &'a str,
    pub aligner_config: AlignerSelectorDescription,
}

fn to_timeout_ns(t: Option<Duration>) -> i64 {
    t.map(|d| d.as_nanos() as i64).unwrap_or(-1)
}

impl Database {
    fn init_tables(conn: &Connection) -> Result<(), rusqlite::Error> {
        for table in &DB_TABLES {
            let name = table.name;
            // cursed
            let columns = table
                .colunms
                .iter()
                .map(|s| s.iter().join(" "))
                .chain(table.extra.map(str::to_string).into_iter())
                .join(",\n");
            let query = format!(
                "CREATE TABLE {name} (
                    {columns}
                )"
            );
            conn.execute(&query, ())?;
        }
        Ok(())
    }

    fn check_tables(conn: &Connection) -> anyhow::Result<()> {
        for table in &DB_TABLES {
            let exists: bool = conn.query_row(
                "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name=?1",
                [table.name],
                |row| row.get::<_, i64>(0),
            )? > 0;

            if !exists {
                anyhow::bail!("Required table '{}' does not exist", table.name);
            }

            let mut stmt = conn.prepare(&format!("PRAGMA table_info({})", table.name))?;
            let actual_columns: Vec<String> = stmt
                .query_map([], |row| row.get::<_, String>(1))?
                .filter_map(Result::ok)
                .collect();

            for &[col, _] in table.colunms {
                if !actual_columns.iter().any(|c| c == col) {
                    anyhow::bail!("Table '{}' is missing required column '{col}'", table.name);
                }
            }
        }
        Ok(())
    }

    pub fn init_with_config(&mut self, key: &StaticAlignmentKey) -> anyhow::Result<()> {
        let existing_config: Option<u32> = self
            .conn
            .query_row(
                "SELECT id FROM configurations WHERE reference = ? AND aligner = ? LIMIT 1",
                (key.reference_name, &key.aligner_config),
                |row| row.get(0),
            )
            .optional()?;

        if let Some(existing) = existing_config {
            self.config_id = Some(existing);
        } else if self.mode.is_write() {
            let new_id: u32 = self.conn.query_one(
                "INSERT INTO configurations (reference, aligner) VALUES (?, ?) RETURNING id",
                (key.reference_name, &key.aligner_config),
                |row| row.get(0),
            )?;
            self.config_id = Some(new_id);
        } else {
            warn!(
                "read-only database, but no existing data matches the configuration of this execution"
            );
        }

        self.initialized = true;
        Ok(())
    }

    pub const fn needs_init(&self) -> bool {
        !self.initialized
    }

    pub fn lookup(
        &self,
        key: &AlignmentKey,
        memory: usize,
        timeout: Option<Duration>,
    ) -> anyhow::Result<Option<TwitcherAlignmentResult>> {
        if !self.initialized {
            anyhow::bail!("DB needs to be initialized before first use");
        }
        if !self.mode.is_read() {
            return Ok(None);
        }
        let Some(cid) = self.config_id else {
            return Ok(None);
        };
        let (qry_seq, qry_ns) = compress_sequence(&key.query_sequence);
        let row: Option<(Vec<u8>, i64, i64)> = self
            .conn
            .query_one(
                "SELECT result, memory_allowance, timeout_ns FROM alignments
                 WHERE config_id = ?
                   AND reference_region = ?
                   AND ref_offset = ? AND ref_limit = ?
                   AND query_offset = ? AND query_limit = ?
                   AND query_seq = ? AND query_seq_n_pos = ? AND query_seq_len = ?",
                (
                    cid,
                    key.reference_region.to_string(),
                    isize::try_from(key.alignment_ranges.reference_offset())?,
                    isize::try_from(key.alignment_ranges.reference_limit())?,
                    isize::try_from(key.alignment_ranges.query_offset())?,
                    isize::try_from(key.alignment_ranges.query_limit())?,
                    qry_seq,
                    qry_ns,
                    isize::try_from(key.query_sequence.len())?,
                ),
                |row| Ok((row.get(0)?, row.get(1)?, row.get(2)?)),
            )
            .optional()?;

        if let Some((bytes, stored_memory, stored_timeout_ns)) = row {
            let result: TwitcherAlignmentResult = rmp_serde::from_slice(&bytes)?;
            let current_timeout_ns = to_timeout_ns(timeout);
            let retry_timeout = match (stored_timeout_ns, current_timeout_ns) {
                (-1, _) => false,
                (_, -1) => true,
                (s, c) => c > s,
            };
            match &result {
                Err(AlignmentFailure::SoftFailure {
                    reason: SoftFailureReason::OutOfMemory,
                }) if memory as i64 > stored_memory => return Ok(None),
                Err(AlignmentFailure::SoftFailure {
                    reason: SoftFailureReason::Timeout(_),
                }) if retry_timeout => return Ok(None),
                _ => {}
            }
            Ok(Some(result))
        } else {
            Ok(None)
        }
    }

    pub fn store(
        &mut self,
        key: AlignmentKey,
        result: Arc<TwitcherAlignmentResult>,
        memory: usize,
        timeout: Option<Duration>,
    ) -> anyhow::Result<()> {
        if !self.initialized {
            anyhow::bail!("DB needs to be initialized before first use");
        }
        if !self.mode.is_write() {
            return Ok(());
        }

        self.pending_writes.push((key, result, memory, to_timeout_ns(timeout)));

        if self.pending_writes.len() >= WRITE_BUFFER_SIZE {
            self.flush_writes()?;
        }

        Ok(())
    }

    fn flush_writes(&mut self) -> anyhow::Result<()> {
        if self.pending_writes.is_empty() {
            return Ok(());
        }

        let tx = self.conn.transaction()?;
        for (key, result, memory, timeout_ns) in self.pending_writes.drain(..) {
            let (qry_seq, qry_ns) = compress_sequence(&key.query_sequence);

            tx.execute(
                "INSERT OR REPLACE INTO alignments
                    (config_id, reference_region, ref_offset, ref_limit,
                     query_offset, query_limit, query_seq, query_seq_n_pos,
                     query_seq_len, memory_allowance, timeout_ns, result)
                 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
                (
                    self.config_id.context("We should have a config id here!")?,
                    key.reference_region.to_string(),
                    isize::try_from(key.alignment_ranges.reference_offset())?,
                    isize::try_from(key.alignment_ranges.reference_limit())?,
                    isize::try_from(key.alignment_ranges.query_offset())?,
                    isize::try_from(key.alignment_ranges.query_limit())?,
                    qry_seq,
                    qry_ns,
                    isize::try_from(key.query_sequence.len())?,
                    memory as i64,
                    timeout_ns,
                    rmp_serde::to_vec(&result)?,
                ),
            )?;
        }
        tx.commit()?;
        Ok(())
    }
}

impl Drop for Database {
    fn drop(&mut self) {
        if let Err(e) = self.flush_writes() {
            error!("Error in flush during drop: {e}");
        }
    }
}

fn compress_sequence(seq: &[u8]) -> (Vec<u8>, Vec<u8>) {
    let mut bits: Vec<u8> = vec![0u8; seq.len().div_ceil(4)];
    let mut n_positions: Vec<u8> = Vec::new();

    for (i, &base) in seq.iter().enumerate() {
        let two_bit = match base {
            b'A' | b'a' => 0b00,
            b'C' | b'c' => 0b01,
            b'G' | b'g' => 0b10,
            b'T' | b't' => 0b11,
            _ => {
                // N or anything else: record position, write 0
                let pos = i as u16;
                n_positions.extend_from_slice(&pos.to_be_bytes());
                0b00
            }
        };
        bits[i / 4] |= two_bit << (6 - (i % 4) * 2);
    }

    (bits, n_positions)
}