sqlitepipe 0.2.0

A simple tool for piping the output of a command into sqlite databases.
Documentation
//! module for managing a transaction daemon

use log::debug;
use nanoserde::{DeBin, SerBin};
use rusqlite::{Connection, TEMP_DB, params_from_iter};
use snafu::{ResultExt, Snafu};

use crate::{StdinMode, column::Column, insert_row, prepare_db, sanitizing::SqlSanitize, stmt};

use std::{
    io::{BufRead, Read, Write},
    ops::Deref,
    os::unix::net::UnixListener,
    path::Path,
};

#[derive(Debug, Snafu)]
pub enum Error {
    /// A rusqlite error occurred within a specific context.
    #[snafu(display("Database error (context: {context}): {source}"))]
    DatabaseErrorContext {
        source: rusqlite::Error,
        context: &'static str,
    },

    /// Failed to write chunked data to a SQLite BLOB.
    #[snafu(display("Blob write error: {source}"))]
    BlobWriteError { source: std::io::Error },

    /// An error during an insert operation.
    #[snafu(display("Error during insert"))]
    InsertError { source: crate::Error },

    /// An error when receiving a packet.
    #[snafu(display("Error during receive"))]
    ReceiveMessageError { source: nanoserde::DeBinErr },

    /// Generic communication error.
    #[snafu(display("Communication error"))]
    CommunicationError { source: std::io::Error },

    /// Error while reading from a source stream.
    #[snafu(display("Source read error"))]
    SourceReadError { source: std::io::Error },

    /// Disk I/O error
    #[snafu(display("Disk io error"))]
    DiskIoError { source: std::io::Error },
}

/// Column data
#[derive(Debug, Clone, PartialEq, Eq, SerBin, DeBin)]
pub enum ColumnData {
    /// A raw key/value pair.
    Raw { key: String, value: String },
    /// A placeholder for blob data
    Blob { column: String, size: usize },
}

impl From<&ColumnData> for Column {
    fn from(value: &ColumnData) -> Self {
        match value {
            ColumnData::Raw { key, value } => Column::raw_column(key, value),
            ColumnData::Blob { column, size: _ } => Column::blob_column(column),
        }
    }
}

/// A helper enum for serializing/deserializing byte buffers that can be shared or owned
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ShareableBuffer<'a> {
    Owned(Vec<u8>),
    Shared(&'a [u8]),
}

impl<'a> Deref for ShareableBuffer<'a> {
    type Target = [u8];

    fn deref(&self) -> &Self::Target {
        match self {
            ShareableBuffer::Owned(items) => &items,
            ShareableBuffer::Shared(items) => items,
        }
    }
}

impl<'a> From<Vec<u8>> for ShareableBuffer<'a> {
    fn from(value: Vec<u8>) -> Self {
        Self::Owned(value)
    }
}

impl<'a> From<&'a [u8]> for ShareableBuffer<'a> {
    fn from(value: &'a [u8]) -> Self {
        Self::Shared(value)
    }
}

impl<'a> SerBin for ShareableBuffer<'a> {
    fn ser_bin(&self, output: &mut Vec<u8>) {
        let buf: &[u8] = self;
        let len = buf.len();
        len.ser_bin(output);
        buf.ser_bin(output);
    }
}

impl<'a> DeBin for ShareableBuffer<'a> {
    fn de_bin(offset: &mut usize, bytes: &[u8]) -> Result<Self, nanoserde::DeBinErr> {
        let buf = Vec::<u8>::de_bin(offset, bytes)?;
        Ok(ShareableBuffer::Owned(buf))
    }
}

/// A IPC message to communicate between daemon and client
#[derive(Debug, Clone, PartialEq, Eq, SerBin, DeBin)]
pub enum Message<'a> {
    /// Insert data into a table.
    Insert {
        table_name: String,
        data: Vec<ColumnData>,
    },
    /// Mark start of transferring a blob.
    BlobStart,
    /// Blob data.
    BlobData(ShareableBuffer<'a>),
    /// Client wants to disconnect.
    Disconnect,
    /// Commit transaction.
    Commit,
    /// Revert transaction.
    Revert,
}

/// Send a mesage into the given sink.
pub fn send_msg(message: &Message, buf: &mut impl Write) -> Result<(), Error> {
    let mut serialized_message = message.serialize_bin();
    let len = serialized_message.len() as u64;
    debug!("sending {message:?} len {len:?}");
    let mut ser_len = len.serialize_bin();
    buf.write_all(&mut ser_len).context(CommunicationSnafu)?;
    buf.write_all(&mut serialized_message)
        .context(CommunicationSnafu)?;
    Ok(())
}

/// Read a message from the given source.
pub fn receive_msg<'a>(read: &mut impl Read) -> Result<Message<'a>, Error> {
    let mut buf = [0; 8];
    read.read_exact(&mut buf).context(CommunicationSnafu)?;
    let msg_len = u64::deserialize_bin(&buf).context(ReceiveMessageSnafu)?;
    let mut buf = vec![0; msg_len as usize];
    read.read_exact(&mut buf).context(CommunicationSnafu)?;
    let msg = Message::deserialize_bin(&buf).context(ReceiveMessageSnafu)?;
    Ok(msg)
}

fn handle_msg_inner(conn: &Connection, message: &Message) -> Result<(), Error> {
    match message {
        Message::Insert { table_name, data } => {
            let mut total_size: Option<usize> = None;
            let columns: Vec<Column> = data
                .iter()
                .inspect(|v| {
                    if let ColumnData::Blob { column: _, size } = v {
                        total_size = Some(*size);
                    }
                })
                .map(|v| v.into())
                .collect();

            prepare_db(conn, table_name, false, &columns).context(InsertSnafu)?;

            if let Some(total_size) = total_size {
                let sanitized_table_name = table_name.to_sql_sanitized_string();

                let mut statement = conn
                    .prepare(&stmt::insert_blob(
                        &sanitized_table_name,
                        &columns,
                        total_size,
                    ))
                    .context(DatabaseErrorContextSnafu {
                        context: "insert blob",
                    })?;

                let values = columns.iter().filter_map(|n| n.value());

                debug!("executing {statement:?} with values {columns:?}");

                statement
                    .execute(params_from_iter(values))
                    .context(DatabaseErrorContextSnafu {
                        context: "insert blob",
                    })?;

                conn.execute(stmt::drop_temporary_blob_table(), ())
                    .context(DatabaseErrorContextSnafu {
                        context: "cleanup blob",
                    })?;
            } else {
                insert_row(conn, table_name, &columns).context(InsertSnafu)?;
            }
        }
        Message::BlobStart => {
            conn.execute(stmt::create_temporary_blob_table(), [])
                .context(DatabaseErrorContextSnafu {
                    context: "insert blob",
                })?;
        }
        Message::BlobData(data) => {
            conn.execute(&stmt::insert_zero_blob(data.len()), [])
                .context(DatabaseErrorContextSnafu {
                    context: "insert blob",
                })?;
            let rowid = conn.last_insert_rowid();
            let mut blob = conn
                .blob_open(TEMP_DB, "blob_insert", "data", rowid, false)
                .context(DatabaseErrorContextSnafu {
                    context: "open blob",
                })?;

            blob.write_all(&data).context(BlobWriteSnafu)?;
        }
        Message::Disconnect | Message::Commit | Message::Revert => {
            panic!("Commit/Abort must be handled outside")
        }
    };
    Ok(())
}

/// Main entry point for a daemon.
/// Opens the connection and connects to the socket.
pub fn tx_daemon_main(
    output_path: impl AsRef<Path>,
    tx_path: impl AsRef<Path>,
) -> Result<(), Error> {
    let tx_path = tx_path.as_ref();
    let conn = Connection::open(output_path).context(DatabaseErrorContextSnafu {
        context: "open connection",
    })?;

    let socket = UnixListener::bind(tx_path).context(CommunicationSnafu)?;

    tx_main(conn, socket)?;

    std::fs::remove_file(tx_path).context(DiskIoSnafu)?;
    Ok(())
}

/// Main entry point for tx management.
/// Expects the connection and socket to be already opened.
pub fn tx_main(mut conn: Connection, listener: UnixListener) -> Result<(), Error> {
    let tx = conn.transaction().context(DatabaseErrorContextSnafu {
        context: "create transaction",
    })?;

    'outer: loop {
        debug!("tx_main loop waiting for connection");
        let (mut stream, addr) = listener.accept().context(CommunicationSnafu)?;
        debug!("tx_main connected: {addr:?}");

        loop {
            debug!("tx_main loop waiting for message");
            let msg = receive_msg(&mut stream)?;
            debug!("tx_main received: {msg:?}");

            match msg {
                Message::Commit => {
                    tx.commit()
                        .context(DatabaseErrorContextSnafu { context: "commit" })?;
                    break 'outer;
                }
                Message::Revert => {
                    tx.rollback().context(DatabaseErrorContextSnafu {
                        context: "rollback",
                    })?;
                    break 'outer;
                }
                Message::Disconnect => {
                    break;
                }
                msg => handle_msg_inner(&tx, &msg)?,
            };
        }
    }

    Ok(())
}

/// Main entrypoint for clients sending data.
pub fn client_main(
    sink: &mut impl Write,
    table_name: &str,
    columns: &[Column],
    mode: StdinMode<'_>,
) -> Result<(), Error> {
    match mode {
        StdinMode::None => {
            let data = columns
                .iter()
                .map(|v| ColumnData::Raw {
                    key: v.name().to_string(),
                    value: v.raw_value().to_string(),
                })
                .collect();
            let message = Message::Insert {
                table_name: table_name.to_string(),
                data,
            };
            send_msg(&message, sink)?;
        }
        StdinMode::Blob(_) => {
            send_msg(&Message::BlobStart, sink)?;
            let mut total_size = 0usize;
            {
                let mut stdin = std::io::stdin().lock();
                let mut buf = vec![0; stmt::BLOB_BUF_SIZE];

                loop {
                    let count = stdin.read(&mut buf).context(SourceReadSnafu)?;
                    total_size += count;
                    debug!("read {count} to a total of {total_size}");
                    if count == 0 {
                        break;
                    }

                    send_msg(&Message::BlobData(ShareableBuffer::Shared(&buf)), sink)?;
                }
            }

            let data = columns
                .iter()
                .map(|col| match col.value() {
                    Some(v) => ColumnData::Raw {
                        key: col.name().to_owned(),
                        value: v.to_owned(),
                    },
                    None => ColumnData::Blob {
                        column: col.name().to_owned(),
                        size: total_size,
                    },
                })
                .collect();

            send_msg(
                &Message::Insert {
                    table_name: table_name.to_owned(),
                    data,
                },
                sink,
            )?;
        }
        StdinMode::Lines(_) => {
            let stdin = std::io::stdin().lock();
            for line in stdin.lines() {
                let line = line.context(SourceReadSnafu)?;
                let data = columns
                    .iter()
                    .map(|col| ColumnData::Raw {
                        key: col.name().to_owned(),
                        value: col.value_or_line(&line).to_owned(),
                    })
                    .collect();
                send_msg(
                    &Message::Insert {
                        table_name: table_name.to_owned(),
                        data,
                    },
                    sink,
                )?;
            }
        }
    }

    Ok(())
}

/// Entrypoint for committing transactions.
pub fn client_commit(sink: &mut impl Write) -> Result<(), Error> {
    send_msg(&Message::Commit, sink)?;
    Ok(())
}

/// Entrypoint for reverting transactions.
pub fn client_revert(sink: &mut impl Write) -> Result<(), Error> {
    send_msg(&Message::Revert, sink)?;
    Ok(())
}

/// Entrypoint for disconnecting.
pub fn client_disconnect(sink: &mut impl Write) -> Result<(), Error> {
    send_msg(&Message::Disconnect, sink)?;
    Ok(())
}