use rust_decimal::Decimal;
use rustledger_core::NaiveDate;
use std::fmt;
#[derive(Debug, Clone, PartialEq)]
pub enum Query {
Select(Box<SelectQuery>),
Journal(JournalQuery),
Balances(BalancesQuery),
Print(PrintQuery),
CreateTable(CreateTableStmt),
Insert(InsertStmt),
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ColumnDef {
pub name: String,
pub type_hint: Option<String>,
}
#[derive(Debug, Clone, PartialEq)]
pub struct CreateTableStmt {
pub table_name: String,
pub columns: Vec<ColumnDef>,
pub as_select: Option<Box<SelectQuery>>,
}
#[derive(Debug, Clone, PartialEq)]
pub struct InsertStmt {
pub table_name: String,
pub columns: Option<Vec<String>>,
pub source: InsertSource,
}
#[derive(Debug, Clone, PartialEq)]
pub enum InsertSource {
Values(Vec<Vec<Expr>>),
Select(Box<SelectQuery>),
}
#[derive(Debug, Clone, PartialEq)]
pub struct SelectQuery {
pub distinct: bool,
pub targets: Vec<Target>,
pub from: Option<FromClause>,
pub where_clause: Option<Expr>,
pub group_by: Option<Vec<Expr>>,
pub having: Option<Expr>,
pub pivot_by: Option<Vec<Expr>>,
pub order_by: Option<Vec<OrderSpec>>,
pub limit: Option<u64>,
}
#[derive(Debug, Clone, PartialEq)]
pub struct Target {
pub expr: Expr,
pub alias: Option<String>,
}
#[derive(Debug, Clone, PartialEq)]
pub struct FromClause {
pub open_on: Option<NaiveDate>,
pub close_on: Option<NaiveDate>,
pub clear: bool,
pub filter: Option<Expr>,
pub subquery: Option<Box<SelectQuery>>,
pub table_name: Option<String>,
}
#[derive(Debug, Clone, PartialEq)]
pub struct OrderSpec {
pub expr: Expr,
pub direction: SortDirection,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum SortDirection {
#[default]
Asc,
Desc,
}
#[derive(Debug, Clone, PartialEq)]
pub struct JournalQuery {
pub account_pattern: String,
pub at_function: Option<String>,
pub from: Option<FromClause>,
}
#[derive(Debug, Clone, PartialEq)]
pub struct BalancesQuery {
pub at_function: Option<String>,
pub from: Option<FromClause>,
}
#[derive(Debug, Clone, PartialEq)]
pub struct PrintQuery {
pub from: Option<FromClause>,
}
#[derive(Debug, Clone, PartialEq)]
pub enum Expr {
Wildcard,
Column(String),
Literal(Literal),
Function(FunctionCall),
Window(WindowFunction),
BinaryOp(Box<BinaryOp>),
UnaryOp(Box<UnaryOp>),
Paren(Box<Self>),
Between {
value: Box<Self>,
low: Box<Self>,
high: Box<Self>,
},
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Literal {
String(String),
Number(Decimal),
Integer(i64),
Date(NaiveDate),
Boolean(bool),
Null,
}
#[derive(Debug, Clone, PartialEq)]
pub struct FunctionCall {
pub name: String,
pub args: Vec<Expr>,
}
#[derive(Debug, Clone, PartialEq)]
pub struct WindowFunction {
pub name: String,
pub args: Vec<Expr>,
pub over: WindowSpec,
}
#[derive(Debug, Clone, PartialEq, Default)]
pub struct WindowSpec {
pub partition_by: Option<Vec<Expr>>,
pub order_by: Option<Vec<OrderSpec>>,
}
#[derive(Debug, Clone, PartialEq)]
pub struct BinaryOp {
pub left: Expr,
pub op: BinaryOperator,
pub right: Expr,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BinaryOperator {
Eq,
Ne,
Lt,
Le,
Gt,
Ge,
Regex,
NotRegex,
In,
NotIn,
And,
Or,
Add,
Sub,
Mul,
Div,
Mod,
}
#[derive(Debug, Clone, PartialEq)]
pub struct UnaryOp {
pub op: UnaryOperator,
pub operand: Expr,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum UnaryOperator {
Not,
Neg,
IsNull,
IsNotNull,
}
impl SelectQuery {
pub const fn new(targets: Vec<Target>) -> Self {
Self {
distinct: false,
targets,
from: None,
where_clause: None,
group_by: None,
having: None,
pivot_by: None,
order_by: None,
limit: None,
}
}
pub const fn distinct(mut self) -> Self {
self.distinct = true;
self
}
pub fn from(mut self, from: FromClause) -> Self {
self.from = Some(from);
self
}
pub fn where_clause(mut self, expr: Expr) -> Self {
self.where_clause = Some(expr);
self
}
pub fn group_by(mut self, exprs: Vec<Expr>) -> Self {
self.group_by = Some(exprs);
self
}
pub fn having(mut self, expr: Expr) -> Self {
self.having = Some(expr);
self
}
pub fn pivot_by(mut self, exprs: Vec<Expr>) -> Self {
self.pivot_by = Some(exprs);
self
}
pub fn order_by(mut self, specs: Vec<OrderSpec>) -> Self {
self.order_by = Some(specs);
self
}
pub const fn limit(mut self, n: u64) -> Self {
self.limit = Some(n);
self
}
}
impl Target {
pub const fn new(expr: Expr) -> Self {
Self { expr, alias: None }
}
pub fn with_alias(expr: Expr, alias: impl Into<String>) -> Self {
Self {
expr,
alias: Some(alias.into()),
}
}
}
impl FromClause {
pub const fn new() -> Self {
Self {
open_on: None,
close_on: None,
clear: false,
filter: None,
subquery: None,
table_name: None,
}
}
pub fn from_subquery(query: SelectQuery) -> Self {
Self {
open_on: None,
close_on: None,
clear: false,
filter: None,
subquery: Some(Box::new(query)),
table_name: None,
}
}
pub fn from_table(name: impl Into<String>) -> Self {
Self {
open_on: None,
close_on: None,
clear: false,
filter: None,
subquery: None,
table_name: Some(name.into()),
}
}
pub const fn open_on(mut self, date: NaiveDate) -> Self {
self.open_on = Some(date);
self
}
pub const fn close_on(mut self, date: NaiveDate) -> Self {
self.close_on = Some(date);
self
}
pub const fn clear(mut self) -> Self {
self.clear = true;
self
}
pub fn filter(mut self, expr: Expr) -> Self {
self.filter = Some(expr);
self
}
pub fn subquery(mut self, query: SelectQuery) -> Self {
self.subquery = Some(Box::new(query));
self
}
}
impl Default for FromClause {
fn default() -> Self {
Self::new()
}
}
impl Expr {
pub fn column(name: impl Into<String>) -> Self {
Self::Column(name.into())
}
pub fn string(s: impl Into<String>) -> Self {
Self::Literal(Literal::String(s.into()))
}
pub const fn number(n: Decimal) -> Self {
Self::Literal(Literal::Number(n))
}
pub const fn integer(n: i64) -> Self {
Self::Literal(Literal::Integer(n))
}
pub const fn date(d: NaiveDate) -> Self {
Self::Literal(Literal::Date(d))
}
pub const fn boolean(b: bool) -> Self {
Self::Literal(Literal::Boolean(b))
}
pub const fn null() -> Self {
Self::Literal(Literal::Null)
}
pub fn function(name: impl Into<String>, args: Vec<Self>) -> Self {
Self::Function(FunctionCall {
name: name.into(),
args,
})
}
pub fn binary(left: Self, op: BinaryOperator, right: Self) -> Self {
Self::BinaryOp(Box::new(BinaryOp { left, op, right }))
}
pub fn unary(op: UnaryOperator, operand: Self) -> Self {
Self::UnaryOp(Box::new(UnaryOp { op, operand }))
}
pub fn between(value: Self, low: Self, high: Self) -> Self {
Self::Between {
value: Box::new(value),
low: Box::new(low),
high: Box::new(high),
}
}
}
impl OrderSpec {
pub const fn asc(expr: Expr) -> Self {
Self {
expr,
direction: SortDirection::Asc,
}
}
pub const fn desc(expr: Expr) -> Self {
Self {
expr,
direction: SortDirection::Desc,
}
}
}
impl fmt::Display for Expr {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Wildcard => write!(f, "*"),
Self::Column(name) => write!(f, "{name}"),
Self::Literal(lit) => write!(f, "{lit}"),
Self::Function(func) => {
write!(f, "{}(", func.name)?;
for (i, arg) in func.args.iter().enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{arg}")?;
}
write!(f, ")")
}
Self::Window(wf) => {
write!(f, "{}(", wf.name)?;
for (i, arg) in wf.args.iter().enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{arg}")?;
}
write!(f, ") OVER ()")
}
Self::BinaryOp(op) => write!(f, "({} {} {})", op.left, op.op, op.right),
Self::UnaryOp(op) => {
match op.op {
UnaryOperator::IsNull => write!(f, "{} IS NULL", op.operand),
UnaryOperator::IsNotNull => write!(f, "{} IS NOT NULL", op.operand),
_ => write!(f, "{}{}", op.op, op.operand),
}
}
Self::Paren(inner) => write!(f, "({inner})"),
Self::Between { value, low, high } => {
write!(f, "{value} BETWEEN {low} AND {high}")
}
}
}
}
impl fmt::Display for Literal {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::String(s) => write!(f, "\"{s}\""),
Self::Number(n) => write!(f, "{n}"),
Self::Integer(n) => write!(f, "{n}"),
Self::Date(d) => write!(f, "{d}"),
Self::Boolean(b) => write!(f, "{b}"),
Self::Null => write!(f, "NULL"),
}
}
}
impl fmt::Display for BinaryOperator {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let s = match self {
Self::Eq => "=",
Self::Ne => "!=",
Self::Lt => "<",
Self::Le => "<=",
Self::Gt => ">",
Self::Ge => ">=",
Self::Regex => "~",
Self::NotRegex => "!~",
Self::In => "IN",
Self::NotIn => "NOT IN",
Self::And => "AND",
Self::Or => "OR",
Self::Add => "+",
Self::Sub => "-",
Self::Mul => "*",
Self::Div => "/",
Self::Mod => "%",
};
write!(f, "{s}")
}
}
impl fmt::Display for UnaryOperator {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let s = match self {
Self::Not => "NOT ",
Self::Neg => "-",
Self::IsNull => " IS NULL",
Self::IsNotNull => " IS NOT NULL",
};
write!(f, "{s}")
}
}
#[cfg(test)]
mod tests {
use super::*;
use rust_decimal_macros::dec;
#[test]
fn test_expr_display_wildcard() {
assert_eq!(Expr::Wildcard.to_string(), "*");
}
#[test]
fn test_expr_display_column() {
assert_eq!(Expr::Column("account".to_string()).to_string(), "account");
}
#[test]
fn test_expr_display_literals() {
assert_eq!(Expr::string("hello").to_string(), "\"hello\"");
assert_eq!(Expr::integer(42).to_string(), "42");
assert_eq!(Expr::number(dec!(3.14)).to_string(), "3.14");
assert_eq!(Expr::boolean(true).to_string(), "true");
assert_eq!(Expr::null().to_string(), "NULL");
}
#[test]
fn test_expr_display_date() {
let date = NaiveDate::from_ymd_opt(2024, 1, 15).unwrap();
assert_eq!(Expr::date(date).to_string(), "2024-01-15");
}
#[test]
fn test_expr_display_function_no_args() {
let func = Expr::function("now", vec![]);
assert_eq!(func.to_string(), "now()");
}
#[test]
fn test_expr_display_function_one_arg() {
let func = Expr::function("account_sortkey", vec![Expr::column("account")]);
assert_eq!(func.to_string(), "account_sortkey(account)");
}
#[test]
fn test_expr_display_function_multiple_args() {
let func = Expr::function(
"coalesce",
vec![Expr::column("a"), Expr::column("b"), Expr::integer(0)],
);
assert_eq!(func.to_string(), "coalesce(a, b, 0)");
}
#[test]
fn test_expr_display_window() {
let wf = Expr::Window(WindowFunction {
name: "row_number".to_string(),
args: vec![],
over: WindowSpec::default(),
});
assert_eq!(wf.to_string(), "row_number() OVER ()");
}
#[test]
fn test_expr_display_window_with_args() {
let wf = Expr::Window(WindowFunction {
name: "sum".to_string(),
args: vec![Expr::column("amount")],
over: WindowSpec::default(),
});
assert_eq!(wf.to_string(), "sum(amount) OVER ()");
}
#[test]
fn test_expr_display_binary_op() {
let expr = Expr::binary(Expr::column("a"), BinaryOperator::Add, Expr::integer(1));
assert_eq!(expr.to_string(), "(a + 1)");
}
#[test]
fn test_expr_display_unary_not() {
let expr = Expr::unary(UnaryOperator::Not, Expr::column("flag"));
assert_eq!(expr.to_string(), "NOT flag");
}
#[test]
fn test_expr_display_unary_neg() {
let expr = Expr::unary(UnaryOperator::Neg, Expr::column("x"));
assert_eq!(expr.to_string(), "-x");
}
#[test]
fn test_expr_display_is_null() {
let expr = Expr::unary(UnaryOperator::IsNull, Expr::column("x"));
assert_eq!(expr.to_string(), "x IS NULL");
}
#[test]
fn test_expr_display_is_not_null() {
let expr = Expr::unary(UnaryOperator::IsNotNull, Expr::column("x"));
assert_eq!(expr.to_string(), "x IS NOT NULL");
}
#[test]
fn test_expr_display_paren() {
let inner = Expr::binary(Expr::column("a"), BinaryOperator::Add, Expr::column("b"));
let expr = Expr::Paren(Box::new(inner));
assert_eq!(expr.to_string(), "((a + b))");
}
#[test]
fn test_expr_display_between() {
let expr = Expr::between(Expr::column("x"), Expr::integer(1), Expr::integer(10));
assert_eq!(expr.to_string(), "x BETWEEN 1 AND 10");
}
#[test]
fn test_binary_operator_display() {
assert_eq!(BinaryOperator::Eq.to_string(), "=");
assert_eq!(BinaryOperator::Ne.to_string(), "!=");
assert_eq!(BinaryOperator::Lt.to_string(), "<");
assert_eq!(BinaryOperator::Le.to_string(), "<=");
assert_eq!(BinaryOperator::Gt.to_string(), ">");
assert_eq!(BinaryOperator::Ge.to_string(), ">=");
assert_eq!(BinaryOperator::Regex.to_string(), "~");
assert_eq!(BinaryOperator::NotRegex.to_string(), "!~");
assert_eq!(BinaryOperator::In.to_string(), "IN");
assert_eq!(BinaryOperator::NotIn.to_string(), "NOT IN");
assert_eq!(BinaryOperator::And.to_string(), "AND");
assert_eq!(BinaryOperator::Or.to_string(), "OR");
assert_eq!(BinaryOperator::Add.to_string(), "+");
assert_eq!(BinaryOperator::Sub.to_string(), "-");
assert_eq!(BinaryOperator::Mul.to_string(), "*");
assert_eq!(BinaryOperator::Div.to_string(), "/");
assert_eq!(BinaryOperator::Mod.to_string(), "%");
}
#[test]
fn test_unary_operator_display() {
assert_eq!(UnaryOperator::Not.to_string(), "NOT ");
assert_eq!(UnaryOperator::Neg.to_string(), "-");
assert_eq!(UnaryOperator::IsNull.to_string(), " IS NULL");
assert_eq!(UnaryOperator::IsNotNull.to_string(), " IS NOT NULL");
}
#[test]
fn test_literal_display() {
assert_eq!(Literal::String("test".to_string()).to_string(), "\"test\"");
assert_eq!(Literal::Number(dec!(1.5)).to_string(), "1.5");
assert_eq!(Literal::Integer(42).to_string(), "42");
assert_eq!(Literal::Boolean(false).to_string(), "false");
assert_eq!(Literal::Null.to_string(), "NULL");
}
}