use bigdecimal::BigDecimal;
use sqlite3_ext::{function::*, *};
use std::{cmp::Ordering, str::FromStr};
fn process_value(a: &mut ValueRef) -> Result<Option<BigDecimal>> {
match a.value_type() {
ValueType::Null => Ok(None),
_ => Ok(Some(
BigDecimal::from_str(a.get_str()?).unwrap_or_else(|_| BigDecimal::default()),
)),
}
}
fn process_args(args: &mut [&mut ValueRef]) -> Result<Vec<Option<BigDecimal>>> {
args.into_iter().map(|x| process_value(*x)).collect()
}
macro_rules! scalar_method {
($name:ident as ( $a:ident, $b:ident ) -> $ty:ty => $ret:expr) => {
#[sqlite3_ext_fn(n_args=2, risk_level=Innocuous, deterministic)]
fn $name(ctx: &mut Context, args: &mut [&mut ValueRef]) -> Result<()> {
let mut args = process_args(args)?.into_iter();
let a = args.next().unwrap_or(None);
let b = args.next().unwrap_or(None);
if let (Some($a), Some($b)) = (a, b) {
ctx.set_result($ret).unwrap();
}
Ok(())
}
};
}
scalar_method!(decimal_add as (a, b) -> String => format!("{}", (a + b).normalized()));
scalar_method!(decimal_sub as (a, b) -> String => format!("{}", (a - b).normalized()));
scalar_method!(decimal_mul as (a, b) -> String => format!("{}", (a * b).normalized()));
scalar_method!(decimal_cmp as (a, b) -> i32 => {
match a.cmp(&b) {
Ordering::Less => -1,
Ordering::Equal => 0,
Ordering::Greater => 1,
}
});
#[derive(Default)]
#[sqlite3_ext_fn(n_args=1, risk_level=Innocuous, deterministic)]
struct Sum {
cur: BigDecimal,
}
impl AggregateFunction<()> for Sum {
fn default_value(_: &(), ctx: &Context) -> Result<()> {
ctx.set_result(())
}
fn step(&mut self, _: &Context, args: &mut [&mut ValueRef]) -> Result<()> {
if let Some(x) = process_value(*args.first_mut().unwrap())? {
self.cur += x;
}
Ok(())
}
fn value(&self, ctx: &Context) -> Result<()> {
ctx.set_result(format!("{}", self.cur.normalized()))
}
fn inverse(&mut self, _: &Context, args: &mut [&mut ValueRef]) -> Result<()> {
if let Some(x) = process_value(*args.first_mut().unwrap())? {
self.cur -= x;
}
Ok(())
}
}
fn decimal_collation(a: &str, b: &str) -> Ordering {
let a = BigDecimal::from_str(a).unwrap_or_else(|_| BigDecimal::default());
let b = BigDecimal::from_str(b).unwrap_or_else(|_| BigDecimal::default());
a.cmp(&b)
}
#[sqlite3_ext_main]
fn init(db: &Connection) -> Result<()> {
db.create_scalar_function("decimal_add", &DECIMAL_ADD_OPTS, decimal_add)?;
db.create_scalar_function("decimal_sub", &DECIMAL_SUB_OPTS, decimal_sub)?;
db.create_scalar_function("decimal_mul", &DECIMAL_MUL_OPTS, decimal_mul)?;
db.create_scalar_function("decimal_cmp", &DECIMAL_CMP_OPTS, decimal_cmp)?;
db.create_aggregate_function::<_, Sum>("decimal_sum", &SUM_OPTS, ())?;
db.create_collation("decimal", decimal_collation)?;
Ok(())
}
#[cfg(all(test, feature = "static"))]
mod test {
use super::*;
fn setup() -> Result<Database> {
let conn = Database::open(":memory:")?;
init(&conn)?;
Ok(conn)
}
fn case(data: Vec<(&str, Value)>) -> Result<()> {
let conn = setup()?;
let (sql, expected): (Vec<&str>, Vec<Value>) = data.into_iter().unzip();
let sql = format!("SELECT {}", sql.join(", "));
println!("{}", sql);
let ret: Vec<Value> = conn.query_row(&sql, (), |r| {
(0..expected.len())
.map(|i| r[i].to_owned())
.collect::<Result<_>>()
})?;
assert_eq!(ret, expected);
Ok(())
}
#[test]
fn decimal_add() -> Result<()> {
case(vec![
(
"decimal_add('1000000000000000', '0.0000000000000001')",
Value::Text("1000000000000000.0000000000000001".to_owned()),
),
("decimal_add(NULL, '0')", Value::Null),
("decimal_add('0', NULL)", Value::Null),
("decimal_add(NULL, NULL)", Value::Null),
("decimal_add('invalid', 2)", Value::Text("2".to_owned())),
])
}
#[test]
fn decimal_sub() -> Result<()> {
case(vec![
(
"decimal_sub('1000000000000000', '0.0000000000000001')",
Value::Text("999999999999999.9999999999999999".to_owned()),
),
("decimal_sub(NULL, '0')", Value::Null),
("decimal_sub('0', NULL)", Value::Null),
("decimal_sub(NULL, NULL)", Value::Null),
("decimal_sub('invalid', 2)", Value::Text("-2".to_owned())),
])
}
#[test]
fn decimal_mul() -> Result<()> {
case(vec![
(
"decimal_mul('1000000000000000', '0.0000000000000001')",
Value::Text("0.1".to_owned()),
),
("decimal_mul(NULL, '0')", Value::Null),
("decimal_mul('0', NULL)", Value::Null),
("decimal_mul(NULL, NULL)", Value::Null),
("decimal_mul('invalid', 2)", Value::Text("0".to_owned())),
])
}
#[test]
fn decimal_cmp() -> Result<()> {
case(vec![
("decimal_cmp('1', '-1')", Value::Integer(1)),
("decimal_cmp('-1', '1')", Value::Integer(-1)),
("decimal_cmp('1', '1')", Value::Integer(0)),
("decimal_cmp(NULL, '0')", Value::Null),
("decimal_cmp('0', NULL)", Value::Null),
("decimal_cmp(NULL, NULL)", Value::Null),
])
}
fn aggregate_case(expr: &str, data: Vec<&str>, expected: Vec<Value>) -> Result<()> {
let conn = setup()?;
let sql = format!(
"SELECT {} FROM ( VALUES {} )",
expr,
data.iter()
.map(|s| format!("({})", s))
.collect::<Vec<String>>()
.join(", ")
);
println!("{}", sql);
let ret: Vec<Value> = conn
.prepare(&sql)?
.query(())?
.map(|r| r[0].to_owned())
.collect()?;
assert_eq!(ret, expected);
Ok(())
}
#[test]
fn decimal_sum() -> Result<()> {
aggregate_case(
"decimal_sum(column1)",
vec!["1000000000000000", "0.0000000000000001", "1"],
vec![Value::Text("1000000000000001.0000000000000001".to_owned())],
)?;
aggregate_case(
"decimal_sum(column1)",
vec!["1", "NULL"],
vec![Value::Text("1".to_owned())],
)?;
aggregate_case(
"decimal_sum(column1)",
vec!["NULL"],
vec![Value::Text("0".to_owned())],
)?;
case(vec![("decimal_sum(NULL)", Value::Text("0".to_owned()))])?;
case(vec![(
"decimal_sum('invalid')",
Value::Text("0".to_owned()),
)])?;
case(vec![("decimal_sum(1) WHERE 1 = 0", Value::Null)])?;
aggregate_case(
"decimal_sum(column1) OVER ( ROWS 1 PRECEDING )",
vec![
"1000000000000000",
"0.0000000000000001",
"NULL",
"NULL",
"1",
],
vec![
Value::Text("1000000000000000".to_owned()),
Value::Text("1000000000000000.0000000000000001".to_owned()),
Value::Text("0.0000000000000001".to_owned()),
Value::Text("0".to_owned()),
Value::Text("1".to_owned()),
],
)?;
Ok(())
}
#[test]
fn collation() -> Result<()> {
let conn = setup()?;
let ret: Vec<String> = conn
.prepare(
"SELECT column1 FROM ( VALUES (('1')), (('0100')), (('.1')) ) ORDER BY column1 COLLATE decimal",
)?
.query(())?.map(|row| Ok(row[0].get_str()?.to_owned()))
.collect()?;
assert_eq!(
ret,
vec![".1".to_owned(), "1".to_owned(), "0100".to_owned()]
);
Ok(())
}
}