use crate::executor::pipeline::RowBatch;
use crate::executor::plan::{Predicate, Value};
use crate::executor::{ExecutionError, Result};
use std::collections::HashMap;
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
pub trait Operator: Send + Sync {
fn execute(&mut self, input: Option<RowBatch>) -> Result<Option<RowBatch>>;
fn name(&self) -> &str;
fn is_pipeline_breaker(&self) -> bool {
false
}
}
#[derive(Debug, Clone)]
pub enum ScanMode {
Sequential,
Index { index_name: String },
Range { start: Value, end: Value },
}
pub struct NodeScan {
mode: ScanMode,
filter: Option<Predicate>,
position: usize,
}
impl NodeScan {
pub fn new(mode: ScanMode, filter: Option<Predicate>) -> Self {
Self {
mode,
filter,
position: 0,
}
}
}
impl Operator for NodeScan {
fn execute(&mut self, _input: Option<RowBatch>) -> Result<Option<RowBatch>> {
Ok(None)
}
fn name(&self) -> &str {
"NodeScan"
}
}
pub struct EdgeScan {
mode: ScanMode,
filter: Option<Predicate>,
position: usize,
}
impl EdgeScan {
pub fn new(mode: ScanMode, filter: Option<Predicate>) -> Self {
Self {
mode,
filter,
position: 0,
}
}
}
impl Operator for EdgeScan {
fn execute(&mut self, _input: Option<RowBatch>) -> Result<Option<RowBatch>> {
Ok(None)
}
fn name(&self) -> &str {
"EdgeScan"
}
}
pub struct HyperedgeScan {
mode: ScanMode,
filter: Option<Predicate>,
}
impl HyperedgeScan {
pub fn new(mode: ScanMode, filter: Option<Predicate>) -> Self {
Self { mode, filter }
}
}
impl Operator for HyperedgeScan {
fn execute(&mut self, _input: Option<RowBatch>) -> Result<Option<RowBatch>> {
Ok(None)
}
fn name(&self) -> &str {
"HyperedgeScan"
}
}
pub struct Filter {
predicate: Predicate,
}
impl Filter {
pub fn new(predicate: Predicate) -> Self {
Self { predicate }
}
fn evaluate(&self, row: &HashMap<String, Value>) -> bool {
self.evaluate_predicate(&self.predicate, row)
}
fn evaluate_predicate(&self, pred: &Predicate, row: &HashMap<String, Value>) -> bool {
match pred {
Predicate::Equals(col, val) => row.get(col).map(|v| v == val).unwrap_or(false),
Predicate::NotEquals(col, val) => row.get(col).map(|v| v != val).unwrap_or(false),
Predicate::GreaterThan(col, val) => row
.get(col)
.and_then(|v| v.compare(val))
.map(|ord| ord == std::cmp::Ordering::Greater)
.unwrap_or(false),
Predicate::GreaterThanOrEqual(col, val) => row
.get(col)
.and_then(|v| v.compare(val))
.map(|ord| ord != std::cmp::Ordering::Less)
.unwrap_or(false),
Predicate::LessThan(col, val) => row
.get(col)
.and_then(|v| v.compare(val))
.map(|ord| ord == std::cmp::Ordering::Less)
.unwrap_or(false),
Predicate::LessThanOrEqual(col, val) => row
.get(col)
.and_then(|v| v.compare(val))
.map(|ord| ord != std::cmp::Ordering::Greater)
.unwrap_or(false),
Predicate::In(col, values) => row.get(col).map(|v| values.contains(v)).unwrap_or(false),
Predicate::Like(col, pattern) => {
if let Some(Value::String(s)) = row.get(col) {
self.pattern_match(s, pattern)
} else {
false
}
}
Predicate::And(preds) => preds.iter().all(|p| self.evaluate_predicate(p, row)),
Predicate::Or(preds) => preds.iter().any(|p| self.evaluate_predicate(p, row)),
Predicate::Not(pred) => !self.evaluate_predicate(pred, row),
}
}
fn pattern_match(&self, s: &str, pattern: &str) -> bool {
if pattern.starts_with('%') && pattern.ends_with('%') {
let p = &pattern[1..pattern.len() - 1];
s.contains(p)
} else if pattern.starts_with('%') {
let p = &pattern[1..];
s.ends_with(p)
} else if pattern.ends_with('%') {
let p = &pattern[..pattern.len() - 1];
s.starts_with(p)
} else {
s == pattern
}
}
#[cfg(target_arch = "x86_64")]
fn filter_batch_simd(&self, values: &[f32], threshold: f32) -> Vec<bool> {
if is_x86_feature_detected!("avx2") {
unsafe { self.filter_batch_avx2(values, threshold) }
} else {
self.filter_batch_scalar(values, threshold)
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn filter_batch_avx2(&self, values: &[f32], threshold: f32) -> Vec<bool> {
let mut result = vec![false; values.len()];
let threshold_vec = _mm256_set1_ps(threshold);
let chunks = values.len() / 8;
for i in 0..chunks {
let idx = i * 8;
let vals = _mm256_loadu_ps(values.as_ptr().add(idx));
let cmp = _mm256_cmp_ps(vals, threshold_vec, _CMP_GT_OQ);
let mask: [f32; 8] = std::mem::transmute(cmp);
for j in 0..8 {
result[idx + j] = mask[j] != 0.0;
}
}
for i in (chunks * 8)..values.len() {
result[i] = values[i] > threshold;
}
result
}
#[cfg(not(target_arch = "x86_64"))]
fn filter_batch_simd(&self, values: &[f32], threshold: f32) -> Vec<bool> {
self.filter_batch_scalar(values, threshold)
}
fn filter_batch_scalar(&self, values: &[f32], threshold: f32) -> Vec<bool> {
values.iter().map(|&v| v > threshold).collect()
}
}
impl Operator for Filter {
fn execute(&mut self, input: Option<RowBatch>) -> Result<Option<RowBatch>> {
if let Some(batch) = input {
let filtered_rows: Vec<_> = batch
.rows
.into_iter()
.filter(|row| self.evaluate(row))
.collect();
Ok(Some(RowBatch {
rows: filtered_rows,
schema: batch.schema,
}))
} else {
Ok(None)
}
}
fn name(&self) -> &str {
"Filter"
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum JoinType {
Inner,
LeftOuter,
RightOuter,
FullOuter,
}
pub struct Join {
join_type: JoinType,
on: Vec<(String, String)>,
hash_table: HashMap<Vec<Value>, Vec<HashMap<String, Value>>>,
built: bool,
}
impl Join {
pub fn new(join_type: JoinType, on: Vec<(String, String)>) -> Self {
Self {
join_type,
on,
hash_table: HashMap::new(),
built: false,
}
}
fn build_hash_table(&mut self, build_side: RowBatch) {
for row in build_side.rows {
let key: Vec<Value> = self
.on
.iter()
.filter_map(|(_, right_col)| row.get(right_col).cloned())
.collect();
self.hash_table
.entry(key)
.or_insert_with(Vec::new)
.push(row);
}
self.built = true;
}
fn probe(&self, probe_row: &HashMap<String, Value>) -> Vec<HashMap<String, Value>> {
let key: Vec<Value> = self
.on
.iter()
.filter_map(|(left_col, _)| probe_row.get(left_col).cloned())
.collect();
if let Some(matches) = self.hash_table.get(&key) {
matches
.iter()
.map(|right_row| {
let mut joined = probe_row.clone();
joined.extend(right_row.clone());
joined
})
.collect()
} else {
Vec::new()
}
}
}
impl Operator for Join {
fn execute(&mut self, input: Option<RowBatch>) -> Result<Option<RowBatch>> {
Ok(None)
}
fn name(&self) -> &str {
"Join"
}
fn is_pipeline_breaker(&self) -> bool {
true }
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum AggregateFunction {
Count,
Sum,
Avg,
Min,
Max,
}
pub struct Aggregate {
group_by: Vec<String>,
aggregates: Vec<(AggregateFunction, String)>,
state: HashMap<Vec<Value>, Vec<f64>>,
}
impl Aggregate {
pub fn new(group_by: Vec<String>, aggregates: Vec<(AggregateFunction, String)>) -> Self {
Self {
group_by,
aggregates,
state: HashMap::new(),
}
}
}
impl Operator for Aggregate {
fn execute(&mut self, input: Option<RowBatch>) -> Result<Option<RowBatch>> {
Ok(None)
}
fn name(&self) -> &str {
"Aggregate"
}
fn is_pipeline_breaker(&self) -> bool {
true
}
}
pub struct Project {
columns: Vec<String>,
}
impl Project {
pub fn new(columns: Vec<String>) -> Self {
Self { columns }
}
}
impl Operator for Project {
fn execute(&mut self, input: Option<RowBatch>) -> Result<Option<RowBatch>> {
if let Some(batch) = input {
let projected: Vec<_> = batch
.rows
.into_iter()
.map(|row| {
self.columns
.iter()
.filter_map(|col| row.get(col).map(|v| (col.clone(), v.clone())))
.collect()
})
.collect();
Ok(Some(RowBatch {
rows: projected,
schema: batch.schema,
}))
} else {
Ok(None)
}
}
fn name(&self) -> &str {
"Project"
}
}
pub struct Sort {
order_by: Vec<(String, crate::executor::plan::SortOrder)>,
buffer: Vec<HashMap<String, Value>>,
}
impl Sort {
pub fn new(order_by: Vec<(String, crate::executor::plan::SortOrder)>) -> Self {
Self {
order_by,
buffer: Vec::new(),
}
}
}
impl Operator for Sort {
fn execute(&mut self, input: Option<RowBatch>) -> Result<Option<RowBatch>> {
Ok(None)
}
fn name(&self) -> &str {
"Sort"
}
fn is_pipeline_breaker(&self) -> bool {
true
}
}
pub struct Limit {
limit: usize,
offset: usize,
current: usize,
}
impl Limit {
pub fn new(limit: usize, offset: usize) -> Self {
Self {
limit,
offset,
current: 0,
}
}
}
impl Operator for Limit {
fn execute(&mut self, input: Option<RowBatch>) -> Result<Option<RowBatch>> {
if let Some(batch) = input {
let start = self.offset.saturating_sub(self.current);
let end = start + self.limit;
let limited: Vec<_> = batch
.rows
.into_iter()
.skip(start)
.take(end - start)
.collect();
self.current += limited.len();
Ok(Some(RowBatch {
rows: limited,
schema: batch.schema,
}))
} else {
Ok(None)
}
}
fn name(&self) -> &str {
"Limit"
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_filter_operator() {
let mut filter = Filter::new(Predicate::Equals("id".to_string(), Value::Int64(42)));
let mut row = HashMap::new();
row.insert("id".to_string(), Value::Int64(42));
assert!(filter.evaluate(&row));
}
#[test]
fn test_pattern_matching() {
let filter = Filter::new(Predicate::Like("name".to_string(), "%test%".to_string()));
assert!(filter.pattern_match("this is a test", "%test%"));
}
#[test]
fn test_simd_filtering() {
let filter = Filter::new(Predicate::GreaterThan(
"value".to_string(),
Value::Float64(5.0),
));
let values = vec![1.0, 6.0, 3.0, 8.0, 4.0, 9.0, 2.0, 7.0];
let result = filter.filter_batch_simd(&values, 5.0);
assert_eq!(
result,
vec![false, true, false, true, false, true, false, true]
);
}
}