use std::{
collections::HashSet,
io::{BufRead, Write},
};
use rusqlite::{TEMP_DB, Transaction, params_from_iter};
use snafu::{ResultExt, Snafu};
use crate::{column::Column, sanitizing::SqlSanitize};
pub mod column;
pub mod sanitizing;
mod stmt;
#[derive(Debug, Snafu)]
pub enum Error {
#[snafu(display("Database error (context: {context}): {source}"))]
DatabaseErrorContext {
source: rusqlite::Error,
context: &'static str,
},
#[snafu(display("Database error: {source}"))]
WhateverDatabaseError { source: rusqlite::Error },
#[snafu(display("Read error: {source}"))]
SourceReadError { source: std::io::Error },
#[snafu(display("Blob write error: {source}"))]
BlobWriteError { source: std::io::Error },
#[snafu(display("Unknown error: {message}"))]
Unknown { message: String },
}
pub type Result<T> = std::result::Result<T, Error>;
pub fn prepare_db(
tx: &mut Transaction,
table_name: &str,
reset: bool,
columns: &[Column],
) -> Result<()> {
let sanitized_table_name = table_name.to_sql_sanitized_string();
if reset {
tx.execute(&stmt::drop_table_if_exists(&sanitized_table_name), [])
.context(DatabaseErrorContextSnafu {
context: "drop table",
})?;
}
let sanitized_column_names: Vec<_> = columns.iter().map(|v| v.sanitized_name()).collect();
tx.execute(
&stmt::create_table_if_not_exists(&sanitized_table_name, &sanitized_column_names),
[],
)
.context(DatabaseErrorContextSnafu {
context: "create table",
})?;
let mut statement =
tx.prepare(stmt::select_pragram_table_info())
.context(DatabaseErrorContextSnafu {
context: "sync columns",
})?;
let existing_columns: std::result::Result<HashSet<String>, _> = statement
.query_map([table_name], |row| row.get(0))
.context(WhateverDatabaseSnafu)?
.collect();
let existing_columns = existing_columns.context(WhateverDatabaseSnafu)?;
for (_col_name, san_col_name) in columns
.iter()
.zip(sanitized_column_names.iter())
.filter(|&(n, _)| !existing_columns.contains(n.name()))
{
tx.execute(
&stmt::alter_table_add_column(&sanitized_table_name, &san_col_name),
[],
)
.context(DatabaseErrorContextSnafu {
context: "add column",
})?;
}
Ok(())
}
pub fn insert_row(tx: &mut Transaction, table_name: &str, columns: &[Column]) -> Result<()> {
let sanitized_table_name = table_name.to_sql_sanitized_string();
let mut statement = tx
.prepare(&stmt::insert_row(&sanitized_table_name, &columns))
.context(DatabaseErrorContextSnafu {
context: "insert row",
})?;
statement
.execute(params_from_iter(columns.iter().map(|n| n.raw_value())))
.context(DatabaseErrorContextSnafu {
context: "insert row",
})?;
Ok(())
}
pub fn insert_blob(
tx: &mut Transaction,
table_name: &str,
columns: &[Column],
mut source: impl BufRead,
) -> Result<()> {
tx.execute(stmt::create_temporary_blob_table(), [])
.context(DatabaseErrorContextSnafu {
context: "insert blob",
})?;
let mut total_size = 0usize;
{
let mut buf = vec![0; stmt::BLOB_BUF_SIZE];
loop {
let count = source.read(&mut buf).context(SourceReadSnafu)?;
total_size += count;
if count == 0 {
break;
}
tx.execute(&stmt::insert_zero_blob(stmt::BLOB_BUF_SIZE), [])
.context(DatabaseErrorContextSnafu {
context: "insert blob",
})?;
let rowid = tx.last_insert_rowid();
let mut blob = tx
.blob_open(TEMP_DB, "blob_insert", "data", rowid, false)
.context(DatabaseErrorContextSnafu {
context: "open blob",
})?;
blob.write_all(&buf[..count]).context(BlobWriteSnafu)?;
}
}
let sanitized_table_name = table_name.to_sql_sanitized_string();
let mut statement = tx
.prepare(&stmt::insert_blob(
&sanitized_table_name,
&columns,
total_size,
))
.context(DatabaseErrorContextSnafu {
context: "insert blob",
})?;
let values = columns.iter().filter_map(|n| n.value());
statement
.execute(params_from_iter(values))
.context(DatabaseErrorContextSnafu {
context: "insert blob",
})?;
Ok(())
}
pub fn insert_lines(
tx: &mut Transaction,
table_name: &str,
columns: &[Column],
source: impl BufRead,
) -> Result<()> {
let sanitized_table_name = table_name.to_sql_sanitized_string();
let mut statement = tx
.prepare(&stmt::insert_row(&sanitized_table_name, &columns))
.context(DatabaseErrorContextSnafu {
context: "insert lines",
})?;
for line in source.lines() {
let line = line.context(SourceReadSnafu)?;
statement
.execute(params_from_iter(
columns.iter().map(|v| v.value_or_line(&line)),
))
.context(DatabaseErrorContextSnafu {
context: "insert lines",
})?;
}
Ok(())
}
#[cfg(test)]
mod test {
use std::io::Cursor;
use rusqlite::Connection;
use super::*;
use crate::Column;
#[test]
fn test_insert_blob_simple() {
let mut db = Connection::open_in_memory().unwrap();
let mut tx = db.transaction().unwrap();
let columns = vec![Column::blob_column("blob")];
prepare_db(&mut tx, "test", true, &columns).unwrap();
let source_bytes = "Hello World".as_bytes();
let source = Cursor::new(source_bytes);
insert_blob(&mut tx, "test", &columns, source).unwrap();
tx.commit().unwrap();
let blob_type: String = db
.query_one("SELECT typeof(blob) from test;", [], |row| row.get(0))
.unwrap();
assert_eq!(&blob_type, "blob");
let blob_content: String = db
.query_one("SELECT cast(blob as text) from test;", [], |row| row.get(0))
.unwrap();
assert_eq!(&blob_content, "Hello World");
}
#[test]
fn test_insert_lines_simple() {
let mut db = Connection::open_in_memory().unwrap();
let mut tx = db.transaction().unwrap();
let columns = vec![Column::line_column("line")];
prepare_db(&mut tx, "test", true, &columns).unwrap();
let source_bytes = "Hello World\nHello World".as_bytes();
let source = Cursor::new(source_bytes);
insert_lines(&mut tx, "test", &columns, source).unwrap();
tx.commit().unwrap();
let col_type: String = db
.query_one("SELECT DISTINCT typeof(line) from test;", [], |row| {
row.get(0)
})
.unwrap();
assert_eq!(&col_type, "text");
let mut stmt = db.prepare("SELECT line from test;").unwrap();
let rows: std::result::Result<Vec<String>, _> =
stmt.query_map([], |row| row.get(0)).unwrap().collect();
let rows = rows.unwrap();
assert_eq!(
rows,
vec!["Hello World".to_string(), "Hello World".to_string()]
);
}
#[test]
fn test_insert_raw_values() {
let mut db = Connection::open_in_memory().unwrap();
let mut tx = db.transaction().unwrap();
let columns = vec![
Column::raw_column("test1", "value1"),
Column::raw_column("test2", "value2"),
];
prepare_db(&mut tx, "test", true, &columns).unwrap();
insert_row(&mut tx, "test", &columns).unwrap();
tx.commit().unwrap();
let col1_type: String = db
.query_one("SELECT DISTINCT typeof(test1) from test;", [], |row| {
row.get(0)
})
.unwrap();
assert_eq!(&col1_type, "text");
let col2_type: String = db
.query_one("SELECT DISTINCT typeof(test1) from test;", [], |row| {
row.get(0)
})
.unwrap();
assert_eq!(&col2_type, "text");
let values: (String, String) = db
.query_one("SELECT test1, test2 from test;", [], |row| {
Ok((row.get(0).unwrap(), row.get(1).unwrap()))
})
.unwrap();
assert_eq!(values, ("value1".to_string(), "value2".to_string()));
}
}