use std::{
collections::HashSet, error::Error, path::PathBuf, process::ExitCode, str::FromStr,
};
use clap::{self, Parser};
use log::{debug, error};
use rusqlite::Connection;
use snafu::{ResultExt, Snafu, Whatever, ensure_whatever, whatever};
#[cfg(feature = "tx")]
use sqlitepipe::txmgmt::client_revert;
use sqlitepipe::{
StdinMode, column::Column, insert_blob, insert_files, insert_lines, insert_row, prepare_db,
};
use crate::defaults::{BLOB_COLUMN, FILENAME_COLUMN};
#[cfg(feature = "tx")]
use {
snafu::OptionExt,
sqlitepipe::txmgmt::{client_commit, client_disconnect, client_main, tx_daemon_main},
std::os::unix::net::UnixStream,
};
type Result<T> = std::result::Result<T, Whatever>;
mod defaults {
pub const BLOB_COLUMN: &'static str = "blob";
pub const LINE_COLUMN: &'static str = "line";
pub const FILENAME_COLUMN: &'static str = "filename";
pub const TABLE_NAME: &'static str = "data";
pub const DATABASE_PATH: &'static str = "stdin.data.sqlite3";
#[cfg(feature = "tx")]
pub const TX_PATH: &'static str = "sqlitepipe_tx.socket";
}
#[derive(Debug, Snafu)]
#[snafu(display("{message}"))]
struct KeyValueParseError {
message: String,
}
#[derive(Debug, Clone)]
struct KeyValueArg {
pub key: String,
pub value: String,
}
impl From<(&str, &str)> for KeyValueArg {
fn from(value: (&str, &str)) -> Self {
KeyValueArg {
key: value.0.to_string(),
value: value.1.to_string(),
}
}
}
impl FromStr for KeyValueArg {
type Err = KeyValueParseError;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
let Some(kv) = s.split_once('=') else {
return Err(KeyValueParseError {
message: "invalid key value pair; missing '='".to_string(),
});
};
Ok(kv.into())
}
}
#[derive(Debug, Default, clap::Args)]
#[group(multiple = false)]
struct StdinArgs {
#[arg(short = 'n', long)]
stdin_none: bool,
#[arg(
short = 'b',
long,
default_missing_value = defaults::BLOB_COLUMN,
require_equals = true,
num_args = 0..=1,
)]
stdin_blob: Option<String>,
#[arg(
short = 'l',
long,
default_missing_value = defaults::LINE_COLUMN,
require_equals = true,
num_args = 0..=1,
)]
stdin_lines: Option<String>,
#[arg(short = 'f', long)]
stdin_files: bool,
}
impl StdinArgs {
fn mode(&self) -> StdinMode {
if self.stdin_blob.is_some() {
StdinMode::Blob
} else if self.stdin_lines.is_some() {
StdinMode::Lines
} else if self.stdin_files {
StdinMode::Files
} else {
StdinMode::None
}
}
}
#[cfg(feature = "tx")]
#[derive(Debug, clap::Args)]
#[group(multiple = false)]
struct TransactionArgs {
#[arg(
long,
default_missing_value = defaults::TX_PATH,
require_equals = true,
num_args = 0..=1,
)]
start_transaction: Option<PathBuf>,
#[arg(
long,
default_missing_value = defaults::TX_PATH,
require_equals = true,
num_args = 0..=1,
)]
commit_transaction: Option<PathBuf>,
#[arg(
long,
default_missing_value = defaults::TX_PATH,
require_equals = true,
num_args = 0..=1,
)]
revert_transaction: Option<PathBuf>,
#[arg(
long,
default_missing_value = defaults::TX_PATH,
require_equals = true,
num_args = 0..=1,
)]
transaction: Option<PathBuf>,
}
#[derive(Debug, clap::Parser)]
#[command(version, about)]
struct Args {
#[arg(short, long)]
output_db: Option<PathBuf>,
#[arg(short, long)]
table_name: Option<String>,
#[arg(short, long)]
reset: bool,
#[arg(short, long)]
value: Vec<KeyValueArg>,
#[arg(short, long)]
json_input: bool,
#[clap(flatten)]
stdin_args: StdinArgs,
#[cfg(feature = "tx")]
#[clap(flatten)]
tx_args: TransactionArgs,
}
fn get_all_columns(
columns: &Vec<KeyValueArg>,
stdin_args: &StdinArgs,
stdin_is_json: bool,
) -> Result<Vec<Column>> {
let mut result: Vec<_> = columns
.iter()
.map(|v| Column::raw_column(&v.key, &v.value))
.collect();
let unique_names: HashSet<_> = result.iter().map(|v| v.name()).collect();
let mut new_col = None;
if let Some(stdin_blob_column) = &stdin_args.stdin_blob {
ensure_whatever!(
!unique_names.contains(&stdin_blob_column.as_str()),
"name clash with column name and blob column: {stdin_blob_column}"
);
ensure_whatever!(
new_col.is_none(),
"more than one special column not allowed"
);
let col = Column::blob_column(stdin_blob_column).set_json(stdin_is_json);
new_col = Some(vec![col]);
}
if let Some(stdin_lines_column) = &stdin_args.stdin_lines {
ensure_whatever!(
!unique_names.contains(&stdin_lines_column.as_str()),
"name clash with column name and line column: {stdin_lines_column}"
);
ensure_whatever!(
new_col.is_none(),
"more than one special column not allowed"
);
let col = Column::line_column(stdin_lines_column).set_json(stdin_is_json);
new_col = Some(vec![col]);
}
if stdin_args.stdin_files {
ensure_whatever!(
!unique_names.contains(FILENAME_COLUMN),
"name clash with column name and filename column: {FILENAME_COLUMN}"
);
ensure_whatever!(
!unique_names.contains(BLOB_COLUMN),
"name clash with column name and blob column: {BLOB_COLUMN}"
);
ensure_whatever!(
new_col.is_none(),
"more than one special column not allowed"
);
let blob_col = Column::blob_column(BLOB_COLUMN).set_json(stdin_is_json);
let filename_col = Column::line_column(FILENAME_COLUMN);
new_col = Some(vec![filename_col, blob_col])
}
if let Some(col) = new_col {
result.extend(col.into_iter());
}
Ok(result)
}
fn get_args() -> Args {
let mut args = Args::parse();
if !args.stdin_args.stdin_none
&& !args.stdin_args.stdin_files
&& args.stdin_args.stdin_blob.is_none()
&& args.stdin_args.stdin_lines.is_none()
{
args.stdin_args = StdinArgs {
stdin_blob: Some(defaults::BLOB_COLUMN.to_string()),
..Default::default()
};
}
args
}
fn normal_main(args: Args) -> Result<()> {
let raw_table_name = args
.table_name
.as_ref()
.map(|v| v.as_str())
.unwrap_or_else(|| defaults::TABLE_NAME);
let output_path = args
.output_db
.as_ref()
.map(|v| v.to_owned())
.unwrap_or_else(|| defaults::DATABASE_PATH.into());
let columns = get_all_columns(&args.value, &args.stdin_args, args.json_input)?;
debug!("found columns: {columns:?}");
if columns.is_empty() {
whatever!("No data to insert.");
}
let mut conn = Connection::open(output_path).whatever_context("sqlite")?;
let tx = conn.transaction().whatever_context("sqlite")?;
prepare_db(&tx, &raw_table_name, args.reset, &columns).whatever_context("prepare_db")?;
match args.stdin_args.mode() {
StdinMode::None => {
insert_row(&tx, &raw_table_name, &columns).whatever_context("insert_row")?;
}
StdinMode::Blob => {
let mut stdin = std::io::stdin().lock();
insert_blob(&tx, &raw_table_name, &columns, &mut stdin)
.whatever_context("insert_blob")?;
}
StdinMode::Lines => {
let stdin = std::io::stdin().lock();
insert_lines(&tx, &raw_table_name, &columns, stdin).whatever_context("insert_lines")?;
}
StdinMode::Files => {
let stdin = std::io::stdin().lock();
insert_files(&tx, &raw_table_name, &columns, stdin).whatever_context("insert_files")?;
}
}
tx.commit().whatever_context("sqlite")?;
Ok(())
}
#[cfg(feature = "tx")]
fn start_transaction(args: Args) -> Result<()> {
use fork::Fork;
let output_path = args
.output_db
.as_ref()
.map(|v| v.to_owned())
.unwrap_or_else(|| defaults::DATABASE_PATH.into());
let tx_path = args
.tx_args
.start_transaction
.whatever_context("need tx path")?;
let child = fork::daemon(true, false).whatever_context("fork")?;
if let Fork::Child = child {
if let Err(_e) = tx_daemon_main(output_path, tx_path) {
}
}
Ok(())
}
#[cfg(feature = "tx")]
fn commit_transaction(args: Args) -> Result<()> {
let tx_path = args
.tx_args
.commit_transaction
.whatever_context("need tx path")?;
let mut socket = UnixStream::connect(tx_path).whatever_context("connect socket")?;
client_commit(&mut socket).whatever_context("client commit")?;
Ok(())
}
#[cfg(feature = "tx")]
fn revert_transaction(args: Args) -> Result<()> {
let tx_path = args
.tx_args
.revert_transaction
.whatever_context("need tx path")?;
let mut socket = UnixStream::connect(tx_path).whatever_context("connect socket")?;
client_revert(&mut socket).whatever_context("client revert")?;
Ok(())
}
#[cfg(feature = "tx")]
fn transfer_transaction(args: Args) -> Result<()> {
let tx_path = args.tx_args.transaction.whatever_context("need tx path")?;
let mut socket = UnixStream::connect(tx_path).whatever_context("connect socket")?;
let raw_table_name = args
.table_name
.as_ref()
.map(|v| v.as_str())
.unwrap_or_else(|| defaults::TABLE_NAME);
let columns = get_all_columns(&args.value, &args.stdin_args, args.json_input)?;
if columns.is_empty() {
whatever!("No data to insert.");
}
client_main(
&mut socket,
raw_table_name,
&columns,
args.stdin_args.mode(),
)
.whatever_context("client main")?;
client_disconnect(&mut socket).whatever_context("disconnect")?;
Ok(())
}
#[cfg(feature = "tx")]
fn main_with_error(args: Args) -> Result<()> {
debug!("Args: {args:?}");
if args.tx_args.start_transaction.is_some() {
start_transaction(args)?;
} else if args.tx_args.commit_transaction.is_some() {
commit_transaction(args)?;
} else if args.tx_args.revert_transaction.is_some() {
revert_transaction(args)?;
} else if args.tx_args.transaction.is_some() {
transfer_transaction(args)?;
} else {
normal_main(args)?;
}
Ok(())
}
#[cfg(not(feature = "tx"))]
fn main_with_error(args: Args) -> Result<()> {
debug!("Args: {args:?}");
normal_main(args)?;
Ok(())
}
fn main() -> ExitCode {
env_logger::init();
let args = get_args();
if let Err(e) = main_with_error(args) {
if let Some(source) = e.source() {
error!("{source}");
} else {
error!("Error: {e}");
}
debug!("Detailed error: {e:?}");
ExitCode::FAILURE
} else {
ExitCode::SUCCESS
}
}