use std::fs;
use std::io::Write;
use anyhow::{Context, Result, bail};
use clap::{Args, Parser, Subcommand};
use datapress_client::blocking::Client;
use datapress_client::{Aggregation, OrderBy, Predicate, QueryRequest};
use serde_json::Value as JsonValue;
#[derive(Debug, Parser)]
#[command(name = "datapress-cli", version, about, long_about = None)]
struct Cli {
#[command(flatten)]
conn: ConnectionArgs,
#[command(subcommand)]
command: Command,
}
#[derive(Debug, Args)]
struct ConnectionArgs {
#[arg(
long,
global = true,
env = "DATAPRESS_URL",
default_value = "http://127.0.0.1:8000"
)]
url: String,
#[arg(long, global = true, default_value = "/api/v1")]
api_base: String,
#[arg(long, global = true, env = "DATAPRESS_TOKEN")]
bearer_token: Option<String>,
#[arg(long, global = true, env = "DATAPRESS_ADMIN_TOKEN")]
admin_token: Option<String>,
#[arg(long, global = true)]
timeout: Option<f64>,
#[arg(long, global = true)]
pretty: bool,
}
impl ConnectionArgs {
fn build(&self) -> Result<Client> {
let mut builder = Client::builder(&self.url).api_base(&self.api_base);
if let Some(t) = &self.bearer_token {
builder = builder.bearer_token(t);
}
if let Some(t) = &self.admin_token {
builder = builder.admin_token(t);
}
if let Some(secs) = self.timeout {
builder = builder.timeout(std::time::Duration::from_secs_f64(secs));
}
let async_client = builder.build().context("failed to build client")?;
Client::from_async(async_client).context("failed to start runtime")
}
}
#[derive(Debug, Subcommand)]
enum Command {
Datasets,
Schema {
dataset: String,
},
Count {
dataset: String,
#[arg(long = "where", value_name = "col:op[:val]")]
filters: Vec<String>,
},
Query(QueryArgs),
Sql {
statement: String,
#[arg(long)]
max_rows: Option<u64>,
},
Reload {
dataset: String,
},
Health,
Ready,
}
#[derive(Debug, Args)]
struct QueryArgs {
dataset: String,
#[arg(long = "select", value_name = "col,col,…", value_delimiter = ',')]
columns: Vec<String>,
#[arg(long = "where", value_name = "col:op[:val]")]
filters: Vec<String>,
#[arg(long = "group-by", value_name = "col,col,…", value_delimiter = ',')]
group_by: Vec<String>,
#[arg(long = "agg", value_name = "op:col[:alias]")]
aggregations: Vec<String>,
#[arg(long = "having", value_name = "col:op[:val]")]
having: Vec<String>,
#[arg(long)]
distinct: bool,
#[arg(long = "order-by", value_name = "col[:dir]")]
order_by: Vec<String>,
#[arg(long)]
limit: Option<u64>,
#[arg(long)]
page: Option<u64>,
#[arg(long)]
page_size: Option<u64>,
#[arg(long, value_name = "PATH")]
arrow_out: Option<String>,
#[arg(long)]
table: bool,
}
fn main() -> Result<()> {
let cli = Cli::parse();
let client = cli.conn.build()?;
let pretty = cli.conn.pretty;
match cli.command {
Command::Datasets => {
let names = client.datasets()?;
print_json(&JsonValue::from(names), pretty)?;
}
Command::Schema { dataset } => {
print_json(&client.schema(&dataset)?, pretty)?;
}
Command::Count { dataset, filters } => {
let preds = parse_predicates(&filters)?;
let n = client.count(&dataset, &preds)?;
println!("{n}");
}
Command::Query(args) => run_query(&client, args, pretty)?,
Command::Sql {
statement,
max_rows,
} => {
let resp = client.sql(statement, max_rows)?;
print_json(&serde_json::to_value(resp)?, pretty)?;
}
Command::Reload { dataset } => {
print_json(&client.reload(&dataset)?, pretty)?;
}
Command::Health => print_json(&client.healthz()?, pretty)?,
Command::Ready => print_json(&client.readyz()?, pretty)?,
}
Ok(())
}
fn run_query(client: &Client, args: QueryArgs, pretty: bool) -> Result<()> {
let mut builder = QueryRequest::builder();
if !args.columns.is_empty() {
builder = builder.columns(args.columns);
}
for p in parse_predicates(&args.filters)? {
builder = builder.predicate(p);
}
if !args.group_by.is_empty() {
builder = builder.group_by(args.group_by);
}
for a in &args.aggregations {
builder = builder.aggregation(parse_aggregation(a)?);
}
for h in parse_predicates(&args.having)? {
builder = builder.having(h);
}
if args.distinct {
builder = builder.distinct(true);
}
for o in &args.order_by {
builder = builder.order_by(parse_order_by(o));
}
if let Some(n) = args.limit {
builder = builder.limit(n);
}
if let Some(n) = args.page {
builder = builder.page(n);
}
if let Some(n) = args.page_size {
builder = builder.page_size(n);
}
let request = builder.build();
if let Some(path) = args.arrow_out {
let bytes = client.query_arrow_bytes(&args.dataset, &request)?;
write_bytes(&path, &bytes)?;
eprintln!("wrote {} bytes of Arrow IPC to {path}", bytes.len());
} else if args.table {
let batches = client.query_arrow(&args.dataset, &request)?;
let rendered = arrow::util::pretty::pretty_format_batches(&batches)
.context("failed to format record batches")?;
println!("{rendered}");
} else {
let resp = client.query_json(&args.dataset, &request)?;
print_json(&serde_json::to_value(resp)?, pretty)?;
}
Ok(())
}
fn parse_predicates(specs: &[String]) -> Result<Vec<Predicate>> {
specs.iter().map(|s| parse_predicate(s)).collect()
}
fn parse_predicate(spec: &str) -> Result<Predicate> {
let mut parts = spec.splitn(3, ':');
let col = parts
.next()
.filter(|s| !s.is_empty())
.with_context(|| format!("predicate `{spec}`: missing column"))?;
let op = parts
.next()
.filter(|s| !s.is_empty())
.with_context(|| format!("predicate `{spec}`: missing operator"))?;
match parts.next() {
Some(val) => Ok(Predicate::new(col, op, parse_value(val))),
None => {
if matches!(op, "is_null" | "is_not_null") {
Ok(Predicate::unary(col, op))
} else {
bail!("predicate `{spec}`: operator `{op}` requires a value (col:op:val)")
}
}
}
}
fn parse_aggregation(spec: &str) -> Result<Aggregation> {
let parts: Vec<&str> = spec.split(':').collect();
match parts.as_slice() {
["count"] => Ok(Aggregation::count(None)),
["count", alias] => Ok(Aggregation::count(Some(alias))),
[op, col] => Ok(Aggregation::over(*op, *col, None)),
[op, col, alias] => Ok(Aggregation::over(*op, *col, Some(alias))),
_ => bail!("aggregation `{spec}`: expected `op:col[:alias]` or `count[:alias]`"),
}
}
fn parse_order_by(spec: &str) -> OrderBy {
match spec.split_once(':') {
Some((col, "desc")) => OrderBy::desc(col),
Some((col, _)) => OrderBy::asc(col),
None => OrderBy::asc(spec),
}
}
fn parse_value(raw: &str) -> JsonValue {
serde_json::from_str(raw).unwrap_or_else(|_| JsonValue::String(raw.to_string()))
}
fn print_json(value: &JsonValue, pretty: bool) -> Result<()> {
let s = if pretty {
serde_json::to_string_pretty(value)?
} else {
serde_json::to_string(value)?
};
println!("{s}");
Ok(())
}
fn write_bytes(path: &str, bytes: &[u8]) -> Result<()> {
if path == "-" {
std::io::stdout()
.write_all(bytes)
.context("failed to write to stdout")
} else {
fs::write(path, bytes).with_context(|| format!("failed to write {path}"))
}
}