use std::sync::Arc;
use grafeo_common::types::{EdgeId, LogicalType, NodeId, PropertyKey, Value};
use super::accumulator::AggregateFunction;
use super::aggregate::AggregateState;
use super::{Operator, OperatorResult};
use crate::execution::DataChunk;
use crate::execution::vector::ValueVector;
use crate::graph::traits::GraphStoreSearch;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum EntityKind {
Edge,
Node,
}
pub struct HorizontalAggregateOperator {
child: Box<dyn Operator>,
list_column_idx: usize,
entity_kind: EntityKind,
function: AggregateFunction,
property: String,
store: Arc<dyn GraphStoreSearch>,
input_column_count: usize,
}
impl HorizontalAggregateOperator {
pub fn new(
child: Box<dyn Operator>,
list_column_idx: usize,
entity_kind: EntityKind,
function: AggregateFunction,
property: String,
store: Arc<dyn GraphStoreSearch>,
input_column_count: usize,
) -> Self {
Self {
child,
list_column_idx,
entity_kind,
function,
property,
store,
input_column_count,
}
}
fn get_property_value(&self, entity_value: &Value) -> Option<Value> {
let prop_key = PropertyKey::new(&self.property);
match self.entity_kind {
EntityKind::Edge => {
let id = match entity_value {
#[allow(clippy::cast_sign_loss)]
Value::Int64(i) => EdgeId(*i as u64),
_ => return None,
};
self.store.get_edge_property(id, &prop_key)
}
EntityKind::Node => {
let id = match entity_value {
#[allow(clippy::cast_sign_loss)]
Value::Int64(i) => NodeId(*i as u64),
_ => return None,
};
self.store.get_node_property(id, &prop_key)
}
}
}
}
impl Operator for HorizontalAggregateOperator {
fn next(&mut self) -> OperatorResult {
let Some(input) = self.child.next()? else {
return Ok(None);
};
let mut output_columns: Vec<ValueVector> = (0..self.input_column_count)
.map(|_| ValueVector::with_capacity(LogicalType::Any, input.row_count()))
.collect();
let mut result_column = ValueVector::with_capacity(LogicalType::Float64, input.row_count());
let rows: Vec<usize> = input.selected_indices().collect();
for row in rows {
for col_idx in 0..self.input_column_count {
let value = input
.column(col_idx)
.and_then(|c| c.get_value(row))
.unwrap_or(Value::Null);
output_columns[col_idx].push_value(value);
}
let agg_result = if let Some(Value::List(list)) = input
.column(self.list_column_idx)
.and_then(|c| c.get_value(row))
{
let mut state = AggregateState::new(self.function, false, None, None);
for entity_val in list.iter() {
let prop_val = self.get_property_value(entity_val);
if prop_val.is_some() && !matches!(prop_val, Some(Value::Null)) {
state.update(prop_val);
}
}
state.finalize()
} else {
Value::Null
};
result_column.push_value(agg_result);
}
output_columns.push(result_column);
Ok(Some(DataChunk::new(output_columns)))
}
fn reset(&mut self) {
self.child.reset();
}
fn name(&self) -> &'static str {
"HorizontalAggregate"
}
fn into_any(self: Box<Self>) -> Box<dyn std::any::Any + Send> {
self
}
}
#[cfg(all(test, feature = "lpg"))]
mod tests {
use super::*;
use crate::execution::chunk::DataChunkBuilder;
use crate::graph::lpg::LpgStore;
struct MockOperator {
chunks: Vec<DataChunk>,
position: usize,
}
impl MockOperator {
fn new(chunks: Vec<DataChunk>) -> Self {
Self {
chunks,
position: 0,
}
}
}
impl Operator for MockOperator {
fn next(&mut self) -> OperatorResult {
if self.position < self.chunks.len() {
let chunk = std::mem::replace(&mut self.chunks[self.position], DataChunk::empty());
self.position += 1;
Ok(Some(chunk))
} else {
Ok(None)
}
}
fn reset(&mut self) {
self.position = 0;
}
fn name(&self) -> &'static str {
"Mock"
}
fn into_any(self: Box<Self>) -> Box<dyn std::any::Any + Send> {
self
}
}
#[allow(clippy::cast_possible_wrap)]
fn setup_store_with_edges() -> (Arc<dyn GraphStoreSearch>, Vec<Value>) {
let store = LpgStore::new().unwrap();
let n1 = store.create_node(&[]);
let n2 = store.create_node(&[]);
let e1 = store.create_edge(n1, n2, "ROAD");
let e2 = store.create_edge(n1, n2, "ROAD");
let e3 = store.create_edge(n1, n2, "ROAD");
store.set_edge_property(e1, "weight", Value::Float64(1.5));
store.set_edge_property(e2, "weight", Value::Float64(2.5));
store.set_edge_property(e3, "weight", Value::Float64(3.0));
let edge_ids: Vec<Value> = vec![
Value::Int64(e1.0 as i64),
Value::Int64(e2.0 as i64),
Value::Int64(e3.0 as i64),
];
(Arc::new(store), edge_ids)
}
#[allow(clippy::cast_possible_wrap)]
fn setup_store_with_nodes() -> (Arc<dyn GraphStoreSearch>, Vec<Value>) {
let store = LpgStore::new().unwrap();
let n1 = store.create_node_with_props(&["City"], [("pop", Value::Float64(100.0))]);
let n2 = store.create_node_with_props(&["City"], [("pop", Value::Float64(200.0))]);
let n3 = store.create_node_with_props(&["City"], [("pop", Value::Float64(300.0))]);
let node_ids: Vec<Value> = vec![
Value::Int64(n1.0 as i64),
Value::Int64(n2.0 as i64),
Value::Int64(n3.0 as i64),
];
(Arc::new(store), node_ids)
}
#[test]
fn test_horizontal_sum_over_edges() {
let (store, edge_ids) = setup_store_with_edges();
let mut builder = DataChunkBuilder::new(&[LogicalType::String, LogicalType::Any]);
builder
.column_mut(0)
.unwrap()
.push_value(Value::String("path1".into()));
builder
.column_mut(1)
.unwrap()
.push_value(Value::List(edge_ids.into()));
builder.advance_row();
let chunk = builder.finish();
let mock = MockOperator::new(vec![chunk]);
let mut op = HorizontalAggregateOperator::new(
Box::new(mock),
1, EntityKind::Edge,
AggregateFunction::Sum,
"weight".to_string(),
store,
2, );
let result = op.next().unwrap().unwrap();
assert_eq!(result.row_count(), 1);
let agg_val = result.column(2).unwrap().get_float64(0).unwrap();
assert!((agg_val - 7.0).abs() < 0.001);
assert!(op.next().unwrap().is_none());
}
#[test]
fn test_horizontal_sum_over_nodes() {
let (store, node_ids) = setup_store_with_nodes();
let mut builder = DataChunkBuilder::new(&[LogicalType::Any]);
builder
.column_mut(0)
.unwrap()
.push_value(Value::List(node_ids.into()));
builder.advance_row();
let chunk = builder.finish();
let mock = MockOperator::new(vec![chunk]);
let mut op = HorizontalAggregateOperator::new(
Box::new(mock),
0,
EntityKind::Node,
AggregateFunction::Sum,
"pop".to_string(),
store,
1,
);
let result = op.next().unwrap().unwrap();
assert_eq!(result.row_count(), 1);
let agg_val = result.column(1).unwrap().get_float64(0).unwrap();
assert!((agg_val - 600.0).abs() < 0.001);
}
#[test]
fn test_horizontal_avg_over_edges() {
let (store, edge_ids) = setup_store_with_edges();
let mut builder = DataChunkBuilder::new(&[LogicalType::Any]);
builder
.column_mut(0)
.unwrap()
.push_value(Value::List(edge_ids.into()));
builder.advance_row();
let chunk = builder.finish();
let mock = MockOperator::new(vec![chunk]);
let mut op = HorizontalAggregateOperator::new(
Box::new(mock),
0,
EntityKind::Edge,
AggregateFunction::Avg,
"weight".to_string(),
store,
1,
);
let result = op.next().unwrap().unwrap();
let agg_val = result.column(1).unwrap().get_float64(0).unwrap();
assert!((agg_val - 7.0 / 3.0).abs() < 0.001);
}
#[test]
fn test_horizontal_min_max_over_edges() {
let (store, edge_ids) = setup_store_with_edges();
let mut builder = DataChunkBuilder::new(&[LogicalType::Any]);
builder
.column_mut(0)
.unwrap()
.push_value(Value::List(edge_ids.clone().into()));
builder.advance_row();
let chunk = builder.finish();
let mock = MockOperator::new(vec![chunk]);
let mut op = HorizontalAggregateOperator::new(
Box::new(mock),
0,
EntityKind::Edge,
AggregateFunction::Min,
"weight".to_string(),
Arc::clone(&store),
1,
);
let result = op.next().unwrap().unwrap();
let min_val = result.column(1).unwrap().get_float64(0).unwrap();
assert!((min_val - 1.5).abs() < 0.001);
let mut builder = DataChunkBuilder::new(&[LogicalType::Any]);
builder
.column_mut(0)
.unwrap()
.push_value(Value::List(edge_ids.into()));
builder.advance_row();
let chunk = builder.finish();
let mock = MockOperator::new(vec![chunk]);
let mut op = HorizontalAggregateOperator::new(
Box::new(mock),
0,
EntityKind::Edge,
AggregateFunction::Max,
"weight".to_string(),
store,
1,
);
let result = op.next().unwrap().unwrap();
let max_val = result.column(1).unwrap().get_float64(0).unwrap();
assert!((max_val - 3.0).abs() < 0.001);
}
#[test]
fn test_horizontal_empty_list_returns_null() {
let store: Arc<dyn GraphStoreSearch> = Arc::new(LpgStore::new().unwrap());
let mut builder = DataChunkBuilder::new(&[LogicalType::Any]);
builder
.column_mut(0)
.unwrap()
.push_value(Value::List(vec![].into()));
builder.advance_row();
let chunk = builder.finish();
let mock = MockOperator::new(vec![chunk]);
let mut op = HorizontalAggregateOperator::new(
Box::new(mock),
0,
EntityKind::Edge,
AggregateFunction::Sum,
"weight".to_string(),
store,
1,
);
let result = op.next().unwrap().unwrap();
assert_eq!(result.row_count(), 1);
let agg_val = result.column(1).unwrap().get_value(0);
assert!(
matches!(agg_val, Some(Value::Null)),
"Expected Null, got {agg_val:?}"
);
}
#[test]
fn test_horizontal_non_list_column_returns_null() {
let store: Arc<dyn GraphStoreSearch> = Arc::new(LpgStore::new().unwrap());
let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
builder.column_mut(0).unwrap().push_int64(42);
builder.advance_row();
let chunk = builder.finish();
let mock = MockOperator::new(vec![chunk]);
let mut op = HorizontalAggregateOperator::new(
Box::new(mock),
0,
EntityKind::Edge,
AggregateFunction::Sum,
"weight".to_string(),
store,
1,
);
let result = op.next().unwrap().unwrap();
let agg_val = result.column(1).unwrap().get_value(0);
assert_eq!(agg_val, Some(Value::Null));
}
#[test]
fn test_horizontal_multiple_rows() {
let (store, edge_ids) = setup_store_with_edges();
let mut builder = DataChunkBuilder::new(&[LogicalType::String, LogicalType::Any]);
builder
.column_mut(0)
.unwrap()
.push_value(Value::String("path_all".into()));
builder
.column_mut(1)
.unwrap()
.push_value(Value::List(edge_ids.clone().into()));
builder.advance_row();
builder
.column_mut(0)
.unwrap()
.push_value(Value::String("path_one".into()));
builder
.column_mut(1)
.unwrap()
.push_value(Value::List(vec![edge_ids[0].clone()].into()));
builder.advance_row();
let chunk = builder.finish();
let mock = MockOperator::new(vec![chunk]);
let mut op = HorizontalAggregateOperator::new(
Box::new(mock),
1,
EntityKind::Edge,
AggregateFunction::Sum,
"weight".to_string(),
store,
2,
);
let result = op.next().unwrap().unwrap();
assert_eq!(result.row_count(), 2);
let val0 = result.column(2).unwrap().get_float64(0).unwrap();
assert!((val0 - 7.0).abs() < 0.001);
let val1 = result.column(2).unwrap().get_float64(1).unwrap();
assert!((val1 - 1.5).abs() < 0.001);
}
#[test]
fn test_horizontal_reset() {
let (store, edge_ids) = setup_store_with_edges();
let mut builder1 = DataChunkBuilder::new(&[LogicalType::Any]);
builder1
.column_mut(0)
.unwrap()
.push_value(Value::List(edge_ids.clone().into()));
builder1.advance_row();
let mut builder2 = DataChunkBuilder::new(&[LogicalType::Any]);
builder2
.column_mut(0)
.unwrap()
.push_value(Value::List(edge_ids.into()));
builder2.advance_row();
let mock = MockOperator::new(vec![builder1.finish(), builder2.finish()]);
let mut op = HorizontalAggregateOperator::new(
Box::new(mock),
0,
EntityKind::Edge,
AggregateFunction::Sum,
"weight".to_string(),
store,
1,
);
let result = op.next().unwrap().unwrap();
assert_eq!(result.row_count(), 1);
let result = op.next().unwrap().unwrap();
assert_eq!(result.row_count(), 1);
assert!(op.next().unwrap().is_none());
op.reset();
}
#[test]
fn test_horizontal_name() {
let store: Arc<dyn GraphStoreSearch> = Arc::new(LpgStore::new().unwrap());
let mock = MockOperator::new(vec![]);
let op = HorizontalAggregateOperator::new(
Box::new(mock),
0,
EntityKind::Edge,
AggregateFunction::Sum,
"weight".to_string(),
store,
1,
);
assert_eq!(op.name(), "HorizontalAggregate");
}
#[test]
fn test_horizontal_child_returns_none() {
let store: Arc<dyn GraphStoreSearch> = Arc::new(LpgStore::new().unwrap());
let mock = MockOperator::new(vec![]); let mut op = HorizontalAggregateOperator::new(
Box::new(mock),
0,
EntityKind::Edge,
AggregateFunction::Sum,
"weight".to_string(),
store,
1,
);
assert!(op.next().unwrap().is_none());
}
#[test]
fn test_horizontal_aggregate_into_any() {
let store: Arc<dyn GraphStoreSearch> = Arc::new(LpgStore::new().unwrap());
let mock = MockOperator::new(vec![]);
let op = HorizontalAggregateOperator::new(
Box::new(mock),
0,
EntityKind::Node,
AggregateFunction::Count,
"name".to_string(),
store,
1,
);
let any = Box::new(op).into_any();
assert!(any.downcast::<HorizontalAggregateOperator>().is_ok());
}
}