use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use type_bridge_core_lib::ast::{FunctionCallValue, Pattern, ReduceAssignment, Value};
use crate::value::AttributeValue;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum SortDir {
Asc,
Desc,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum Expr {
Eq {
attr: String,
value: AttributeValue,
},
Gt {
attr: String,
value: AttributeValue,
},
Lt {
attr: String,
value: AttributeValue,
},
Gte {
attr: String,
value: AttributeValue,
},
Lte {
attr: String,
value: AttributeValue,
},
Neq {
attr: String,
value: AttributeValue,
},
Contains {
attr: String,
substring: String,
},
Like {
attr: String,
pattern: String,
},
And(Vec<Expr>),
Or(Vec<Expr>),
Not(Box<Expr>),
RolePlayer {
role: String,
inner: Box<Expr>,
},
}
impl Expr {
pub fn eq(attr: impl Into<String>, value: AttributeValue) -> Self {
Self::Eq {
attr: attr.into(),
value,
}
}
pub fn gt(attr: impl Into<String>, value: AttributeValue) -> Self {
Self::Gt {
attr: attr.into(),
value,
}
}
pub fn lt(attr: impl Into<String>, value: AttributeValue) -> Self {
Self::Lt {
attr: attr.into(),
value,
}
}
pub fn gte(attr: impl Into<String>, value: AttributeValue) -> Self {
Self::Gte {
attr: attr.into(),
value,
}
}
pub fn lte(attr: impl Into<String>, value: AttributeValue) -> Self {
Self::Lte {
attr: attr.into(),
value,
}
}
pub fn neq(attr: impl Into<String>, value: AttributeValue) -> Self {
Self::Neq {
attr: attr.into(),
value,
}
}
pub fn contains(attr: impl Into<String>, substring: impl Into<String>) -> Self {
Self::Contains {
attr: attr.into(),
substring: substring.into(),
}
}
pub fn like(attr: impl Into<String>, pattern: impl Into<String>) -> Self {
Self::Like {
attr: attr.into(),
pattern: pattern.into(),
}
}
pub fn and(exprs: Vec<Expr>) -> Self {
Self::And(exprs)
}
pub fn or(exprs: Vec<Expr>) -> Self {
Self::Or(exprs)
}
#[allow(clippy::should_implement_trait)]
pub fn not(expr: Expr) -> Self {
Self::Not(Box::new(expr))
}
pub fn in_range(attr: impl Into<String>, low: AttributeValue, high: AttributeValue) -> Self {
let a = attr.into();
Self::And(vec![Self::gte(a.clone(), low), Self::lte(a, high)])
}
pub fn startswith(attr: impl Into<String>, prefix: impl Into<String>) -> Self {
Self::Like {
attr: attr.into(),
pattern: format!("^{}.*", regex_escape(&prefix.into())),
}
}
pub fn endswith(attr: impl Into<String>, suffix: impl Into<String>) -> Self {
Self::Like {
attr: attr.into(),
pattern: format!(".*{}$", regex_escape(&suffix.into())),
}
}
pub fn role_player(role: impl Into<String>, inner: Expr) -> Self {
Self::RolePlayer {
role: role.into(),
inner: Box::new(inner),
}
}
pub fn to_patterns(&self, entity_var: &str, counter: &mut usize) -> Vec<Pattern> {
match self {
Self::Eq { attr, value }
| Self::Gt { attr, value }
| Self::Lt { attr, value }
| Self::Gte { attr, value }
| Self::Lte { attr, value }
| Self::Neq { attr, value } => {
let op = match self {
Self::Eq { .. } => "==",
Self::Gt { .. } => ">",
Self::Lt { .. } => "<",
Self::Gte { .. } => ">=",
Self::Lte { .. } => "<=",
Self::Neq { .. } => "!=",
_ => unreachable!(),
};
let var_name = format!("$attr{}", counter);
*counter += 1;
vec![
Pattern::Has {
thing_var: entity_var.to_string(),
attr_type: attr.clone(),
attr_var: var_name.clone(),
},
Pattern::ValueComparison {
var: var_name,
operator: op.to_string(),
value: value.to_ast_value(),
},
]
}
Self::Contains { attr, substring } => {
let var_name = format!("$attr{}", counter);
*counter += 1;
vec![
Pattern::Has {
thing_var: entity_var.to_string(),
attr_type: attr.clone(),
attr_var: var_name.clone(),
},
Pattern::ValueComparison {
var: var_name,
operator: "contains".to_string(),
value: AttributeValue::String(substring.clone()).to_ast_value(),
},
]
}
Self::Like { attr, pattern } => {
let var_name = format!("$attr{}", counter);
*counter += 1;
vec![
Pattern::Has {
thing_var: entity_var.to_string(),
attr_type: attr.clone(),
attr_var: var_name.clone(),
},
Pattern::ValueComparison {
var: var_name,
operator: "like".to_string(),
value: AttributeValue::String(pattern.clone()).to_ast_value(),
},
]
}
Self::And(children) => {
let mut patterns = Vec::new();
for child in children {
patterns.extend(child.to_patterns(entity_var, counter));
}
patterns
}
Self::Or(children) => {
let branches: Vec<Vec<Pattern>> = children
.iter()
.map(|child| child.to_patterns(entity_var, counter))
.collect();
vec![Pattern::Or(branches)]
}
Self::Not(inner) => {
let inner_patterns = inner.to_patterns(entity_var, counter);
vec![Pattern::Not(inner_patterns)]
}
Self::RolePlayer { role, inner } => {
let role_var = format!("${}", role);
inner.to_patterns(&role_var, counter)
}
}
}
pub fn collect_roles(
&self,
roles: &mut Vec<String>,
seen: &mut std::collections::HashSet<String>,
) {
match self {
Self::RolePlayer { role, .. } if seen.insert(role.clone()) => {
roles.push(role.clone());
}
Self::And(children) | Self::Or(children) => {
for child in children {
child.collect_roles(roles, seen);
}
}
Self::Not(inner) => inner.collect_roles(roles, seen),
_ => {}
}
}
}
fn regex_escape(s: &str) -> String {
let mut out = String::with_capacity(s.len());
for c in s.chars() {
if matches!(
c,
'.' | '*' | '+' | '?' | '(' | ')' | '[' | ']' | '{' | '}' | '\\' | '^' | '$' | '|'
) {
out.push('\\');
}
out.push(c);
}
out
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum Agg {
Count,
Sum(String),
Min(String),
Max(String),
Mean(String),
Median(String),
Std(String),
}
impl Agg {
pub fn to_reduce_assignment(
&self,
entity_var: &str,
counter: &mut usize,
) -> (ReduceAssignment, Option<Pattern>) {
match self {
Self::Count => {
let assignment = ReduceAssignment {
variable: "$count".to_string(),
expression: Value::FunctionCall(FunctionCallValue {
function: "count".to_string(),
args: vec![Value::Variable(entity_var.to_string())],
}),
};
(assignment, None)
}
_ => {
let (func_name, attr_name, result_var) = match self {
Self::Sum(a) => ("sum", a.as_str(), "$sum"),
Self::Min(a) => ("min", a.as_str(), "$min"),
Self::Max(a) => ("max", a.as_str(), "$max"),
Self::Mean(a) => ("mean", a.as_str(), "$mean"),
Self::Median(a) => ("median", a.as_str(), "$median"),
Self::Std(a) => ("std", a.as_str(), "$std"),
Self::Count => unreachable!(),
};
let attr_var = format!("$agg{}", counter);
*counter += 1;
let has_pattern = Pattern::Has {
thing_var: entity_var.to_string(),
attr_type: attr_name.to_string(),
attr_var: attr_var.clone(),
};
let assignment = ReduceAssignment {
variable: result_var.to_string(),
expression: Value::FunctionCall(FunctionCallValue {
function: func_name.to_string(),
args: vec![Value::Variable(attr_var)],
}),
};
(assignment, Some(has_pattern))
}
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AggResult {
values: HashMap<String, serde_json::Value>,
}
impl AggResult {
pub fn new(values: HashMap<String, serde_json::Value>) -> Self {
Self { values }
}
pub fn count(&self) -> Option<u64> {
self.values
.get("$count")
.or_else(|| self.values.get("count"))
.and_then(parse_u64_json)
}
pub fn get_i64(&self, key: &str) -> Option<i64> {
self.values.get(key).and_then(|v| v.as_i64())
}
pub fn get_f64(&self, key: &str) -> Option<f64> {
self.values.get(key).and_then(|v| v.as_f64())
}
pub fn get(&self, key: &str) -> Option<&serde_json::Value> {
self.values.get(key)
}
}
fn parse_u64_json(value: &serde_json::Value) -> Option<u64> {
if let Some(value) = value.get("value") {
return parse_u64_json(value);
}
value
.as_u64()
.or_else(|| value.as_i64().and_then(|value| u64::try_from(value).ok()))
.or_else(|| value.as_f64().map(|value| value as u64))
.or_else(|| value.as_str().and_then(|value| value.parse().ok()))
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GroupByResult {
groups: Vec<(serde_json::Value, AggResult)>,
}
impl GroupByResult {
pub fn new(groups: Vec<(serde_json::Value, AggResult)>) -> Self {
Self { groups }
}
pub fn iter(&self) -> impl Iterator<Item = (&serde_json::Value, &AggResult)> {
self.groups.iter().map(|(k, v)| (k, v))
}
pub fn len(&self) -> usize {
self.groups.len()
}
pub fn is_empty(&self) -> bool {
self.groups.is_empty()
}
pub fn get(&self, key: &serde_json::Value) -> Option<&AggResult> {
self.groups.iter().find(|(k, _)| k == key).map(|(_, v)| v)
}
pub fn get_by_str(&self, key: &str) -> Option<&AggResult> {
self.get(&serde_json::Value::String(key.to_string()))
}
pub fn get_by_i64(&self, key: i64) -> Option<&AggResult> {
self.get(&serde_json::json!(key))
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_expr_eq_patterns() {
let expr = Expr::eq("age", AttributeValue::Long(30));
let mut counter = 0;
let patterns = expr.to_patterns("$e", &mut counter);
assert_eq!(patterns.len(), 2);
assert_eq!(counter, 1);
match &patterns[0] {
Pattern::Has {
thing_var,
attr_type,
attr_var,
} => {
assert_eq!(thing_var, "$e");
assert_eq!(attr_type, "age");
assert_eq!(attr_var, "$attr0");
}
_ => panic!("expected Has"),
}
match &patterns[1] {
Pattern::ValueComparison { var, operator, .. } => {
assert_eq!(var, "$attr0");
assert_eq!(operator, "==");
}
_ => panic!("expected ValueComparison"),
}
}
#[test]
fn test_expr_gt_patterns() {
let expr = Expr::gt("salary", AttributeValue::Long(50000));
let mut counter = 0;
let patterns = expr.to_patterns("$e", &mut counter);
assert_eq!(patterns.len(), 2);
match &patterns[1] {
Pattern::ValueComparison { operator, .. } => assert_eq!(operator, ">"),
_ => panic!("expected ValueComparison"),
}
}
#[test]
fn test_expr_lt_patterns() {
let expr = Expr::lt("age", AttributeValue::Long(18));
let mut counter = 0;
let patterns = expr.to_patterns("$e", &mut counter);
match &patterns[1] {
Pattern::ValueComparison { operator, .. } => assert_eq!(operator, "<"),
_ => panic!("expected ValueComparison"),
}
}
#[test]
fn test_expr_gte_lte_neq() {
for (expr, expected_op) in [
(Expr::gte("x", AttributeValue::Long(1)), ">="),
(Expr::lte("x", AttributeValue::Long(1)), "<="),
(Expr::neq("x", AttributeValue::Long(1)), "!="),
] {
let mut counter = 0;
let patterns = expr.to_patterns("$e", &mut counter);
match &patterns[1] {
Pattern::ValueComparison { operator, .. } => assert_eq!(operator, expected_op),
_ => panic!("expected ValueComparison"),
}
}
}
#[test]
fn test_expr_contains_patterns() {
let expr = Expr::contains("name", "Ali");
let mut counter = 0;
let patterns = expr.to_patterns("$e", &mut counter);
assert_eq!(patterns.len(), 2);
match &patterns[1] {
Pattern::ValueComparison {
operator, value, ..
} => {
assert_eq!(operator, "contains");
if let Value::Literal(lit) = value {
assert_eq!(lit.value, json!("Ali"));
} else {
panic!("expected Literal");
}
}
_ => panic!("expected ValueComparison"),
}
}
#[test]
fn test_expr_like_patterns() {
let expr = Expr::like("email", "^.*@example\\.com$");
let mut counter = 0;
let patterns = expr.to_patterns("$e", &mut counter);
assert_eq!(patterns.len(), 2);
match &patterns[1] {
Pattern::ValueComparison { operator, .. } => assert_eq!(operator, "like"),
_ => panic!("expected ValueComparison"),
}
}
#[test]
fn test_expr_and_flattens() {
let expr = Expr::and(vec![
Expr::gt("age", AttributeValue::Long(18)),
Expr::lt("age", AttributeValue::Long(65)),
]);
let mut counter = 0;
let patterns = expr.to_patterns("$e", &mut counter);
assert_eq!(patterns.len(), 4);
assert_eq!(counter, 2);
}
#[test]
fn test_expr_or_generates_or_pattern() {
let expr = Expr::or(vec![
Expr::eq("dept", AttributeValue::String("HR".into())),
Expr::eq("dept", AttributeValue::String("Eng".into())),
]);
let mut counter = 0;
let patterns = expr.to_patterns("$e", &mut counter);
assert_eq!(patterns.len(), 1);
match &patterns[0] {
Pattern::Or(branches) => {
assert_eq!(branches.len(), 2);
assert_eq!(branches[0].len(), 2); assert_eq!(branches[1].len(), 2);
}
_ => panic!("expected Or"),
}
}
#[test]
fn test_expr_not_generates_not_pattern() {
let expr = Expr::not(Expr::eq(
"status",
AttributeValue::String("inactive".into()),
));
let mut counter = 0;
let patterns = expr.to_patterns("$e", &mut counter);
assert_eq!(patterns.len(), 1);
match &patterns[0] {
Pattern::Not(inner) => assert_eq!(inner.len(), 2),
_ => panic!("expected Not"),
}
}
#[test]
fn test_counter_increments_across_expressions() {
let mut counter = 0;
Expr::eq("a", AttributeValue::Long(1)).to_patterns("$e", &mut counter);
Expr::gt("b", AttributeValue::Long(2)).to_patterns("$e", &mut counter);
Expr::contains("c", "x").to_patterns("$e", &mut counter);
assert_eq!(counter, 3);
}
#[test]
fn test_agg_count() {
let mut counter = 0;
let (assign, pattern) = Agg::Count.to_reduce_assignment("$e", &mut counter);
assert_eq!(assign.variable, "$count");
assert!(pattern.is_none());
assert_eq!(counter, 0); }
#[test]
fn test_agg_sum() {
let mut counter = 0;
let (assign, pattern) = Agg::Sum("salary".into()).to_reduce_assignment("$e", &mut counter);
assert_eq!(assign.variable, "$sum");
assert!(pattern.is_some());
match pattern.unwrap() {
Pattern::Has {
attr_type,
attr_var,
..
} => {
assert_eq!(attr_type, "salary");
assert_eq!(attr_var, "$agg0");
}
_ => panic!("expected Has"),
}
assert_eq!(counter, 1);
}
#[test]
fn test_agg_result_count() {
let mut values = HashMap::new();
values.insert("$count".into(), json!(42));
let result = AggResult::new(values);
assert_eq!(result.count(), Some(42));
}
#[test]
fn test_agg_result_count_ignores_other_aggregates() {
let mut values = HashMap::new();
values.insert("$sum".into(), json!(99));
let result = AggResult::new(values);
assert_eq!(result.count(), None);
}
#[test]
fn test_agg_result_get_f64() {
let mut values = HashMap::new();
values.insert("$mean".into(), json!(2.78));
let result = AggResult::new(values);
assert_eq!(result.get_f64("$mean"), Some(2.78));
}
#[test]
fn test_agg_result_get_i64() {
let mut values = HashMap::new();
values.insert("$sum".into(), json!(100));
let result = AggResult::new(values);
assert_eq!(result.get_i64("$sum"), Some(100));
}
#[test]
fn test_agg_result_missing_key() {
let result = AggResult::new(HashMap::new());
assert_eq!(result.count(), None);
assert_eq!(result.get_f64("$sum"), None);
}
#[test]
fn test_expr_in_range() {
let expr = Expr::in_range("age", AttributeValue::Long(20), AttributeValue::Long(30));
match expr {
Expr::And(children) => {
assert_eq!(children.len(), 2);
assert!(matches!(&children[0], Expr::Gte { attr, .. } if attr == "age"));
assert!(matches!(&children[1], Expr::Lte { attr, .. } if attr == "age"));
}
_ => panic!("expected And"),
}
}
#[test]
fn test_expr_startswith() {
let expr = Expr::startswith("name", "Ali");
match expr {
Expr::Like { attr, pattern } => {
assert_eq!(attr, "name");
assert_eq!(pattern, "^Ali.*");
}
_ => panic!("expected Like"),
}
}
#[test]
fn test_expr_endswith() {
let expr = Expr::endswith("name", "ice");
match expr {
Expr::Like { attr, pattern } => {
assert_eq!(attr, "name");
assert_eq!(pattern, ".*ice$");
}
_ => panic!("expected Like"),
}
}
#[test]
fn test_expr_startswith_escapes_special_chars() {
let expr = Expr::startswith("email", "foo.bar");
match expr {
Expr::Like { pattern, .. } => {
assert_eq!(pattern, "^foo\\.bar.*");
}
_ => panic!("expected Like"),
}
}
#[test]
fn test_expr_role_player_patterns() {
let expr = Expr::role_player("employee", Expr::gt("age", AttributeValue::Long(30)));
let mut counter = 0;
let patterns = expr.to_patterns("$r", &mut counter);
assert_eq!(patterns.len(), 2);
match &patterns[0] {
Pattern::Has {
thing_var,
attr_type,
..
} => {
assert_eq!(thing_var, "$employee");
assert_eq!(attr_type, "age");
}
_ => panic!("expected Has"),
}
}
#[test]
fn test_collect_roles() {
let expr = Expr::and(vec![
Expr::role_player("employee", Expr::gt("age", AttributeValue::Long(30))),
Expr::role_player(
"employer",
Expr::eq("name", AttributeValue::String("Corp".into())),
),
Expr::role_player("employee", Expr::lt("age", AttributeValue::Long(65))),
]);
let mut roles = Vec::new();
let mut seen = std::collections::HashSet::new();
expr.collect_roles(&mut roles, &mut seen);
assert_eq!(roles, vec!["employee", "employer"]);
}
#[test]
fn test_group_by_result() {
let mut values1 = HashMap::new();
values1.insert("$mean".into(), json!(35.5));
let mut values2 = HashMap::new();
values2.insert("$mean".into(), json!(28.3));
let result = GroupByResult::new(vec![
(json!("Engineering"), AggResult::new(values1)),
(json!("Sales"), AggResult::new(values2)),
]);
assert_eq!(result.len(), 2);
assert!(!result.is_empty());
assert_eq!(
result.get_by_str("Engineering").unwrap().get_f64("$mean"),
Some(35.5)
);
assert_eq!(
result.get_by_str("Sales").unwrap().get_f64("$mean"),
Some(28.3)
);
assert!(result.get_by_str("HR").is_none());
}
#[test]
fn expr_serde_roundtrip() {
let expr = Expr::and(vec![
Expr::eq("name", AttributeValue::String("Alice".into())),
Expr::or(vec![
Expr::gt("age", AttributeValue::Long(18)),
Expr::lt("age", AttributeValue::Long(65)),
]),
Expr::not(Expr::eq(
"status",
AttributeValue::String("inactive".into()),
)),
]);
let json = serde_json::to_string(&expr).unwrap();
let parsed: Expr = serde_json::from_str(&json).unwrap();
match parsed {
Expr::And(children) => {
assert_eq!(children.len(), 3);
assert!(matches!(&children[0], Expr::Eq { .. }));
assert!(matches!(&children[1], Expr::Or(_)));
assert!(matches!(&children[2], Expr::Not(_)));
}
_ => panic!("expected And"),
}
}
#[test]
fn agg_serde_roundtrip() {
for agg in [
Agg::Count,
Agg::Sum("salary".into()),
Agg::Min("age".into()),
Agg::Max("age".into()),
Agg::Mean("score".into()),
] {
let json = serde_json::to_string(&agg).unwrap();
let _parsed: Agg = serde_json::from_str(&json).unwrap();
}
}
#[test]
fn role_player_serde_roundtrip() {
let expr = Expr::role_player("employee", Expr::gt("age", AttributeValue::Long(30)));
let json = serde_json::to_string(&expr).unwrap();
let parsed: Expr = serde_json::from_str(&json).unwrap();
match parsed {
Expr::RolePlayer { role, inner } => {
assert_eq!(role, "employee");
assert!(matches!(*inner, Expr::Gt { .. }));
}
_ => panic!("expected RolePlayer"),
}
}
#[test]
fn sort_dir_serde_roundtrip() {
let json = serde_json::to_string(&SortDir::Asc).unwrap();
let parsed: SortDir = serde_json::from_str(&json).unwrap();
assert_eq!(parsed, SortDir::Asc);
}
}