use std::sync::Arc;
use relational_engine::{Row, Value};
use serde::{Deserialize, Serialize};
use tensor_store::{PartitionResult, Partitioner, SemanticPartitioner};
use crate::{QueryResult, Result, SimilarResult};
pub type ShardId = usize;
#[derive(Debug, Clone)]
pub enum QueryPlan {
Local { query: String },
Remote { shard: ShardId, query: String },
ScatterGather {
shards: Vec<ShardId>,
query: String,
merge: MergeStrategy,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum MergeStrategy {
Union,
TopK(usize),
Aggregate(AggregateFunction),
FirstNonEmpty,
Concat,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
pub enum AggregateFunction {
Sum,
Count,
Avg,
Max,
Min,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ShardResult {
pub shard: ShardId,
pub result: QueryResult,
pub execution_time_us: u64,
pub error: Option<String>,
}
impl ShardResult {
#[must_use]
pub const fn success(shard: ShardId, result: QueryResult, execution_time_us: u64) -> Self {
Self {
shard,
result,
execution_time_us,
error: None,
}
}
#[must_use]
pub const fn error(shard: ShardId, error: String) -> Self {
Self {
shard,
result: QueryResult::Empty,
execution_time_us: 0,
error: Some(error),
}
}
}
#[derive(Debug, Clone)]
pub struct DistributedQueryConfig {
pub max_concurrent: usize,
pub shard_timeout_ms: u64,
pub retry_count: usize,
pub fail_fast: bool,
}
impl Default for DistributedQueryConfig {
fn default() -> Self {
Self {
max_concurrent: 10,
shard_timeout_ms: 5000,
retry_count: 2,
fail_fast: false,
}
}
}
#[derive(Debug)]
pub struct QueryPlanner {
partitioner: Arc<dyn Partitioner + Send + Sync>,
semantic_partitioner: Option<Arc<SemanticPartitioner>>,
#[allow(dead_code)]
local_shard: ShardId,
}
impl QueryPlanner {
pub fn new(partitioner: Arc<dyn Partitioner + Send + Sync>, local_shard: ShardId) -> Self {
Self {
partitioner,
semantic_partitioner: None,
local_shard,
}
}
#[must_use]
pub fn with_semantic_partitioner(mut self, partitioner: Arc<SemanticPartitioner>) -> Self {
self.semantic_partitioner = Some(partitioner);
self
}
#[must_use]
pub fn plan(&self, query: &str) -> QueryPlan {
let query_type = Self::classify_query(query);
match query_type {
QueryType::PointLookup { key } => {
let result = self.partitioner.partition(&key);
if result.is_local {
QueryPlan::Local {
query: query.to_string(),
}
} else {
QueryPlan::Remote {
shard: self.shard_from_result(&result),
query: query.to_string(),
}
}
},
QueryType::SimilaritySearch { k } => {
QueryPlan::ScatterGather {
shards: self.all_shards(),
query: query.to_string(),
merge: MergeStrategy::TopK(k),
}
},
QueryType::TableScan => {
QueryPlan::ScatterGather {
shards: self.all_shards(),
query: query.to_string(),
merge: MergeStrategy::Union,
}
},
QueryType::Aggregate { func } => {
QueryPlan::ScatterGather {
shards: self.all_shards(),
query: query.to_string(),
merge: MergeStrategy::Aggregate(func),
}
},
QueryType::Unknown => {
QueryPlan::Local {
query: query.to_string(),
}
},
}
}
#[must_use]
pub fn plan_with_embedding(&self, query: &str, embedding: &[f32]) -> QueryPlan {
let relevant_shards = self.shards_for_embedding(embedding);
if relevant_shards.is_empty() {
return self.plan(query);
}
let query_type = Self::classify_query(query);
match query_type {
QueryType::SimilaritySearch { k } => QueryPlan::ScatterGather {
shards: relevant_shards,
query: query.to_string(),
merge: MergeStrategy::TopK(k),
},
_ => self.plan(query),
}
}
fn all_shards(&self) -> Vec<ShardId> {
let nodes = self.partitioner.nodes();
(0..nodes.len()).collect()
}
fn shard_from_result(&self, result: &PartitionResult) -> ShardId {
let nodes = self.partitioner.nodes();
nodes.iter().position(|n| *n == result.primary).unwrap_or(0)
}
fn shards_for_embedding(&self, embedding: &[f32]) -> Vec<ShardId> {
if let Some(sp) = &self.semantic_partitioner {
let results = sp.shards_for_embedding(embedding);
if !results.is_empty() {
return results.into_iter().map(|(shard, _score)| shard).collect();
}
}
self.all_shards()
}
fn classify_query(query: &str) -> QueryType {
let query_upper = query.to_uppercase();
let query_trimmed = query_upper.trim();
if query_trimmed.starts_with("GET ")
|| query_trimmed.starts_with("NODE GET ")
|| query_trimmed.starts_with("ENTITY GET ")
{
if let Some(key) = Self::extract_key(query) {
return QueryType::PointLookup { key };
}
}
if query_trimmed.starts_with("SIMILAR ") {
let k = Self::extract_top_k(query).unwrap_or(10);
return QueryType::SimilaritySearch { k };
}
if query_trimmed.starts_with("SELECT ") || query_trimmed.starts_with("NODE LIST") {
if query_trimmed.contains("COUNT(") {
return QueryType::Aggregate {
func: AggregateFunction::Count,
};
}
if query_trimmed.contains("SUM(") {
return QueryType::Aggregate {
func: AggregateFunction::Sum,
};
}
if query_trimmed.contains("AVG(") {
return QueryType::Aggregate {
func: AggregateFunction::Avg,
};
}
return QueryType::TableScan;
}
QueryType::Unknown
}
fn extract_key(query: &str) -> Option<String> {
let parts: Vec<&str> = query.split_whitespace().collect();
if parts.len() >= 2 {
for (i, part) in parts.iter().enumerate() {
if part.eq_ignore_ascii_case("GET") && i + 1 < parts.len() {
return Some(parts[i + 1].to_string());
}
}
}
None
}
fn extract_top_k(query: &str) -> Option<usize> {
let query_upper = query.to_uppercase();
if let Some(pos) = query_upper.find("TOP ") {
let rest = &query_upper[pos + 4..];
let num_str: String = rest.chars().take_while(char::is_ascii_digit).collect();
return num_str.parse().ok();
}
None
}
}
#[derive(Debug)]
enum QueryType {
PointLookup { key: String },
SimilaritySearch { k: usize },
TableScan,
Aggregate { func: AggregateFunction },
Unknown,
}
#[derive(Debug)]
pub struct ResultMerger;
impl ResultMerger {
pub fn merge(results: Vec<ShardResult>, strategy: &MergeStrategy) -> Result<QueryResult> {
let successful: Vec<_> = results.into_iter().filter(|r| r.error.is_none()).collect();
if successful.is_empty() {
return Ok(QueryResult::Empty);
}
Ok(match strategy {
MergeStrategy::Union => Self::merge_union(successful),
MergeStrategy::TopK(k) => Self::merge_top_k(successful, *k),
MergeStrategy::Aggregate(func) => Self::merge_aggregate(successful, *func),
MergeStrategy::FirstNonEmpty => Self::merge_first_non_empty(successful),
MergeStrategy::Concat => Self::merge_concat(successful),
})
}
fn merge_union(results: Vec<ShardResult>) -> QueryResult {
let mut all_rows = Vec::new();
let mut all_nodes = Vec::new();
let mut all_edges = Vec::new();
let mut all_similar = Vec::new();
for shard_result in results {
match shard_result.result {
QueryResult::Rows(rows) => all_rows.extend(rows),
QueryResult::Nodes(nodes) => all_nodes.extend(nodes),
QueryResult::Edges(edges) => all_edges.extend(edges),
QueryResult::Similar(similar) => all_similar.extend(similar),
QueryResult::Count(n) => {
#[allow(clippy::cast_possible_wrap)]
let count_val = n as i64;
all_rows.push(Row {
id: 0,
values: vec![("count".to_string(), Value::Int(count_val))],
});
},
_ => {},
}
}
if !all_similar.is_empty() {
return QueryResult::Similar(all_similar);
}
if !all_nodes.is_empty() {
return QueryResult::Nodes(all_nodes);
}
if !all_edges.is_empty() {
return QueryResult::Edges(all_edges);
}
if !all_rows.is_empty() {
return QueryResult::Rows(all_rows);
}
QueryResult::Empty
}
fn merge_top_k(results: Vec<ShardResult>, k: usize) -> QueryResult {
let mut all_similar: Vec<SimilarResult> = Vec::new();
for shard_result in results {
if let QueryResult::Similar(similar) = shard_result.result {
all_similar.extend(similar);
}
}
all_similar.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
all_similar.truncate(k);
QueryResult::Similar(all_similar)
}
fn merge_aggregate(results: Vec<ShardResult>, func: AggregateFunction) -> QueryResult {
let mut values: Vec<i64> = Vec::new();
for shard_result in results {
match shard_result.result {
QueryResult::Count(n) => {
#[allow(clippy::cast_possible_wrap)]
let count_val = n as i64;
values.push(count_val);
},
QueryResult::Value(s) => {
if let Ok(n) = s.parse::<i64>() {
values.push(n);
}
},
_ => {},
}
}
if values.is_empty() {
return QueryResult::Count(0);
}
#[allow(
clippy::cast_possible_truncation,
clippy::cast_sign_loss,
clippy::cast_possible_wrap
)]
let result = match func {
AggregateFunction::Sum | AggregateFunction::Count => {
values.iter().sum::<i64>() as usize
},
AggregateFunction::Max => *values.iter().max().unwrap_or(&0) as usize,
AggregateFunction::Min => *values.iter().min().unwrap_or(&0) as usize,
AggregateFunction::Avg => (values.iter().sum::<i64>() / (values.len() as i64)) as usize,
};
QueryResult::Count(result)
}
fn merge_first_non_empty(results: Vec<ShardResult>) -> QueryResult {
for shard_result in results {
if !matches!(&shard_result.result, QueryResult::Empty) {
return shard_result.result;
}
}
QueryResult::Empty
}
fn merge_concat(results: Vec<ShardResult>) -> QueryResult {
Self::merge_union(results)
}
}
#[derive(Debug, Clone, Default)]
pub struct DistributedQueryStats {
pub queries_executed: u64,
pub local_queries: u64,
pub remote_queries: u64,
pub scatter_gather_queries: u64,
pub shards_contacted: u64,
pub avg_latency_us: u64,
pub shard_errors: u64,
}
impl DistributedQueryStats {
pub const fn record_query(&mut self, plan: &QueryPlan, latency_us: u64, errors: usize) {
self.queries_executed += 1;
match plan {
QueryPlan::Local { .. } => {
self.local_queries += 1;
self.shards_contacted += 1;
},
QueryPlan::Remote { .. } => {
self.remote_queries += 1;
self.shards_contacted += 1;
},
QueryPlan::ScatterGather { shards, .. } => {
self.scatter_gather_queries += 1;
self.shards_contacted += shards.len() as u64;
},
}
self.shard_errors += errors as u64;
if self.queries_executed == 1 {
self.avg_latency_us = latency_us;
} else {
self.avg_latency_us = (self.avg_latency_us * (self.queries_executed - 1) + latency_us)
/ self.queries_executed;
}
}
}
#[cfg(test)]
mod tests {
use tensor_store::{ConsistentHashConfig, ConsistentHashPartitioner};
use super::*;
fn create_test_partitioner() -> Arc<dyn Partitioner + Send + Sync> {
let config = ConsistentHashConfig::new("node1").with_virtual_nodes(10);
let mut partitioner = ConsistentHashPartitioner::new(config);
partitioner.add_node("node1".to_string());
partitioner.add_node("node2".to_string());
partitioner.add_node("node3".to_string());
Arc::new(partitioner)
}
#[test]
fn test_query_plan_local() {
let partitioner = create_test_partitioner();
let planner = QueryPlanner::new(partitioner, 0);
let plan = planner.plan("GET some_key");
assert!(
matches!(plan, QueryPlan::Local { .. } | QueryPlan::Remote { .. }),
"Expected Local or Remote plan"
);
}
#[test]
fn test_query_plan_scatter_gather() {
let partitioner = create_test_partitioner();
let planner = QueryPlanner::new(partitioner, 0);
let plan = planner.plan("SELECT users");
assert!(
matches!(
plan,
QueryPlan::ScatterGather {
merge: MergeStrategy::Union,
..
}
),
"Expected ScatterGather with Union merge"
);
}
#[test]
fn test_query_plan_similar() {
let partitioner = create_test_partitioner();
let planner = QueryPlanner::new(partitioner, 0);
let plan = planner.plan("SIMILAR key TOP 5");
assert!(
matches!(
plan,
QueryPlan::ScatterGather {
merge: MergeStrategy::TopK(5),
..
}
),
"Expected ScatterGather with TopK(5) merge"
);
}
#[test]
fn test_query_plan_aggregate() {
let partitioner = create_test_partitioner();
let planner = QueryPlanner::new(partitioner, 0);
let plan = planner.plan("SELECT COUNT(*) FROM users");
assert!(
matches!(
plan,
QueryPlan::ScatterGather {
merge: MergeStrategy::Aggregate(AggregateFunction::Count),
..
}
),
"Expected ScatterGather with Count aggregate"
);
}
#[test]
fn test_merge_union() {
let results = vec![
ShardResult::success(0, QueryResult::Count(10), 100),
ShardResult::success(1, QueryResult::Count(20), 150),
];
let merged = ResultMerger::merge(results, &MergeStrategy::Union).unwrap();
let QueryResult::Rows(rows) = merged else {
panic!("Expected Rows result");
};
assert_eq!(rows.len(), 2);
}
#[test]
fn test_merge_top_k() {
let results = vec![
ShardResult::success(
0,
QueryResult::Similar(vec![
SimilarResult {
key: "a".to_string(),
score: 0.9,
},
SimilarResult {
key: "b".to_string(),
score: 0.8,
},
]),
100,
),
ShardResult::success(
1,
QueryResult::Similar(vec![SimilarResult {
key: "c".to_string(),
score: 0.95,
}]),
150,
),
];
let merged = ResultMerger::merge(results, &MergeStrategy::TopK(2)).unwrap();
match merged {
QueryResult::Similar(similar) => {
assert_eq!(similar.len(), 2);
assert_eq!(similar[0].key, "c"); assert_eq!(similar[1].key, "a");
},
_ => panic!("Expected Similar result"),
}
}
#[test]
fn test_merge_aggregate_sum() {
let results = vec![
ShardResult::success(0, QueryResult::Count(10), 100),
ShardResult::success(1, QueryResult::Count(20), 150),
ShardResult::success(2, QueryResult::Count(30), 200),
];
let merged =
ResultMerger::merge(results, &MergeStrategy::Aggregate(AggregateFunction::Sum))
.unwrap();
match merged {
QueryResult::Count(n) => assert_eq!(n, 60),
_ => panic!("Expected Count result"),
}
}
#[test]
fn test_merge_aggregate_avg() {
let results = vec![
ShardResult::success(0, QueryResult::Count(10), 100),
ShardResult::success(1, QueryResult::Count(20), 150),
ShardResult::success(2, QueryResult::Count(30), 200),
];
let merged =
ResultMerger::merge(results, &MergeStrategy::Aggregate(AggregateFunction::Avg))
.unwrap();
match merged {
QueryResult::Count(n) => assert_eq!(n, 20),
_ => panic!("Expected Count result"),
}
}
#[test]
fn test_merge_first_non_empty() {
let results = vec![
ShardResult::success(0, QueryResult::Empty, 100),
ShardResult::success(1, QueryResult::Value("found".to_string()), 150),
ShardResult::success(2, QueryResult::Value("also_found".to_string()), 200),
];
let merged = ResultMerger::merge(results, &MergeStrategy::FirstNonEmpty).unwrap();
match merged {
QueryResult::Value(s) => assert_eq!(s, "found"),
_ => panic!("Expected Value result"),
}
}
#[test]
fn test_shard_result_success() {
let result = ShardResult::success(0, QueryResult::Count(10), 100);
assert_eq!(result.shard, 0);
assert!(result.error.is_none());
assert_eq!(result.execution_time_us, 100);
}
#[test]
fn test_shard_result_error() {
let result = ShardResult::error(1, "timeout".to_string());
assert_eq!(result.shard, 1);
assert!(result.error.is_some());
assert_eq!(result.error.unwrap(), "timeout");
}
#[test]
fn test_config_default() {
let config = DistributedQueryConfig::default();
assert_eq!(config.max_concurrent, 10);
assert_eq!(config.shard_timeout_ms, 5000);
assert_eq!(config.retry_count, 2);
assert!(!config.fail_fast);
}
#[test]
fn test_stats_record_local() {
let mut stats = DistributedQueryStats::default();
let plan = QueryPlan::Local {
query: "GET key".to_string(),
};
stats.record_query(&plan, 100, 0);
assert_eq!(stats.queries_executed, 1);
assert_eq!(stats.local_queries, 1);
assert_eq!(stats.shards_contacted, 1);
assert_eq!(stats.avg_latency_us, 100);
}
#[test]
fn test_stats_record_scatter_gather() {
let mut stats = DistributedQueryStats::default();
let plan = QueryPlan::ScatterGather {
shards: vec![0, 1, 2],
query: "SELECT users".to_string(),
merge: MergeStrategy::Union,
};
stats.record_query(&plan, 500, 1);
assert_eq!(stats.queries_executed, 1);
assert_eq!(stats.scatter_gather_queries, 1);
assert_eq!(stats.shards_contacted, 3);
assert_eq!(stats.shard_errors, 1);
}
#[test]
fn test_merge_empty_results() {
let results: Vec<ShardResult> = vec![];
let merged = ResultMerger::merge(results, &MergeStrategy::Union).unwrap();
assert!(matches!(merged, QueryResult::Empty));
}
#[test]
fn test_merge_filters_errors() {
let results = vec![
ShardResult::success(0, QueryResult::Count(10), 100),
ShardResult::error(1, "timeout".to_string()),
];
let merged =
ResultMerger::merge(results, &MergeStrategy::Aggregate(AggregateFunction::Sum))
.unwrap();
match merged {
QueryResult::Count(n) => assert_eq!(n, 10), _ => panic!("Expected Count result"),
}
}
#[test]
fn test_planner_extract_key() {
assert_eq!(
QueryPlanner::extract_key("GET mykey"),
Some("mykey".to_string())
);
assert_eq!(
QueryPlanner::extract_key("NODE GET user:123"),
Some("user:123".to_string())
);
}
#[test]
fn test_planner_extract_top_k() {
assert_eq!(QueryPlanner::extract_top_k("SIMILAR key TOP 5"), Some(5));
assert_eq!(
QueryPlanner::extract_top_k("SIMILAR key TOP 100"),
Some(100)
);
assert_eq!(QueryPlanner::extract_top_k("SIMILAR key"), None);
}
#[test]
fn test_aggregate_function_equality() {
assert_eq!(AggregateFunction::Sum, AggregateFunction::Sum);
assert_ne!(AggregateFunction::Sum, AggregateFunction::Count);
}
#[test]
fn test_all_shards() {
let partitioner = create_test_partitioner();
let planner = QueryPlanner::new(partitioner, 0);
let shards = planner.all_shards();
assert_eq!(shards.len(), 3);
assert_eq!(shards, vec![0, 1, 2]);
}
#[test]
fn test_plan_with_embedding() {
let partitioner = create_test_partitioner();
let planner = QueryPlanner::new(partitioner, 0);
let embedding = vec![1.0, 0.0, 0.0, 0.0];
let plan = planner.plan_with_embedding("SIMILAR key TOP 10", &embedding);
match plan {
QueryPlan::ScatterGather { .. } => {},
_ => panic!("Expected ScatterGather plan"),
}
}
#[test]
fn test_merge_max() {
let results = vec![
ShardResult::success(0, QueryResult::Count(10), 100),
ShardResult::success(1, QueryResult::Count(50), 150),
ShardResult::success(2, QueryResult::Count(30), 200),
];
let merged =
ResultMerger::merge(results, &MergeStrategy::Aggregate(AggregateFunction::Max))
.unwrap();
match merged {
QueryResult::Count(n) => assert_eq!(n, 50),
_ => panic!("Expected Count result"),
}
}
#[test]
fn test_merge_min() {
let results = vec![
ShardResult::success(0, QueryResult::Count(10), 100),
ShardResult::success(1, QueryResult::Count(50), 150),
ShardResult::success(2, QueryResult::Count(30), 200),
];
let merged =
ResultMerger::merge(results, &MergeStrategy::Aggregate(AggregateFunction::Min))
.unwrap();
match merged {
QueryResult::Count(n) => assert_eq!(n, 10),
_ => panic!("Expected Count result"),
}
}
#[test]
fn test_stats_avg_latency_updates() {
let mut stats = DistributedQueryStats::default();
let plan = QueryPlan::Local {
query: "GET key".to_string(),
};
stats.record_query(&plan, 100, 0);
assert_eq!(stats.avg_latency_us, 100);
stats.record_query(&plan, 200, 0);
assert_eq!(stats.avg_latency_us, 150);
}
#[test]
fn test_merge_concat() {
let results = vec![
ShardResult::success(0, QueryResult::Count(10), 100),
ShardResult::success(1, QueryResult::Count(20), 150),
];
let merged = ResultMerger::merge(results, &MergeStrategy::Concat).unwrap();
match merged {
QueryResult::Rows(rows) => assert_eq!(rows.len(), 2),
_ => panic!("Expected Rows result"),
}
}
#[test]
fn test_merge_union_nodes() {
use crate::NodeResult;
let results = vec![
ShardResult::success(
0,
QueryResult::Nodes(vec![
NodeResult {
id: 1,
label: "Person".to_string(),
properties: std::collections::HashMap::new(),
},
NodeResult {
id: 2,
label: "Person".to_string(),
properties: std::collections::HashMap::new(),
},
]),
100,
),
ShardResult::success(
1,
QueryResult::Nodes(vec![NodeResult {
id: 3,
label: "Person".to_string(),
properties: std::collections::HashMap::new(),
}]),
150,
),
];
let merged = ResultMerger::merge(results, &MergeStrategy::Union).unwrap();
match merged {
QueryResult::Nodes(nodes) => assert_eq!(nodes.len(), 3),
_ => panic!("Expected Nodes result"),
}
}
#[test]
fn test_merge_union_edges() {
use crate::EdgeResult;
let results = vec![
ShardResult::success(
0,
QueryResult::Edges(vec![EdgeResult {
id: 1,
from: 1,
to: 2,
label: "KNOWS".to_string(),
}]),
100,
),
ShardResult::success(
1,
QueryResult::Edges(vec![EdgeResult {
id: 2,
from: 2,
to: 3,
label: "KNOWS".to_string(),
}]),
150,
),
];
let merged = ResultMerger::merge(results, &MergeStrategy::Union).unwrap();
match merged {
QueryResult::Edges(edges) => assert_eq!(edges.len(), 2),
_ => panic!("Expected Edges result"),
}
}
#[test]
fn test_merge_union_empty_all() {
let results = vec![
ShardResult::success(0, QueryResult::Empty, 100),
ShardResult::success(1, QueryResult::Empty, 150),
];
let merged = ResultMerger::merge(results, &MergeStrategy::Union).unwrap();
assert!(matches!(merged, QueryResult::Empty));
}
#[test]
fn test_merge_first_non_empty_all_empty() {
let results = vec![
ShardResult::success(0, QueryResult::Empty, 100),
ShardResult::success(1, QueryResult::Empty, 150),
];
let merged = ResultMerger::merge(results, &MergeStrategy::FirstNonEmpty).unwrap();
assert!(matches!(merged, QueryResult::Empty));
}
#[test]
fn test_merge_aggregate_value_strings() {
let results = vec![
ShardResult::success(0, QueryResult::Value("100".to_string()), 100),
ShardResult::success(1, QueryResult::Value("200".to_string()), 150),
];
let merged =
ResultMerger::merge(results, &MergeStrategy::Aggregate(AggregateFunction::Sum))
.unwrap();
match merged {
QueryResult::Count(n) => assert_eq!(n, 300),
_ => panic!("Expected Count result"),
}
}
#[test]
fn test_merge_aggregate_empty_values() {
let results = vec![
ShardResult::success(0, QueryResult::Empty, 100),
ShardResult::success(1, QueryResult::Empty, 150),
];
let merged =
ResultMerger::merge(results, &MergeStrategy::Aggregate(AggregateFunction::Sum))
.unwrap();
match merged {
QueryResult::Count(n) => assert_eq!(n, 0),
_ => panic!("Expected Count result"),
}
}
#[test]
fn test_query_plan_node_list() {
let partitioner = create_test_partitioner();
let planner = QueryPlanner::new(partitioner, 0);
let plan = planner.plan("NODE LIST users");
match plan {
QueryPlan::ScatterGather {
merge: MergeStrategy::Union,
..
} => {},
_ => panic!("Expected ScatterGather with Union merge"),
}
}
#[test]
fn test_query_plan_select_sum() {
let partitioner = create_test_partitioner();
let planner = QueryPlanner::new(partitioner, 0);
let plan = planner.plan("SELECT SUM(amount) FROM orders");
match plan {
QueryPlan::ScatterGather {
merge: MergeStrategy::Aggregate(AggregateFunction::Sum),
..
} => {},
_ => panic!("Expected ScatterGather with Sum aggregate"),
}
}
#[test]
fn test_query_plan_select_avg() {
let partitioner = create_test_partitioner();
let planner = QueryPlanner::new(partitioner, 0);
let plan = planner.plan("SELECT AVG(price) FROM products");
match plan {
QueryPlan::ScatterGather {
merge: MergeStrategy::Aggregate(AggregateFunction::Avg),
..
} => {},
_ => panic!("Expected ScatterGather with Avg aggregate"),
}
}
#[test]
fn test_query_plan_unknown() {
let partitioner = create_test_partitioner();
let planner = QueryPlanner::new(partitioner, 0);
let plan = planner.plan("FOOBAR something");
match plan {
QueryPlan::Local { .. } => {},
_ => panic!("Expected Local plan for unknown query"),
}
}
#[test]
fn test_plan_with_embedding_non_similar() {
let partitioner = create_test_partitioner();
let planner = QueryPlanner::new(partitioner, 0);
let embedding = vec![1.0, 0.0, 0.0, 0.0];
let plan = planner.plan_with_embedding("SELECT * FROM users", &embedding);
match plan {
QueryPlan::ScatterGather { .. } => {},
_ => panic!("Expected ScatterGather plan"),
}
}
#[test]
fn test_extract_key_no_get() {
assert!(QueryPlanner::extract_key("something else").is_none());
}
#[test]
fn test_extract_key_empty() {
assert!(QueryPlanner::extract_key("").is_none());
}
#[test]
fn test_query_plan_node_get() {
let partitioner = create_test_partitioner();
let planner = QueryPlanner::new(partitioner, 0);
let plan = planner.plan("NODE GET user:123");
match plan {
QueryPlan::Local { .. } | QueryPlan::Remote { .. } => {},
_ => panic!("Expected Local or Remote plan"),
}
}
#[test]
fn test_query_plan_entity_get() {
let partitioner = create_test_partitioner();
let planner = QueryPlanner::new(partitioner, 0);
let plan = planner.plan("ENTITY GET entity:456");
match plan {
QueryPlan::Local { .. } | QueryPlan::Remote { .. } => {},
_ => panic!("Expected Local or Remote plan"),
}
}
#[test]
fn test_merge_top_k_non_similar_results() {
let results = vec![
ShardResult::success(0, QueryResult::Empty, 100),
ShardResult::success(1, QueryResult::Count(10), 150),
];
let merged = ResultMerger::merge(results, &MergeStrategy::TopK(5)).unwrap();
match merged {
QueryResult::Similar(similar) => assert!(similar.is_empty()),
_ => panic!("Expected Similar result"),
}
}
#[test]
fn test_merge_aggregate_avg_empty() {
let results = vec![ShardResult::success(0, QueryResult::Rows(vec![]), 100)];
let merged =
ResultMerger::merge(results, &MergeStrategy::Aggregate(AggregateFunction::Avg))
.unwrap();
match merged {
QueryResult::Count(n) => assert_eq!(n, 0),
_ => panic!("Expected Count result"),
}
}
#[test]
fn test_query_plan_get_only_no_key() {
let partitioner = create_test_partitioner();
let planner = QueryPlanner::new(partitioner, 0);
let plan = planner.plan("GET");
match plan {
QueryPlan::Local { .. } => {},
_ => panic!("Expected Local plan for GET without key"),
}
}
#[test]
fn test_query_plan_node_get_only() {
let partitioner = create_test_partitioner();
let planner = QueryPlanner::new(partitioner, 0);
let plan = planner.plan("NODE GET");
match plan {
QueryPlan::Local { .. } => {},
_ => panic!("Expected Local plan for NODE GET without key"),
}
}
#[test]
fn test_merge_union_other_result_types() {
let results = vec![
ShardResult::success(0, QueryResult::Path(vec![1, 2, 3]), 100),
ShardResult::success(1, QueryResult::Value("test".to_string()), 150),
];
let merged = ResultMerger::merge(results, &MergeStrategy::Union).unwrap();
assert!(matches!(merged, QueryResult::Empty));
}
#[test]
fn test_stats_record_remote() {
let mut stats = DistributedQueryStats::default();
let plan = QueryPlan::Remote {
shard: 1,
query: "GET key".to_string(),
};
stats.record_query(&plan, 100, 0);
assert_eq!(stats.queries_executed, 1);
assert_eq!(stats.remote_queries, 1);
assert_eq!(stats.shards_contacted, 1);
}
#[test]
fn test_extract_key_get_at_end() {
assert!(QueryPlanner::extract_key("something GET").is_none());
}
#[test]
fn test_plan_with_embedding_empty_partitioner() {
let config = ConsistentHashConfig::new("node1").with_virtual_nodes(10);
let partitioner = ConsistentHashPartitioner::new(config);
let partitioner: Arc<dyn Partitioner + Send + Sync> = Arc::new(partitioner);
let planner = QueryPlanner::new(partitioner, 0);
let embedding = vec![1.0, 0.0, 0.0, 0.0];
let plan = planner.plan_with_embedding("SIMILAR key TOP 10", &embedding);
match plan {
QueryPlan::Local { .. } | QueryPlan::ScatterGather { .. } => {},
_ => panic!("Expected Local or ScatterGather plan"),
}
}
#[test]
fn test_all_shards_empty() {
let config = ConsistentHashConfig::new("node1").with_virtual_nodes(10);
let partitioner = ConsistentHashPartitioner::new(config);
let partitioner: Arc<dyn Partitioner + Send + Sync> = Arc::new(partitioner);
let planner = QueryPlanner::new(partitioner, 0);
let shards = planner.all_shards();
assert!(shards.is_empty());
}
#[test]
fn test_plan_select_with_empty_partitioner() {
let config = ConsistentHashConfig::new("node1").with_virtual_nodes(10);
let partitioner = ConsistentHashPartitioner::new(config);
let partitioner: Arc<dyn Partitioner + Send + Sync> = Arc::new(partitioner);
let planner = QueryPlanner::new(partitioner, 0);
let plan = planner.plan("SELECT * FROM users");
match plan {
QueryPlan::ScatterGather { shards, .. } => {
assert!(shards.is_empty());
},
_ => panic!("Expected ScatterGather plan"),
}
}
#[test]
fn test_get_with_trailing_space_no_key() {
let partitioner = create_test_partitioner();
let planner = QueryPlanner::new(partitioner, 0);
let plan = planner.plan("GET ");
match plan {
QueryPlan::Local { .. } => {},
_ => panic!("Expected Local plan for GET without key"),
}
}
#[test]
fn test_merge_aggregate_unparseable_value() {
let results = vec![
ShardResult::success(0, QueryResult::Value("not_a_number".to_string()), 100),
ShardResult::success(1, QueryResult::Count(100), 150),
];
let merged =
ResultMerger::merge(results, &MergeStrategy::Aggregate(AggregateFunction::Sum))
.unwrap();
match merged {
QueryResult::Count(n) => assert_eq!(n, 100), _ => panic!("Expected Count result"),
}
}
#[test]
fn test_node_get_trailing_space() {
let partitioner = create_test_partitioner();
let planner = QueryPlanner::new(partitioner, 0);
let plan = planner.plan("NODE GET ");
match plan {
QueryPlan::Local { .. } => {},
_ => panic!("Expected Local plan"),
}
}
#[test]
fn test_debug_impls() {
let config = DistributedQueryConfig::default();
let _ = format!("{:?}", config);
let plan_local = QueryPlan::Local {
query: "test".to_string(),
};
let plan_remote = QueryPlan::Remote {
shard: 0,
query: "test".to_string(),
};
let plan_scatter = QueryPlan::ScatterGather {
shards: vec![0, 1],
query: "test".to_string(),
merge: MergeStrategy::Union,
};
let _ = format!("{:?}", plan_local);
let _ = format!("{:?}", plan_remote);
let _ = format!("{:?}", plan_scatter);
let _ = format!("{:?}", MergeStrategy::TopK(10));
let _ = format!("{:?}", MergeStrategy::Aggregate(AggregateFunction::Count));
let _ = format!("{:?}", MergeStrategy::FirstNonEmpty);
let _ = format!("{:?}", MergeStrategy::Concat);
let _ = format!("{:?}", AggregateFunction::Max);
let _ = format!("{:?}", AggregateFunction::Min);
let result = ShardResult::success(0, QueryResult::Empty, 100);
let _ = format!("{:?}", result);
let stats = DistributedQueryStats::default();
let _ = format!("{:?}", stats);
}
#[test]
fn test_shard_result_clone() {
let result = ShardResult::success(0, QueryResult::Count(10), 100);
let cloned = result.clone();
assert_eq!(cloned.shard, result.shard);
}
#[test]
fn test_config_clone() {
let config = DistributedQueryConfig::default();
let cloned = config.clone();
assert_eq!(cloned.max_concurrent, config.max_concurrent);
}
#[test]
fn test_stats_clone() {
let mut stats = DistributedQueryStats::default();
stats.queries_executed = 10;
let cloned = stats.clone();
assert_eq!(cloned.queries_executed, 10);
}
#[test]
fn test_merge_strategy_clone() {
let strategy = MergeStrategy::TopK(5);
let cloned = strategy.clone();
assert!(matches!(cloned, MergeStrategy::TopK(5)));
}
#[test]
fn test_aggregate_function_copy() {
let func = AggregateFunction::Sum;
let copied: AggregateFunction = func;
assert_eq!(copied, AggregateFunction::Sum);
}
#[test]
fn test_query_plan_clone() {
let plan = QueryPlan::Local {
query: "test".to_string(),
};
let cloned = plan.clone();
assert!(matches!(cloned, QueryPlan::Local { .. }));
}
}