use log::debug;
use nanoserde::{DeBin, SerBin};
use rusqlite::{Connection, TEMP_DB, params_from_iter};
use snafu::{ResultExt, Snafu};
use crate::{StdinMode, column::Column, insert_row, prepare_db, sanitizing::SqlSanitize, stmt};
use std::{
io::{BufRead, Read, Write},
ops::Deref,
os::unix::net::UnixListener,
path::Path,
};
#[derive(Debug, Snafu)]
pub enum Error {
#[snafu(display("Database error (context: {context}): {source}"))]
DatabaseErrorContext {
source: rusqlite::Error,
context: &'static str,
},
#[snafu(display("Blob write error: {source}"))]
BlobWriteError { source: std::io::Error },
#[snafu(display("Error during insert"))]
InsertError { source: crate::Error },
#[snafu(display("Error during receive"))]
ReceiveMessageError { source: nanoserde::DeBinErr },
#[snafu(display("Communication error"))]
CommunicationError { source: std::io::Error },
#[snafu(display("Source read error"))]
SourceReadError { source: std::io::Error },
#[snafu(display("Disk io error"))]
DiskIoError { source: std::io::Error },
}
#[derive(Debug, Clone, PartialEq, Eq, SerBin, DeBin)]
pub enum ColumnData {
Raw { key: String, value: String },
Blob { column: String, size: usize },
}
impl From<&ColumnData> for Column {
fn from(value: &ColumnData) -> Self {
match value {
ColumnData::Raw { key, value } => Column::raw_column(key, value),
ColumnData::Blob { column, size: _ } => Column::blob_column(column),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ShareableBuffer<'a> {
Owned(Vec<u8>),
Shared(&'a [u8]),
}
impl<'a> Deref for ShareableBuffer<'a> {
type Target = [u8];
fn deref(&self) -> &Self::Target {
match self {
ShareableBuffer::Owned(items) => &items,
ShareableBuffer::Shared(items) => items,
}
}
}
impl<'a> From<Vec<u8>> for ShareableBuffer<'a> {
fn from(value: Vec<u8>) -> Self {
Self::Owned(value)
}
}
impl<'a> From<&'a [u8]> for ShareableBuffer<'a> {
fn from(value: &'a [u8]) -> Self {
Self::Shared(value)
}
}
impl<'a> SerBin for ShareableBuffer<'a> {
fn ser_bin(&self, output: &mut Vec<u8>) {
let buf: &[u8] = self;
let len = buf.len();
len.ser_bin(output);
buf.ser_bin(output);
}
}
impl<'a> DeBin for ShareableBuffer<'a> {
fn de_bin(offset: &mut usize, bytes: &[u8]) -> Result<Self, nanoserde::DeBinErr> {
let buf = Vec::<u8>::de_bin(offset, bytes)?;
Ok(ShareableBuffer::Owned(buf))
}
}
#[derive(Debug, Clone, PartialEq, Eq, SerBin, DeBin)]
pub enum Message<'a> {
Insert {
table_name: String,
data: Vec<ColumnData>,
},
BlobStart,
BlobData(ShareableBuffer<'a>),
Disconnect,
Commit,
Revert,
}
pub fn send_msg(message: &Message, buf: &mut impl Write) -> Result<(), Error> {
let mut serialized_message = message.serialize_bin();
let len = serialized_message.len() as u64;
debug!("sending {message:?} len {len:?}");
let mut ser_len = len.serialize_bin();
buf.write_all(&mut ser_len).context(CommunicationSnafu)?;
buf.write_all(&mut serialized_message)
.context(CommunicationSnafu)?;
Ok(())
}
pub fn receive_msg<'a>(read: &mut impl Read) -> Result<Message<'a>, Error> {
let mut buf = [0; 8];
read.read_exact(&mut buf).context(CommunicationSnafu)?;
let msg_len = u64::deserialize_bin(&buf).context(ReceiveMessageSnafu)?;
let mut buf = vec![0; msg_len as usize];
read.read_exact(&mut buf).context(CommunicationSnafu)?;
let msg = Message::deserialize_bin(&buf).context(ReceiveMessageSnafu)?;
Ok(msg)
}
fn handle_msg_inner(conn: &Connection, message: &Message) -> Result<(), Error> {
match message {
Message::Insert { table_name, data } => {
let mut total_size: Option<usize> = None;
let columns: Vec<Column> = data
.iter()
.inspect(|v| {
if let ColumnData::Blob { column: _, size } = v {
total_size = Some(*size);
}
})
.map(|v| v.into())
.collect();
prepare_db(conn, table_name, false, &columns).context(InsertSnafu)?;
if let Some(total_size) = total_size {
let sanitized_table_name = table_name.to_sql_sanitized_string();
let mut statement = conn
.prepare(&stmt::insert_blob(
&sanitized_table_name,
&columns,
total_size,
))
.context(DatabaseErrorContextSnafu {
context: "insert blob",
})?;
let values = columns.iter().filter_map(|n| n.value());
debug!("executing {statement:?} with values {columns:?}");
statement
.execute(params_from_iter(values))
.context(DatabaseErrorContextSnafu {
context: "insert blob",
})?;
conn.execute(stmt::drop_temporary_blob_table(), ())
.context(DatabaseErrorContextSnafu {
context: "cleanup blob",
})?;
} else {
insert_row(conn, table_name, &columns).context(InsertSnafu)?;
}
}
Message::BlobStart => {
conn.execute(stmt::create_temporary_blob_table(), [])
.context(DatabaseErrorContextSnafu {
context: "insert blob",
})?;
}
Message::BlobData(data) => {
conn.execute(&stmt::insert_zero_blob(data.len()), [])
.context(DatabaseErrorContextSnafu {
context: "insert blob",
})?;
let rowid = conn.last_insert_rowid();
let mut blob = conn
.blob_open(TEMP_DB, "blob_insert", "data", rowid, false)
.context(DatabaseErrorContextSnafu {
context: "open blob",
})?;
blob.write_all(&data).context(BlobWriteSnafu)?;
}
Message::Disconnect | Message::Commit | Message::Revert => {
panic!("Commit/Abort must be handled outside")
}
};
Ok(())
}
pub fn tx_daemon_main(
output_path: impl AsRef<Path>,
tx_path: impl AsRef<Path>,
) -> Result<(), Error> {
let tx_path = tx_path.as_ref();
let conn = Connection::open(output_path).context(DatabaseErrorContextSnafu {
context: "open connection",
})?;
let socket = UnixListener::bind(tx_path).context(CommunicationSnafu)?;
tx_main(conn, socket)?;
std::fs::remove_file(tx_path).context(DiskIoSnafu)?;
Ok(())
}
pub fn tx_main(mut conn: Connection, listener: UnixListener) -> Result<(), Error> {
let tx = conn.transaction().context(DatabaseErrorContextSnafu {
context: "create transaction",
})?;
'outer: loop {
debug!("tx_main loop waiting for connection");
let (mut stream, addr) = listener.accept().context(CommunicationSnafu)?;
debug!("tx_main connected: {addr:?}");
loop {
debug!("tx_main loop waiting for message");
let msg = receive_msg(&mut stream)?;
debug!("tx_main received: {msg:?}");
match msg {
Message::Commit => {
tx.commit()
.context(DatabaseErrorContextSnafu { context: "commit" })?;
break 'outer;
}
Message::Revert => {
tx.rollback().context(DatabaseErrorContextSnafu {
context: "rollback",
})?;
break 'outer;
}
Message::Disconnect => {
break;
}
msg => handle_msg_inner(&tx, &msg)?,
};
}
}
Ok(())
}
pub fn client_main(
sink: &mut impl Write,
table_name: &str,
columns: &[Column],
mode: StdinMode<'_>,
) -> Result<(), Error> {
match mode {
StdinMode::None => {
let data = columns
.iter()
.map(|v| ColumnData::Raw {
key: v.name().to_string(),
value: v.raw_value().to_string(),
})
.collect();
let message = Message::Insert {
table_name: table_name.to_string(),
data,
};
send_msg(&message, sink)?;
}
StdinMode::Blob(_) => {
send_msg(&Message::BlobStart, sink)?;
let mut total_size = 0usize;
{
let mut stdin = std::io::stdin().lock();
let mut buf = vec![0; stmt::BLOB_BUF_SIZE];
loop {
let count = stdin.read(&mut buf).context(SourceReadSnafu)?;
total_size += count;
debug!("read {count} to a total of {total_size}");
if count == 0 {
break;
}
send_msg(&Message::BlobData(ShareableBuffer::Shared(&buf)), sink)?;
}
}
let data = columns
.iter()
.map(|col| match col.value() {
Some(v) => ColumnData::Raw {
key: col.name().to_owned(),
value: v.to_owned(),
},
None => ColumnData::Blob {
column: col.name().to_owned(),
size: total_size,
},
})
.collect();
send_msg(
&Message::Insert {
table_name: table_name.to_owned(),
data,
},
sink,
)?;
}
StdinMode::Lines(_) => {
let stdin = std::io::stdin().lock();
for line in stdin.lines() {
let line = line.context(SourceReadSnafu)?;
let data = columns
.iter()
.map(|col| ColumnData::Raw {
key: col.name().to_owned(),
value: col.value_or_line(&line).to_owned(),
})
.collect();
send_msg(
&Message::Insert {
table_name: table_name.to_owned(),
data,
},
sink,
)?;
}
}
}
Ok(())
}
pub fn client_commit(sink: &mut impl Write) -> Result<(), Error> {
send_msg(&Message::Commit, sink)?;
Ok(())
}
pub fn client_revert(sink: &mut impl Write) -> Result<(), Error> {
send_msg(&Message::Revert, sink)?;
Ok(())
}
pub fn client_disconnect(sink: &mut impl Write) -> Result<(), Error> {
send_msg(&Message::Disconnect, sink)?;
Ok(())
}