use std::{
collections::HashSet,
io::{BufRead, Read, Write},
};
use log::debug;
use rusqlite::{Connection, TEMP_DB, params_from_iter};
use snafu::{ResultExt, Snafu};
use crate::{column::Column, sanitizing::SqlSanitize};
pub mod column;
pub mod sanitizing;
mod stmt;
#[cfg(feature = "tx")]
pub mod txmgmt;
#[derive(Debug, Snafu)]
pub enum Error {
#[snafu(display("Database error (context: {context}): {source}"))]
DatabaseErrorContext {
source: rusqlite::Error,
context: &'static str,
},
#[snafu(display("Database error: {source}"))]
WhateverDatabaseError { source: rusqlite::Error },
#[snafu(display("Read error: {source}"))]
SourceReadError { source: std::io::Error },
#[snafu(display("Blob write error: {source}"))]
BlobWriteError { source: std::io::Error },
#[snafu(display("Unknown error: {message}"))]
Unknown { message: String },
}
pub type Result<T> = std::result::Result<T, Error>;
#[derive(Debug, PartialEq, Eq)]
pub enum StdinMode {
None,
Blob,
Lines,
}
pub fn prepare_db(
conn: &Connection,
table_name: &str,
reset: bool,
columns: &[Column],
) -> Result<()> {
let sanitized_table_name = table_name.to_sql_sanitized_string();
if reset {
conn.execute(&stmt::drop_table_if_exists(&sanitized_table_name), [])
.context(DatabaseErrorContextSnafu {
context: "drop table",
})?;
}
let sanitized_column_names: Vec<_> = columns.iter().map(|v| v.sanitized_name()).collect();
conn.execute(
&stmt::create_table_if_not_exists(&sanitized_table_name, &sanitized_column_names),
[],
)
.context(DatabaseErrorContextSnafu {
context: "create table",
})?;
let mut statement =
conn.prepare(stmt::select_pragram_table_info())
.context(DatabaseErrorContextSnafu {
context: "sync columns",
})?;
let existing_columns: std::result::Result<HashSet<String>, _> = statement
.query_map([table_name], |row| row.get(0))
.context(WhateverDatabaseSnafu)?
.collect();
let existing_columns = existing_columns.context(WhateverDatabaseSnafu)?;
for (col_name, san_col_name) in columns
.iter()
.zip(sanitized_column_names.iter())
.filter(|&(n, _)| !existing_columns.contains(n.name()))
{
debug!("adding column {col_name:?}");
conn.execute(
&stmt::alter_table_add_column(&sanitized_table_name, &san_col_name),
[],
)
.context(DatabaseErrorContextSnafu {
context: "add column",
})?;
}
Ok(())
}
pub fn insert_row(conn: &Connection, table_name: &str, columns: &[Column]) -> Result<()> {
let sanitized_table_name = table_name.to_sql_sanitized_string();
let mut statement = conn
.prepare(&stmt::insert_row(&sanitized_table_name, &columns))
.context(DatabaseErrorContextSnafu {
context: "insert row",
})?;
debug!("executing {statement:?} with values {columns:?}");
statement
.execute(params_from_iter(columns.iter().map(|n| n.raw_value())))
.context(DatabaseErrorContextSnafu {
context: "insert row",
})?;
Ok(())
}
pub fn insert_blob(
conn: &Connection,
table_name: &str,
columns: &[Column],
mut source: impl Read,
) -> Result<()> {
conn.execute(stmt::create_temporary_blob_table(), [])
.context(DatabaseErrorContextSnafu {
context: "insert blob",
})?;
let mut total_size = 0usize;
{
let mut buf = vec![0; stmt::BLOB_BUF_SIZE];
loop {
let count = source.read(&mut buf).context(SourceReadSnafu)?;
total_size += count;
debug!("read {count} to a total of {total_size}");
if count == 0 {
break;
}
conn.execute(&stmt::insert_zero_blob(count), []).context(
DatabaseErrorContextSnafu {
context: "insert blob",
},
)?;
let rowid = conn.last_insert_rowid();
let mut blob = conn
.blob_open(TEMP_DB, "blob_insert", "data", rowid, false)
.context(DatabaseErrorContextSnafu {
context: "open blob",
})?;
blob.write_all(&buf[..count]).context(BlobWriteSnafu)?;
}
}
debug!("read a total of {total_size} bytes");
let sanitized_table_name = table_name.to_sql_sanitized_string();
let mut statement = conn
.prepare(&stmt::insert_blob(
&sanitized_table_name,
&columns,
total_size,
))
.context(DatabaseErrorContextSnafu {
context: "insert blob",
})?;
let values = columns.iter().filter_map(|n| n.value());
debug!("executing {statement:?} with values {columns:?}");
statement
.execute(params_from_iter(values))
.context(DatabaseErrorContextSnafu {
context: "insert blob",
})?;
Ok(())
}
pub fn insert_lines(
conn: &Connection,
table_name: &str,
columns: &[Column],
source: impl BufRead,
) -> Result<()> {
let sanitized_table_name = table_name.to_sql_sanitized_string();
let mut statement = conn
.prepare(&stmt::insert_row(&sanitized_table_name, &columns))
.context(DatabaseErrorContextSnafu {
context: "insert lines",
})?;
debug!("prepared statement {statement:?} with values {columns:?}");
for line in source.lines() {
let line = line.context(SourceReadSnafu)?;
debug!("executing statement with line {line:?}");
statement
.execute(params_from_iter(
columns.iter().map(|v| v.value_or_line(&line)),
))
.context(DatabaseErrorContextSnafu {
context: "insert lines",
})?;
}
Ok(())
}
#[cfg(test)]
mod test {
use std::io::{Cursor, Read};
use rusqlite::Connection;
use super::*;
use crate::Column;
static INIT: std::sync::Once = std::sync::Once::new();
struct InterruptibleCursor<const N: usize> {
cursor: Cursor<[u8; N]>,
interrupt_after: u64,
}
impl<const N: usize> Read for InterruptibleCursor<N> {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
if self.cursor.position() < self.interrupt_after {
let bytes_to_read =
(buf.len() as u64).min(self.interrupt_after - self.cursor.position()) as usize;
self.cursor.read(&mut buf[..bytes_to_read])
} else {
self.cursor.read(buf)
}
}
}
pub fn init() {
INIT.call_once(|| {
env_logger::builder()
.filter_level(log::LevelFilter::Debug)
.try_init()
.expect("expected env logger to init");
});
}
#[test]
fn test_insert_blob_simple() {
init();
let db = Connection::open_in_memory().unwrap();
let columns = vec![Column::blob_column("blob")];
prepare_db(&db, "test", true, &columns).unwrap();
let source_bytes = "Hello World".as_bytes();
let source = Cursor::new(source_bytes);
insert_blob(&db, "test", &columns, source).unwrap();
let blob_type: String = db
.query_one("SELECT typeof(blob) from test;", [], |row| row.get(0))
.unwrap();
assert_eq!(&blob_type, "blob");
let blob_content: String = db
.query_one("SELECT cast(blob as text) from test;", [], |row| row.get(0))
.unwrap();
assert_eq!(&blob_content, "Hello World");
}
#[test]
fn test_insert_blob_with_zero() {
init();
let db = Connection::open_in_memory().unwrap();
let columns = vec![Column::blob_column("blob")];
prepare_db(&db, "test", true, &columns).unwrap();
let mut source_bytes = "Hello World".as_bytes().to_vec();
source_bytes[5] = 0;
let test_source_bytes = source_bytes.clone();
let source = Cursor::new(source_bytes);
insert_blob(&db, "test", &columns, source).unwrap();
let blob_type: String = db
.query_one("SELECT typeof(blob) from test;", [], |row| row.get(0))
.unwrap();
assert_eq!(&blob_type, "blob");
let blob_length: i64 = db
.query_one("SELECT length(blob) from test;", [], |row| row.get(0))
.unwrap();
assert_eq!(blob_length, 11);
let blob_data: Vec<u8> = db
.query_one("SELECT blob from test;", [], |row| row.get(0))
.unwrap();
assert_eq!(&blob_data, &test_source_bytes);
}
#[test]
fn test_insert_blob_read_smaller_than_buf() {
init();
let db = Connection::open_in_memory().unwrap();
let columns = vec![Column::blob_column("blob")];
prepare_db(&db, "test", true, &columns).unwrap();
let mut source_bytes = [0u8; 8192];
for i in 0..source_bytes.len() {
source_bytes[i] = (i % 256) as u8;
}
let source = Cursor::new(source_bytes);
let interruptible_source = InterruptibleCursor {
cursor: source,
interrupt_after: 5,
};
insert_blob(&db, "test", &columns, interruptible_source).unwrap();
let blob_type: String = db
.query_one("SELECT typeof(blob) from test;", [], |row| row.get(0))
.unwrap();
assert_eq!(&blob_type, "blob");
let blob_length: i64 = db
.query_one("SELECT length(blob) from test;", [], |row| row.get(0))
.unwrap();
assert_eq!(blob_length, source_bytes.len() as i64);
let blob_data: Vec<u8> = db
.query_one("SELECT blob from test;", [], |row| row.get(0))
.unwrap();
assert_eq!(&blob_data, &source_bytes)
}
#[test]
fn test_insert_lines_simple() {
init();
let db = Connection::open_in_memory().unwrap();
let columns = vec![Column::line_column("line")];
prepare_db(&db, "test", true, &columns).unwrap();
let source_bytes = "Hello World\nHello World".as_bytes();
let source = Cursor::new(source_bytes);
insert_lines(&db, "test", &columns, source).unwrap();
let col_type: String = db
.query_one("SELECT DISTINCT typeof(line) from test;", [], |row| {
row.get(0)
})
.unwrap();
assert_eq!(&col_type, "text");
let mut stmt = db.prepare("SELECT line from test;").unwrap();
let rows: std::result::Result<Vec<String>, _> =
stmt.query_map([], |row| row.get(0)).unwrap().collect();
let rows = rows.unwrap();
assert_eq!(
rows,
vec!["Hello World".to_string(), "Hello World".to_string()]
);
}
#[test]
fn test_insert_raw_values() {
init();
let db = Connection::open_in_memory().unwrap();
let columns = vec![
Column::raw_column("test1", "value1"),
Column::raw_column("test2", "value2"),
];
prepare_db(&db, "test", true, &columns).unwrap();
insert_row(&db, "test", &columns).unwrap();
let col1_type: String = db
.query_one("SELECT DISTINCT typeof(test1) from test;", [], |row| {
row.get(0)
})
.unwrap();
assert_eq!(&col1_type, "text");
let col2_type: String = db
.query_one("SELECT DISTINCT typeof(test1) from test;", [], |row| {
row.get(0)
})
.unwrap();
assert_eq!(&col2_type, "text");
let values: (String, String) = db
.query_one("SELECT test1, test2 from test;", [], |row| {
Ok((row.get(0).unwrap(), row.get(1).unwrap()))
})
.unwrap();
assert_eq!(values, ("value1".to_string(), "value2".to_string()));
}
#[test]
fn test_insert_jsonb_value() {
init();
let db = Connection::open_in_memory().unwrap();
let column = Column::blob_column("json").set_json(true);
let columns = vec![column];
prepare_db(&db, "test", true, &columns).unwrap();
let source_bytes = "{\"hello\": [1, 2, \"World\"]}".as_bytes();
let source = Cursor::new(source_bytes);
insert_blob(&db, "test", &columns, source).unwrap();
let blob_type: String = db
.query_one("SELECT typeof(json) from test;", [], |row| row.get(0))
.unwrap();
assert_eq!(&blob_type, "blob");
let json_value: String = db
.query_one("SELECT json ->> '$.hello[2]' from test;", [], |row| {
row.get(0)
})
.unwrap();
assert_eq!(&json_value, "World");
}
#[test]
fn test_insert_jsonb_lines() {
init();
let db = Connection::open_in_memory().unwrap();
let column = Column::line_column("json").set_json(true);
let columns = vec![column];
prepare_db(&db, "test", true, &columns).unwrap();
let source_bytes =
"{\"hello\": [1, 2, \"World\"]}\n{\"hello\": [1, 2, \"World2\"]}".as_bytes();
let source = Cursor::new(source_bytes);
insert_lines(&db, "test", &columns, source).unwrap();
let blob_type: String = db
.query_one("SELECT DISTINCT typeof(json) from test;", [], |row| {
row.get(0)
})
.unwrap();
assert_eq!(&blob_type, "blob");
let count: i64 = db
.query_one("SELECT count(json) from test;", [], |row| row.get(0))
.unwrap();
assert_eq!(count, 2);
let json_value: String = db
.query_one(
"SELECT group_concat(json ->> '$.hello[2]') from test;",
[],
|row| row.get(0),
)
.unwrap();
assert_eq!(&json_value, "World,World2");
}
#[test]
fn test_insert_invalid_jsonb_value() {
init();
let db = Connection::open_in_memory().unwrap();
let column = Column::blob_column("json").set_json(true);
let columns = vec![column];
prepare_db(&db, "test", true, &columns).unwrap();
let source_bytes = "{\"hello\": [invalid}".as_bytes();
let source = Cursor::new(source_bytes);
let err = insert_blob(&db, "test", &columns, source).expect_err("should throw");
let Error::DatabaseErrorContext { source, context: _ } = err else {
panic!("expected database error")
};
debug!("sqlite error: {source:?}");
assert!(source.sqlite_error().is_some());
}
}