use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum OperatorType {
SeqScan {
table: String,
},
IndexScan {
index: String,
table: String,
},
HnswScan {
index: String,
ef_search: u32,
},
IvfFlatScan {
index: String,
nprobe: u32,
},
NestedLoopJoin,
HashJoin {
hash_key: String,
},
MergeJoin {
merge_key: String,
},
Aggregate {
functions: Vec<String>,
},
GroupBy {
keys: Vec<String>,
},
Filter {
predicate: String,
},
Project {
columns: Vec<String>,
},
Sort {
keys: Vec<String>,
descending: Vec<bool>,
},
Limit {
count: usize,
},
VectorDistance {
metric: String,
},
Rerank {
model: String,
},
Materialize,
Result,
#[deprecated(note = "Use SeqScan instead")]
Scan,
#[deprecated(note = "Use HashJoin or NestedLoopJoin instead")]
Join,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OperatorNode {
pub id: usize,
pub op_type: OperatorType,
pub estimated_rows: f64,
pub estimated_cost: f64,
pub actual_rows: Option<f64>,
pub actual_time_ms: Option<f64>,
pub embedding: Option<Vec<f32>>,
}
impl OperatorNode {
pub fn new(id: usize, op_type: OperatorType) -> Self {
Self {
id,
op_type,
estimated_rows: 0.0,
estimated_cost: 0.0,
actual_rows: None,
actual_time_ms: None,
embedding: None,
}
}
pub fn seq_scan(id: usize, table: &str) -> Self {
Self::new(
id,
OperatorType::SeqScan {
table: table.to_string(),
},
)
}
pub fn index_scan(id: usize, index: &str, table: &str) -> Self {
Self::new(
id,
OperatorType::IndexScan {
index: index.to_string(),
table: table.to_string(),
},
)
}
pub fn hnsw_scan(id: usize, index: &str, ef_search: u32) -> Self {
Self::new(
id,
OperatorType::HnswScan {
index: index.to_string(),
ef_search,
},
)
}
pub fn ivf_flat_scan(id: usize, index: &str, nprobe: u32) -> Self {
Self::new(
id,
OperatorType::IvfFlatScan {
index: index.to_string(),
nprobe,
},
)
}
pub fn nested_loop_join(id: usize) -> Self {
Self::new(id, OperatorType::NestedLoopJoin)
}
pub fn hash_join(id: usize, key: &str) -> Self {
Self::new(
id,
OperatorType::HashJoin {
hash_key: key.to_string(),
},
)
}
pub fn merge_join(id: usize, key: &str) -> Self {
Self::new(
id,
OperatorType::MergeJoin {
merge_key: key.to_string(),
},
)
}
pub fn filter(id: usize, predicate: &str) -> Self {
Self::new(
id,
OperatorType::Filter {
predicate: predicate.to_string(),
},
)
}
pub fn project(id: usize, columns: Vec<String>) -> Self {
Self::new(id, OperatorType::Project { columns })
}
pub fn sort(id: usize, keys: Vec<String>) -> Self {
let descending = vec![false; keys.len()];
Self::new(id, OperatorType::Sort { keys, descending })
}
pub fn sort_with_order(id: usize, keys: Vec<String>, descending: Vec<bool>) -> Self {
Self::new(id, OperatorType::Sort { keys, descending })
}
pub fn limit(id: usize, count: usize) -> Self {
Self::new(id, OperatorType::Limit { count })
}
pub fn aggregate(id: usize, functions: Vec<String>) -> Self {
Self::new(id, OperatorType::Aggregate { functions })
}
pub fn group_by(id: usize, keys: Vec<String>) -> Self {
Self::new(id, OperatorType::GroupBy { keys })
}
pub fn vector_distance(id: usize, metric: &str) -> Self {
Self::new(
id,
OperatorType::VectorDistance {
metric: metric.to_string(),
},
)
}
pub fn rerank(id: usize, model: &str) -> Self {
Self::new(
id,
OperatorType::Rerank {
model: model.to_string(),
},
)
}
pub fn materialize(id: usize) -> Self {
Self::new(id, OperatorType::Materialize)
}
pub fn result(id: usize) -> Self {
Self::new(id, OperatorType::Result)
}
pub fn with_estimates(mut self, rows: f64, cost: f64) -> Self {
self.estimated_rows = rows;
self.estimated_cost = cost;
self
}
pub fn with_actuals(mut self, rows: f64, time_ms: f64) -> Self {
self.actual_rows = Some(rows);
self.actual_time_ms = Some(time_ms);
self
}
pub fn with_embedding(mut self, embedding: Vec<f32>) -> Self {
self.embedding = Some(embedding);
self
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_operator_node_creation() {
let node = OperatorNode::seq_scan(1, "users");
assert_eq!(node.id, 1);
assert!(matches!(node.op_type, OperatorType::SeqScan { .. }));
}
#[test]
fn test_builder_pattern() {
let node = OperatorNode::hash_join(2, "id")
.with_estimates(1000.0, 50.0)
.with_actuals(987.0, 45.2);
assert_eq!(node.estimated_rows, 1000.0);
assert_eq!(node.estimated_cost, 50.0);
assert_eq!(node.actual_rows, Some(987.0));
assert_eq!(node.actual_time_ms, Some(45.2));
}
#[test]
fn test_serialization() {
let node = OperatorNode::hnsw_scan(3, "embeddings_idx", 100);
let json = serde_json::to_string(&node).unwrap();
let deserialized: OperatorNode = serde_json::from_str(&json).unwrap();
assert_eq!(node.id, deserialized.id);
}
}