twitcher 0.1.8

Find template switch mutations in genomic data
use std::{fmt::Display, path::PathBuf, sync::Arc};

use anyhow::Context;
use clap::ValueEnum;
use itertools::Itertools;
use rusqlite::{Connection, OptionalExtension};
use tracing::{error, warn};

use crate::common::aligner::{
    AlignerSelectorDescription, AlignmentKey, result::TwitcherAlignmentResult,
};

const DB_SCHEMA_VERSION: i32 = 1;
const DB_TABLES: [DbTableSpec; 2] = [
    DbTableSpec {
        name: "configurations",
        colunms: &[
            ["id", "INTEGER PRIMARY KEY AUTOINCREMENT"],
            ["reference", "TEXT NOT NULL"],
            ["aligner", "BLOB NOT NULL"],
            ["created_at", "TEXT NOT NULL DEFAULT (datetime('now'))"],
        ],
        extra: None,
    },
    DbTableSpec {
        name: "alignments",
        colunms: &[
            ["id", "INTEGER PRIMARY KEY AUTOINCREMENT"],
            [
                "config_id",
                "INTEGER NOT NULL REFERENCES configurations(id)",
            ],
            ["reference_region", "TEXT NOT NULL"],
            ["ref_offset", "INTEGER NOT NULL"],
            ["ref_limit", "INTEGER NOT NULL"],
            ["query_offset", "INTEGER NOT NULL"],
            ["query_limit", "INTEGER NOT NULL"],
            ["query_seq", "BLOB NOT NULL"],
            ["query_seq_n_pos", "BLOB NOT NULL"],
            ["query_seq_len", "INTEGER NOT NULL"],
            ["result", "BLOB NOT NULL"],
        ],
        extra: Some(
            "UNIQUE(config_id, reference_region, ref_offset, ref_limit, query_offset, query_limit, query_seq, query_seq_n_pos, query_seq_len)",
        ),
    },
];

const WRITE_BUFFER_SIZE: usize = 128;

#[derive(Debug)]
struct DbTableSpec {
    name: &'static str,
    colunms: &'static [[&'static str; 2]],
    extra: Option<&'static str>,
}

#[derive(clap::Args, Debug, Default)]
pub struct CliDatabaseArgs {
    #[arg(long = "db")]
    path: Option<PathBuf>,

    #[arg(long = "db-mode", default_value_t, requires = "path")]
    mode: DbMode,
}

#[derive(clap::ValueEnum, Debug, Clone, Copy, Default, PartialEq, Eq)]
enum DbMode {
    ReadOnly,
    WriteOnly,
    #[default]
    ReadWrite,
}

impl DbMode {
    fn is_read(self) -> bool {
        matches!(self, DbMode::ReadOnly | DbMode::ReadWrite)
    }
    fn is_write(self) -> bool {
        matches!(self, DbMode::WriteOnly | DbMode::ReadWrite)
    }
}

impl Display for DbMode {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        let pv = self.to_possible_value().unwrap();
        f.write_str(pv.get_name())
    }
}

pub struct Database {
    conn: Connection,
    mode: DbMode,
    config_id: Option<u32>,
    pending_writes: Vec<(AlignmentKey, Arc<TwitcherAlignmentResult>)>,
    initialized: bool,
}

impl TryFrom<&CliDatabaseArgs> for Option<Database> {
    type Error = anyhow::Error;
    fn try_from(args: &CliDatabaseArgs) -> anyhow::Result<Self> {
        let Some(path) = &args.path else {
            return Ok(None);
        };
        let exists = path.exists();
        let conn = if exists {
            let conn = Connection::open(path)?;
            let version: Result<i32, _> =
                conn.pragma_query_value(None, "user_version", |row| row.get(0));
            if version != Ok(DB_SCHEMA_VERSION) {
                anyhow::bail!("The specified database file does not have the appropriate schema.")
            }
            Database::check_tables(&conn)?;
            conn
        } else if args.mode.is_write() {
            // Create database
            let conn = Connection::open(path)?;
            conn.pragma_update(None, "user_version", DB_SCHEMA_VERSION)?;
            Database::init_tables(&conn)?;
            conn
        } else {
            anyhow::bail!("Cannot create database: read-only was specified.");
        };
        conn.pragma_update(None, "journal_mode", "WAL")?; // write-ahead log
        Ok(Some(Database {
            conn,
            mode: args.mode,
            config_id: None,
            pending_writes: Vec::new(),
            initialized: false,
        }))
    }
}

pub struct StaticAlignmentKey<'a> {
    pub reference_name: &'a str,
    pub aligner_config: AlignerSelectorDescription,
}

impl Database {
    fn init_tables(conn: &Connection) -> Result<(), rusqlite::Error> {
        for table in &DB_TABLES {
            let name = table.name;
            // cursed
            let columns = table
                .colunms
                .iter()
                .map(|s| s.iter().join(" "))
                .chain(table.extra.map(str::to_string).into_iter())
                .join(",\n");
            let query = format!(
                "CREATE TABLE {name} (
                    {columns}
                )"
            );
            conn.execute(&query, ())?;
        }
        Ok(())
    }

    fn check_tables(conn: &Connection) -> anyhow::Result<()> {
        for table in &DB_TABLES {
            let exists: bool = conn.query_row(
                "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name=?1",
                [table.name],
                |row| row.get::<_, i64>(0),
            )? > 0;

            if !exists {
                anyhow::bail!("Required table '{}' does not exist", table.name);
            }

            let mut stmt = conn.prepare(&format!("PRAGMA table_info({})", table.name))?;
            let actual_columns: Vec<String> = stmt
                .query_map([], |row| row.get::<_, String>(1))?
                .filter_map(Result::ok)
                .collect();

            for &[col, _] in table.colunms {
                if !actual_columns.iter().any(|c| c == col) {
                    anyhow::bail!("Table '{}' is missing required column '{col}'", table.name);
                }
            }
        }
        Ok(())
    }

    pub fn init_with_config(&mut self, key: &StaticAlignmentKey) -> anyhow::Result<()> {
        let existing_config: Option<u32> = self
            .conn
            .query_row(
                "
                SELECT id FROM configurations WHERE reference = ? AND aligner = ? LIMIT 1
            ",
                (key.reference_name, &key.aligner_config),
                |row| row.get(0),
            )
            .optional()?;

        if let Some(existing) = existing_config {
            self.config_id = Some(existing);
        } else if self.mode.is_write() {
            let new_id: u32 = self.conn.query_one(
                "
                    INSERT INTO configurations (
                        reference,
                        aligner
                    ) VALUES (
                        ?,
                        ?
                    ) RETURNING id",
                (key.reference_name, &key.aligner_config),
                |row| row.get(0),
            )?;
            self.config_id = Some(new_id);
        } else {
            warn!(
                "read-only database, but no existing data matches the configuration of this execution"
            );
        }

        self.initialized = true;
        Ok(())
    }

    pub fn needs_init(&self) -> bool {
        !self.initialized
    }

    pub fn lookup(&self, key: &AlignmentKey) -> anyhow::Result<Option<TwitcherAlignmentResult>> {
        if !self.initialized {
            anyhow::bail!("DB needs to be initialized before first use");
        }
        if !self.mode.is_read() {
            return Ok(None);
        }
        let Some(cid) = self.config_id else {
            return Ok(None);
        };
        let (qry_seq, qry_ns) = compress_sequence(&key.query_sequence);
        let row: Option<Vec<u8>> = self
            .conn
            .query_one(
                "SELECT result FROM alignments
                 WHERE config_id = ?
                   AND reference_region = ?
                   AND ref_offset = ? AND ref_limit = ?
                   AND query_offset = ? AND query_limit = ?
                   AND query_seq = ? AND query_seq_n_pos = ? AND query_seq_len = ?",
                (
                    cid,
                    key.reference_region.to_string(),
                    isize::try_from(key.alignment_ranges.reference_offset())?,
                    isize::try_from(key.alignment_ranges.reference_limit())?,
                    isize::try_from(key.alignment_ranges.query_offset())?,
                    isize::try_from(key.alignment_ranges.query_limit())?,
                    qry_seq,
                    qry_ns,
                    isize::try_from(key.query_sequence.len())?,
                ),
                |row| row.get(0),
            )
            .optional()?;

        if let Some(row) = row {
            let result = rmp_serde::from_slice(&row)?;
            Ok(Some(result))
        } else {
            Ok(None)
        }
    }

    pub fn store(
        &mut self,
        key: AlignmentKey,
        result: Arc<TwitcherAlignmentResult>,
    ) -> anyhow::Result<()> {
        if !self.initialized {
            anyhow::bail!("DB needs to be initialized before first use");
        }
        if !self.mode.is_write() {
            return Ok(());
        }

        self.pending_writes.push((key, result));

        if self.pending_writes.len() >= WRITE_BUFFER_SIZE {
            self.flush_writes()?;
        }

        Ok(())
    }

    fn flush_writes(&mut self) -> anyhow::Result<()> {
        if self.pending_writes.is_empty() {
            return Ok(());
        }

        let tx = self.conn.transaction()?;
        for (key, result) in self.pending_writes.drain(..) {
            let (qry_seq, qry_ns) = compress_sequence(&key.query_sequence);

            tx.execute(
                "INSERT OR IGNORE INTO alignments
                    (config_id, reference_region, ref_offset, ref_limit,
                     query_offset, query_limit, query_seq, query_seq_n_pos,
                     query_seq_len, result)
                 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
                (
                    self.config_id.context("We should have a config id here!")?,
                    key.reference_region.to_string(),
                    isize::try_from(key.alignment_ranges.reference_offset())?,
                    isize::try_from(key.alignment_ranges.reference_limit())?,
                    isize::try_from(key.alignment_ranges.query_offset())?,
                    isize::try_from(key.alignment_ranges.query_limit())?,
                    qry_seq,
                    qry_ns,
                    isize::try_from(key.query_sequence.len())?,
                    rmp_serde::to_vec(&result)?,
                ),
            )?;
        }
        tx.commit()?;
        Ok(())
    }
}

impl Drop for Database {
    fn drop(&mut self) {
        if let Err(e) = self.flush_writes() {
            error!("Error in flush during drop: {e}");
        }
    }
}

fn compress_sequence(seq: &[u8]) -> (Vec<u8>, Vec<u8>) {
    let mut bits: Vec<u8> = vec![0u8; seq.len().div_ceil(4)];
    let mut n_positions: Vec<u8> = Vec::new();

    for (i, &base) in seq.iter().enumerate() {
        let two_bit = match base {
            b'A' | b'a' => 0b00,
            b'C' | b'c' => 0b01,
            b'G' | b'g' => 0b10,
            b'T' | b't' => 0b11,
            _ => {
                // N or anything else: record position, write 0
                let pos = i as u16;
                n_positions.extend_from_slice(&pos.to_be_bytes());
                0b00
            }
        };
        bits[i / 4] |= two_bit << (6 - (i % 4) * 2);
    }

    (bits, n_positions)
}