use std::collections::BTreeMap;
use std::hash::{Hash, Hasher};
use rust_decimal::Decimal;
use rustledger_core::{Amount, Inventory, Metadata, NaiveDate, Position, Transaction};
#[derive(Debug, Clone)]
pub struct SourceLocation {
pub filename: String,
pub lineno: usize,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum IntervalUnit {
Day,
Week,
Month,
Quarter,
Year,
}
impl IntervalUnit {
pub fn parse_unit(s: &str) -> Option<Self> {
match s.to_uppercase().as_str() {
"DAY" | "DAYS" | "D" => Some(Self::Day),
"WEEK" | "WEEKS" | "W" => Some(Self::Week),
"MONTH" | "MONTHS" | "M" => Some(Self::Month),
"QUARTER" | "QUARTERS" | "Q" => Some(Self::Quarter),
"YEAR" | "YEARS" | "Y" => Some(Self::Year),
_ => None,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Interval {
pub count: i64,
pub unit: IntervalUnit,
}
impl Interval {
pub const fn new(count: i64, unit: IntervalUnit) -> Self {
Self { count, unit }
}
pub(crate) const fn to_approx_days(&self) -> i64 {
let days_per_unit = match self.unit {
IntervalUnit::Day => 1,
IntervalUnit::Week => 7,
IntervalUnit::Month => 30,
IntervalUnit::Quarter => 91,
IntervalUnit::Year => 365,
};
self.count.saturating_mul(days_per_unit)
}
pub fn add_to_date(&self, date: NaiveDate) -> Option<NaiveDate> {
use jiff::ToSpan;
let span = match self.unit {
IntervalUnit::Day => self.count.days(),
IntervalUnit::Week => self.count.weeks(),
IntervalUnit::Month => self.count.months(),
IntervalUnit::Quarter => (self.count * 3).months(),
IntervalUnit::Year => self.count.years(),
};
date.checked_add(span).ok()
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Value {
String(String),
Number(Decimal),
Integer(i64),
Date(NaiveDate),
Boolean(bool),
Amount(Amount),
Position(Box<Position>),
Inventory(Box<Inventory>),
StringSet(Vec<String>),
Set(Vec<Self>),
Metadata(Box<Metadata>),
Interval(Interval),
Object(Box<BTreeMap<String, Self>>),
Null,
}
impl Value {
pub(crate) fn hash_value<H: Hasher>(&self, state: &mut H) {
std::mem::discriminant(self).hash(state);
match self {
Self::String(s) => s.hash(state),
Self::Number(d) => d.serialize().hash(state),
Self::Integer(i) => i.hash(state),
Self::Date(d) => {
d.year().hash(state);
d.month().hash(state);
d.day().hash(state);
}
Self::Boolean(b) => b.hash(state),
Self::Amount(a) => {
a.number.serialize().hash(state);
a.currency.as_str().hash(state);
}
Self::Position(p) => {
p.units.number.serialize().hash(state);
p.units.currency.as_str().hash(state);
if let Some(cost) = &p.cost {
cost.number.serialize().hash(state);
cost.currency.as_str().hash(state);
}
}
Self::Inventory(inv) => {
for pos in inv.positions() {
pos.units.number.serialize().hash(state);
pos.units.currency.as_str().hash(state);
if let Some(cost) = &pos.cost {
cost.number.serialize().hash(state);
cost.currency.as_str().hash(state);
}
}
}
Self::StringSet(ss) => {
let mut sorted = ss.clone();
sorted.sort();
for s in &sorted {
s.hash(state);
}
}
Self::Set(values) => {
for v in values {
v.hash_value(state);
}
}
Self::Metadata(meta) => {
let mut keys: Vec<_> = meta.keys().collect();
keys.sort();
for key in keys {
key.hash(state);
format!("{:?}", meta.get(key)).hash(state);
}
}
Self::Interval(interval) => {
interval.count.hash(state);
interval.unit.hash(state);
}
Self::Object(obj) => {
for (k, v) in obj.as_ref() {
k.hash(state);
v.hash_value(state);
}
}
Self::Null => {}
}
}
}
pub type Row = Vec<Value>;
pub fn hash_row(row: &Row) -> u64 {
let mut hasher = rustc_hash::FxHasher::default();
for value in row {
value.hash_value(&mut hasher);
}
hasher.finish()
}
pub fn hash_single_value(value: &Value) -> u64 {
let mut hasher = rustc_hash::FxHasher::default();
value.hash_value(&mut hasher);
hasher.finish()
}
#[derive(Debug, Clone)]
pub struct QueryResult {
pub columns: Vec<String>,
pub rows: Vec<Row>,
pub(crate) row_group_keys: Vec<Option<Vec<Value>>>,
}
impl QueryResult {
pub const fn new(columns: Vec<String>) -> Self {
Self {
columns,
rows: Vec::new(),
row_group_keys: Vec::new(),
}
}
pub fn add_row(&mut self, row: Row) {
self.rows.push(row);
self.row_group_keys.push(None);
}
pub fn add_aggregate_row(&mut self, row: Row, group_key: Vec<Value>) {
self.rows.push(row);
self.row_group_keys.push(if group_key.is_empty() {
None
} else {
Some(group_key)
});
}
#[must_use]
pub fn group_key(&self, row_idx: usize) -> Option<&[Value]> {
self.row_group_keys.get(row_idx).and_then(|k| k.as_deref())
}
#[must_use]
pub fn has_aggregate_rows(&self) -> bool {
self.row_group_keys.iter().any(Option::is_some)
}
pub fn truncate(&mut self, len: usize) {
self.rows.truncate(len);
self.row_group_keys.truncate(len);
}
pub fn sort_by<F>(&mut self, mut compare: F)
where
F: FnMut(&Row, &Row) -> std::cmp::Ordering,
{
assert_eq!(
self.rows.len(),
self.row_group_keys.len(),
"QueryResult invariant violated: rows.len() must equal row_group_keys.len()"
);
let n = self.rows.len();
let mut paired: Vec<(Row, Option<Vec<Value>>)> = std::mem::take(&mut self.rows)
.into_iter()
.zip(std::mem::take(&mut self.row_group_keys))
.collect();
paired.sort_by(|(a, _), (b, _)| compare(a, b));
self.rows.reserve_exact(n);
self.row_group_keys.reserve_exact(n);
for (row, key) in paired {
self.rows.push(row);
self.row_group_keys.push(key);
}
}
pub const fn len(&self) -> usize {
self.rows.len()
}
pub const fn is_empty(&self) -> bool {
self.rows.is_empty()
}
}
#[derive(Debug)]
pub struct PostingContext<'a> {
pub transaction: &'a Transaction,
pub posting_index: usize,
pub balance: Option<Inventory>,
pub account_balance: Option<Inventory>,
pub directive_index: Option<usize>,
}
#[derive(Debug, Clone)]
pub struct WindowContext {
pub row_number: usize,
pub rank: usize,
pub dense_rank: usize,
}
#[derive(Debug, Clone)]
pub struct AccountInfo {
pub open_date: Option<NaiveDate>,
pub close_date: Option<NaiveDate>,
pub open_meta: Metadata,
}
#[derive(Debug, Clone)]
pub struct Table {
pub columns: Vec<String>,
pub rows: Vec<Vec<Value>>,
}
impl Table {
#[allow(clippy::missing_const_for_fn)] pub fn new(columns: Vec<String>) -> Self {
Self {
columns,
rows: Vec::new(),
}
}
pub fn add_row(&mut self, row: Vec<Value>) {
self.rows.push(row);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_value_size() {
use std::mem::size_of;
assert!(
size_of::<Value>() <= 48,
"Value enum too large: {} bytes",
size_of::<Value>()
);
}
fn make_keyed_result() -> QueryResult {
let mut r = QueryResult::new(vec!["currency".into(), "sum".into()]);
r.add_aggregate_row(
vec![Value::String("USD".into()), Value::Integer(100)],
vec![Value::String("USD".into())],
);
r.add_aggregate_row(
vec![Value::String("EUR".into()), Value::Integer(50)],
vec![Value::String("EUR".into())],
);
r.add_aggregate_row(
vec![Value::String("GBP".into()), Value::Integer(75)],
vec![Value::String("GBP".into())],
);
r
}
#[test]
fn test_sort_by_keeps_row_group_keys_in_lockstep() {
let mut r = make_keyed_result();
r.sort_by(|a, b| match (&a[1], &b[1]) {
(Value::Integer(x), Value::Integer(y)) => x.cmp(y),
_ => std::cmp::Ordering::Equal,
});
assert_eq!(r.group_key(0), Some(&[Value::String("EUR".into())][..]));
assert_eq!(r.group_key(1), Some(&[Value::String("GBP".into())][..]));
assert_eq!(r.group_key(2), Some(&[Value::String("USD".into())][..]));
}
#[test]
fn test_truncate_keeps_row_group_keys_in_lockstep() {
let mut r = make_keyed_result();
r.truncate(2);
assert_eq!(r.rows.len(), 2);
assert_eq!(r.row_group_keys.len(), 2);
assert_eq!(r.group_key(0), Some(&[Value::String("USD".into())][..]));
assert_eq!(r.group_key(1), Some(&[Value::String("EUR".into())][..]));
assert_eq!(r.group_key(2), None);
}
#[test]
fn test_add_row_and_add_aggregate_row_mixed() {
let mut r = QueryResult::new(vec!["x".into()]);
r.add_aggregate_row(vec![Value::Integer(1)], vec![Value::String("USD".into())]);
r.add_row(vec![Value::Integer(2)]);
r.add_aggregate_row(vec![Value::Integer(3)], vec![Value::String("EUR".into())]);
assert_eq!(r.rows.len(), 3);
assert_eq!(r.row_group_keys.len(), 3);
assert_eq!(r.group_key(0), Some(&[Value::String("USD".into())][..]));
assert_eq!(r.group_key(1), None);
assert_eq!(r.group_key(2), Some(&[Value::String("EUR".into())][..]));
}
#[test]
fn test_add_aggregate_row_empty_key_records_none() {
let mut r = QueryResult::new(vec!["count".into()]);
r.add_aggregate_row(vec![Value::Integer(42)], vec![]);
assert_eq!(r.group_key(0), None);
}
#[test]
#[should_panic(expected = "QueryResult invariant violated")]
fn test_sort_by_panics_on_lockstep_violation() {
let mut r = QueryResult::new(vec!["x".into()]);
r.rows.push(vec![Value::Integer(1)]);
r.sort_by(|_, _| std::cmp::Ordering::Equal);
}
#[test]
fn test_add_row_records_none_in_sidecar() {
let mut r = QueryResult::new(vec!["x".into()]);
r.add_row(vec![Value::Integer(1)]);
assert_eq!(r.rows.len(), 1);
assert_eq!(r.row_group_keys.len(), 1);
assert_eq!(r.group_key(0), None);
}
}