use super::AddExpr;
use crate::{
base::{
commitment::InnerProductProof,
database::{
owned_table_utility::*, table_utility::*, ColumnType, OwnedTableTestAccessor, TableRef,
TableTestAccessor,
},
math::decimal::Precision,
},
proof_primitive::inner_product::curve_25519_scalar::Curve25519Scalar,
sql::{
proof::{exercise_verification, VerifiableQueryResult},
proof_exprs::{test_utility::*, DynProofExpr, ProofExpr},
proof_plans::test_utility::*,
AnalyzeError,
},
};
use bumpalo::Bump;
use itertools::{multizip, MultiUnzip};
use rand::{
distributions::{Distribution, Uniform},
rngs::StdRng,
};
use rand_core::SeedableRng;
#[test]
fn we_can_prove_a_typical_add_subtract_query() {
let data = owned_table([
smallint("a", [1_i16, 2, 3, 4]),
int("b", [0_i32, 1, 0, 1]),
varchar("d", ["ab", "t", "efg", "g"]),
bigint("c", [0_i64, 2, 2, 0]),
]);
let t = TableRef::new("sxt", "t");
let accessor =
OwnedTableTestAccessor::<InnerProductProof>::new_from_table(t.clone(), data, 0, ());
let ast = filter(
vec![
col_expr_plan(&t, "a", &accessor),
col_expr_plan(&t, "c", &accessor),
aliased_plan(add(column(&t, "b", &accessor), const_bigint(4)), "res"),
col_expr_plan(&t, "d", &accessor),
],
table_exec(
t.clone(),
vec![
column_field("a", ColumnType::SmallInt),
column_field("b", ColumnType::Int),
column_field("d", ColumnType::VarChar),
column_field("c", ColumnType::BigInt),
],
),
equal(
subtract(column(&t, "a", &accessor), column(&t, "b", &accessor)),
const_bigint(3),
),
);
let verifiable_res = VerifiableQueryResult::new(&ast, &accessor, &(), &[]).unwrap();
exercise_verification(&verifiable_res, &ast, &accessor, &t);
let res = verifiable_res
.verify(&ast, &accessor, &(), &[])
.unwrap()
.table;
let expected_res = owned_table([
smallint("a", [3_i16, 4]),
bigint("c", [2_i16, 0]),
decimal75("res", 20, 0, [4_i64, 5]),
varchar("d", ["efg", "g"]),
]);
assert_eq!(res, expected_res);
}
#[test]
fn we_can_prove_a_typical_add_subtract_query_with_decimals() {
let data = owned_table([
decimal75("a", 12, 1, [4_i64, 2, 2, 7]),
decimal75("b", 12, 2, [5_i64, -15, 42, 8]),
varchar("d", ["ab", "t", "efg", "g"]),
decimal75("c", 12, 3, [190_i64, 27, 253, 120]),
]);
let t = TableRef::new("sxt", "t");
let accessor =
OwnedTableTestAccessor::<InnerProductProof>::new_from_table(t.clone(), data, 0, ());
let ast = filter(
vec![
col_expr_plan(&t, "a", &accessor),
aliased_plan(
add(
add(
scaling_cast(
add(
scaling_cast(
column(&t, "a", &accessor),
ColumnType::Decimal75(Precision::new(13).unwrap(), 2),
),
column(&t, "b", &accessor),
),
ColumnType::Decimal75(Precision::new(15).unwrap(), 3),
),
column(&t, "c", &accessor),
),
const_decimal75(4, 3, 400),
),
"c",
),
col_expr_plan(&t, "d", &accessor),
],
table_exec(
t.clone(),
vec![
column_field("a", ColumnType::Decimal75(Precision::new(12).unwrap(), 1)),
column_field("b", ColumnType::Decimal75(Precision::new(12).unwrap(), 2)),
column_field("d", ColumnType::VarChar),
column_field("c", ColumnType::Decimal75(Precision::new(12).unwrap(), 3)),
],
),
equal(
subtract(
scaling_cast(
column(&t, "a", &accessor),
ColumnType::Decimal75(Precision::new(13).unwrap(), 2),
),
column(&t, "b", &accessor),
),
const_decimal75(12, 2, 35),
),
);
let verifiable_res = VerifiableQueryResult::new(&ast, &accessor, &(), &[]).unwrap();
exercise_verification(&verifiable_res, &ast, &accessor, &t);
let res = verifiable_res
.verify(&ast, &accessor, &(), &[])
.unwrap()
.table;
let expected_res = owned_table([
decimal75("a", 12, 1, [4_i64, 2]),
decimal75("c", 17, 3, [1040_i64, 477]),
varchar("d", ["ab", "t"]),
]);
assert_eq!(res, expected_res);
}
fn test_random_tables_with_given_offset(offset: usize) {
let dist = Uniform::new(-3, 4);
let mut rng = StdRng::from_seed([0u8; 32]);
for _ in 0..20 {
let n = Uniform::new(1, 21).sample(&mut rng);
let data = owned_table([
bigint("a", dist.sample_iter(&mut rng).take(n)),
varchar(
"b",
dist.sample_iter(&mut rng).take(n).map(|v| format!("s{v}")),
),
bigint("c", dist.sample_iter(&mut rng).take(n)),
varchar(
"d",
dist.sample_iter(&mut rng).take(n).map(|v| format!("s{v}")),
),
]);
let filter_val1 = format!("s{}", dist.sample(&mut rng));
let filter_val2 = dist.sample(&mut rng);
let t = TableRef::new("sxt", "t");
let accessor = OwnedTableTestAccessor::<InnerProductProof>::new_from_table(
t.clone(),
data.clone(),
offset,
(),
);
let ast = filter(
vec![
col_expr_plan(&t, "d", &accessor),
aliased_plan(
subtract(
add(column(&t, "a", &accessor), column(&t, "c", &accessor)),
const_int128(4),
),
"f",
),
],
table_exec(
t.clone(),
vec![
column_field("a", ColumnType::BigInt),
column_field("b", ColumnType::VarChar),
column_field("c", ColumnType::BigInt),
column_field("d", ColumnType::VarChar),
],
),
and(
equal(
column(&t, "b", &accessor),
const_scalar::<Curve25519Scalar, _>(filter_val1.as_str()),
),
equal(
column(&t, "c", &accessor),
const_scalar::<Curve25519Scalar, _>(filter_val2),
),
),
);
let verifiable_res = VerifiableQueryResult::new(&ast, &accessor, &(), &[]).unwrap();
exercise_verification(&verifiable_res, &ast, &accessor, &t);
let res = verifiable_res
.verify(&ast, &accessor, &(), &[])
.unwrap()
.table;
let (expected_f, expected_d): (Vec<_>, Vec<_>) = multizip((
data["a"].i64_iter(),
data["b"].string_iter(),
data["c"].i64_iter(),
data["d"].string_iter(),
))
.filter_map(|(a, b, c, d)| {
if b == &filter_val1 && c == &filter_val2 {
Some((Curve25519Scalar::from(*a + *c - 4), d.clone()))
} else {
None
}
})
.multiunzip();
let expected_result =
owned_table([varchar("d", expected_d), decimal75("f", 40, 0, expected_f)]);
assert_eq!(expected_result, res);
}
}
#[test]
fn we_can_query_random_tables_using_a_zero_offset() {
test_random_tables_with_given_offset(0);
}
#[test]
fn we_can_query_random_tables_using_a_non_zero_offset() {
test_random_tables_with_given_offset(123);
}
#[test]
fn we_can_compute_the_correct_output_of_an_add_subtract_expr_using_first_round_evaluate() {
let alloc = Bump::new();
let data = table([
borrowed_smallint("a", [1_i16, 2, 3, 4], &alloc),
borrowed_int("b", [0_i32, 1, 0, 1], &alloc),
borrowed_varchar("d", ["ab", "t", "efg", "g"], &alloc),
borrowed_bigint("c", [0_i64, 2, 2, 0], &alloc),
]);
let t = TableRef::new("sxt", "t");
let accessor =
TableTestAccessor::<InnerProductProof>::new_from_table(t.clone(), data.clone(), 0, ());
let add_subtract_expr: DynProofExpr = add(
column(&t, "b", &accessor),
subtract(column(&t, "a", &accessor), const_bigint(1)),
);
let res = add_subtract_expr
.first_round_evaluate(&alloc, &data, &[])
.unwrap();
let expected_res = borrowed_decimal75("res", 21, 0, [0_i64, 2, 2, 4], &alloc).1;
assert_eq!(res, expected_res);
}
#[test]
fn we_cannot_add_subtract_mismatching_types() {
let alloc = Bump::new();
let data = table([
borrowed_smallint("a", [1_i16, 2, 3, 4], &alloc),
borrowed_varchar("b", ["a", "b", "s", "z"], &alloc),
]);
let t = TableRef::new("sxt", "t");
let accessor =
TableTestAccessor::<InnerProductProof>::new_from_table(t.clone(), data.clone(), 0, ());
let lhs = Box::new(column(&t, "a", &accessor));
let rhs = Box::new(column(&t, "b", &accessor));
let add_err = AddExpr::try_new(lhs.clone(), rhs.clone()).unwrap_err();
assert!(matches!(
add_err,
AnalyzeError::DataTypeMismatch {
left_type: _,
right_type: _
}
));
let sub_err = AddExpr::try_new(lhs, rhs).unwrap_err();
assert!(matches!(
sub_err,
AnalyzeError::DataTypeMismatch {
left_type: _,
right_type: _
}
));
}