use std::{fmt::Display, path::PathBuf, sync::Arc};
use anyhow::Context;
use clap::ValueEnum;
use itertools::Itertools;
use rusqlite::{Connection, OptionalExtension};
use tracing::{error, warn};
use crate::common::aligner::{
AlignerSelectorDescription, AlignmentKey, result::TwitcherAlignmentResult,
};
const DB_SCHEMA_VERSION: i32 = 1;
const DB_TABLES: [DbTableSpec; 2] = [
DbTableSpec {
name: "configurations",
colunms: &[
["id", "INTEGER PRIMARY KEY AUTOINCREMENT"],
["reference", "TEXT NOT NULL"],
["aligner", "BLOB NOT NULL"],
["created_at", "TEXT NOT NULL DEFAULT (datetime('now'))"],
],
extra: None,
},
DbTableSpec {
name: "alignments",
colunms: &[
["id", "INTEGER PRIMARY KEY AUTOINCREMENT"],
[
"config_id",
"INTEGER NOT NULL REFERENCES configurations(id)",
],
["reference_region", "TEXT NOT NULL"],
["ref_offset", "INTEGER NOT NULL"],
["ref_limit", "INTEGER NOT NULL"],
["query_offset", "INTEGER NOT NULL"],
["query_limit", "INTEGER NOT NULL"],
["query_seq", "BLOB NOT NULL"],
["query_seq_n_pos", "BLOB NOT NULL"],
["query_seq_len", "INTEGER NOT NULL"],
["result", "BLOB NOT NULL"],
],
extra: Some(
"UNIQUE(config_id, reference_region, ref_offset, ref_limit, query_offset, query_limit, query_seq, query_seq_n_pos, query_seq_len)",
),
},
];
const WRITE_BUFFER_SIZE: usize = 128;
#[derive(Debug)]
struct DbTableSpec {
name: &'static str,
colunms: &'static [[&'static str; 2]],
extra: Option<&'static str>,
}
#[derive(clap::Args, Debug, Default)]
pub struct CliDatabaseArgs {
#[arg(long = "db")]
path: Option<PathBuf>,
#[arg(long = "db-mode", default_value_t, requires = "path")]
mode: DbMode,
}
#[derive(clap::ValueEnum, Debug, Clone, Copy, Default, PartialEq, Eq)]
enum DbMode {
ReadOnly,
WriteOnly,
#[default]
ReadWrite,
}
impl DbMode {
fn is_read(self) -> bool {
matches!(self, DbMode::ReadOnly | DbMode::ReadWrite)
}
fn is_write(self) -> bool {
matches!(self, DbMode::WriteOnly | DbMode::ReadWrite)
}
}
impl Display for DbMode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let pv = self.to_possible_value().unwrap();
f.write_str(pv.get_name())
}
}
pub struct Database {
conn: Connection,
mode: DbMode,
config_id: Option<u32>,
pending_writes: Vec<(AlignmentKey, Arc<TwitcherAlignmentResult>)>,
initialized: bool,
}
impl TryFrom<&CliDatabaseArgs> for Option<Database> {
type Error = anyhow::Error;
fn try_from(args: &CliDatabaseArgs) -> anyhow::Result<Self> {
let Some(path) = &args.path else {
return Ok(None);
};
let exists = path.exists();
let conn = if exists {
let conn = Connection::open(path)?;
let version: Result<i32, _> =
conn.pragma_query_value(None, "user_version", |row| row.get(0));
if version != Ok(DB_SCHEMA_VERSION) {
anyhow::bail!("The specified database file does not have the appropriate schema.")
}
Database::check_tables(&conn)?;
conn
} else if args.mode.is_write() {
let conn = Connection::open(path)?;
conn.pragma_update(None, "user_version", DB_SCHEMA_VERSION)?;
Database::init_tables(&conn)?;
conn
} else {
anyhow::bail!("Cannot create database: read-only was specified.");
};
conn.pragma_update(None, "journal_mode", "WAL")?; Ok(Some(Database {
conn,
mode: args.mode,
config_id: None,
pending_writes: Vec::new(),
initialized: false,
}))
}
}
pub struct StaticAlignmentKey<'a> {
pub reference_name: &'a str,
pub aligner_config: AlignerSelectorDescription,
}
impl Database {
fn init_tables(conn: &Connection) -> Result<(), rusqlite::Error> {
for table in &DB_TABLES {
let name = table.name;
let columns = table
.colunms
.iter()
.map(|s| s.iter().join(" "))
.chain(table.extra.map(str::to_string).into_iter())
.join(",\n");
let query = format!(
"CREATE TABLE {name} (
{columns}
)"
);
conn.execute(&query, ())?;
}
Ok(())
}
fn check_tables(conn: &Connection) -> anyhow::Result<()> {
for table in &DB_TABLES {
let exists: bool = conn.query_row(
"SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name=?1",
[table.name],
|row| row.get::<_, i64>(0),
)? > 0;
if !exists {
anyhow::bail!("Required table '{}' does not exist", table.name);
}
let mut stmt = conn.prepare(&format!("PRAGMA table_info({})", table.name))?;
let actual_columns: Vec<String> = stmt
.query_map([], |row| row.get::<_, String>(1))?
.filter_map(Result::ok)
.collect();
for &[col, _] in table.colunms {
if !actual_columns.iter().any(|c| c == col) {
anyhow::bail!("Table '{}' is missing required column '{col}'", table.name);
}
}
}
Ok(())
}
pub fn init_with_config(&mut self, key: &StaticAlignmentKey) -> anyhow::Result<()> {
let existing_config: Option<u32> = self
.conn
.query_row(
"
SELECT id FROM configurations WHERE reference = ? AND aligner = ? LIMIT 1
",
(key.reference_name, &key.aligner_config),
|row| row.get(0),
)
.optional()?;
if let Some(existing) = existing_config {
self.config_id = Some(existing);
} else if self.mode.is_write() {
let new_id: u32 = self.conn.query_one(
"
INSERT INTO configurations (
reference,
aligner
) VALUES (
?,
?
) RETURNING id",
(key.reference_name, &key.aligner_config),
|row| row.get(0),
)?;
self.config_id = Some(new_id);
} else {
warn!(
"read-only database, but no existing data matches the configuration of this execution"
);
}
self.initialized = true;
Ok(())
}
pub fn needs_init(&self) -> bool {
!self.initialized
}
pub fn lookup(&self, key: &AlignmentKey) -> anyhow::Result<Option<TwitcherAlignmentResult>> {
if !self.initialized {
anyhow::bail!("DB needs to be initialized before first use");
}
if !self.mode.is_read() {
return Ok(None);
}
let Some(cid) = self.config_id else {
return Ok(None);
};
let (qry_seq, qry_ns) = compress_sequence(&key.query_sequence);
let row: Option<Vec<u8>> = self
.conn
.query_one(
"SELECT result FROM alignments
WHERE config_id = ?
AND reference_region = ?
AND ref_offset = ? AND ref_limit = ?
AND query_offset = ? AND query_limit = ?
AND query_seq = ? AND query_seq_n_pos = ? AND query_seq_len = ?",
(
cid,
key.reference_region.to_string(),
isize::try_from(key.alignment_ranges.reference_offset())?,
isize::try_from(key.alignment_ranges.reference_limit())?,
isize::try_from(key.alignment_ranges.query_offset())?,
isize::try_from(key.alignment_ranges.query_limit())?,
qry_seq,
qry_ns,
isize::try_from(key.query_sequence.len())?,
),
|row| row.get(0),
)
.optional()?;
if let Some(row) = row {
let result = rmp_serde::from_slice(&row)?;
Ok(Some(result))
} else {
Ok(None)
}
}
pub fn store(
&mut self,
key: AlignmentKey,
result: Arc<TwitcherAlignmentResult>,
) -> anyhow::Result<()> {
if !self.initialized {
anyhow::bail!("DB needs to be initialized before first use");
}
if !self.mode.is_write() {
return Ok(());
}
self.pending_writes.push((key, result));
if self.pending_writes.len() >= WRITE_BUFFER_SIZE {
self.flush_writes()?;
}
Ok(())
}
fn flush_writes(&mut self) -> anyhow::Result<()> {
if self.pending_writes.is_empty() {
return Ok(());
}
let tx = self.conn.transaction()?;
for (key, result) in self.pending_writes.drain(..) {
let (qry_seq, qry_ns) = compress_sequence(&key.query_sequence);
tx.execute(
"INSERT OR IGNORE INTO alignments
(config_id, reference_region, ref_offset, ref_limit,
query_offset, query_limit, query_seq, query_seq_n_pos,
query_seq_len, result)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
(
self.config_id.context("We should have a config id here!")?,
key.reference_region.to_string(),
isize::try_from(key.alignment_ranges.reference_offset())?,
isize::try_from(key.alignment_ranges.reference_limit())?,
isize::try_from(key.alignment_ranges.query_offset())?,
isize::try_from(key.alignment_ranges.query_limit())?,
qry_seq,
qry_ns,
isize::try_from(key.query_sequence.len())?,
rmp_serde::to_vec(&result)?,
),
)?;
}
tx.commit()?;
Ok(())
}
}
impl Drop for Database {
fn drop(&mut self) {
if let Err(e) = self.flush_writes() {
error!("Error in flush during drop: {e}");
}
}
}
fn compress_sequence(seq: &[u8]) -> (Vec<u8>, Vec<u8>) {
let mut bits: Vec<u8> = vec![0u8; seq.len().div_ceil(4)];
let mut n_positions: Vec<u8> = Vec::new();
for (i, &base) in seq.iter().enumerate() {
let two_bit = match base {
b'A' | b'a' => 0b00,
b'C' | b'c' => 0b01,
b'G' | b'g' => 0b10,
b'T' | b't' => 0b11,
_ => {
let pos = i as u16;
n_positions.extend_from_slice(&pos.to_be_bytes());
0b00
}
};
bits[i / 4] |= two_bit << (6 - (i % 4) * 2);
}
(bits, n_positions)
}