use super::config::ShardConfig;
use super::types::ShardId;
use crate::core::id::NodeId;
use std::collections::{HashMap, HashSet};
const CROSS_SHARD_PENALTY: f64 = 2.0;
const HOP_COST: f64 = 1.5;
const CROSS_SHARD_HOP_COST: f64 = 3.0;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct TraversalStep {
pub shard_id: ShardId,
pub edge_labels: Vec<String>,
pub may_cross_shard: bool,
}
impl TraversalStep {
pub fn new(shard_id: ShardId, edge_labels: Vec<String>, may_cross_shard: bool) -> Self {
Self {
shard_id,
edge_labels,
may_cross_shard,
}
}
}
#[derive(Debug, Clone)]
pub struct TraversalPlan {
pub start_shard: ShardId,
pub involved_shards: HashSet<ShardId>,
pub steps: Vec<TraversalStep>,
pub is_distributed: bool,
pub estimated_cost: f64,
}
impl TraversalPlan {
pub fn single_shard(shard_id: ShardId) -> Self {
let mut involved = HashSet::new();
involved.insert(shard_id);
Self {
start_shard: shard_id,
involved_shards: involved,
steps: Vec::new(),
is_distributed: false,
estimated_cost: 1.0,
}
}
pub fn multi_shard(start_shard: ShardId, shards: HashSet<ShardId>) -> Self {
let is_distributed = shards.len() > 1;
Self {
start_shard,
involved_shards: shards,
steps: Vec::new(),
is_distributed,
estimated_cost: 1.0,
}
}
pub fn add_step(&mut self, step: TraversalStep) {
self.involved_shards.insert(step.shard_id);
if step.may_cross_shard {
self.is_distributed = true;
}
self.steps.push(step);
}
pub fn with_cost(mut self, cost: f64) -> Self {
self.estimated_cost = cost;
self
}
}
#[derive(Debug)]
pub struct ShardRouter {
config: ShardConfig,
label_to_shard: HashMap<String, ShardId>,
}
impl ShardRouter {
pub fn new(config: ShardConfig) -> Self {
let label_to_shard = config.build_label_map();
Self {
config,
label_to_shard,
}
}
pub fn route_node(&self, label: &str) -> ShardId {
*self
.label_to_shard
.get(label)
.unwrap_or(&self.config.default_shard)
}
pub fn route_node_by_id(&self, _node_id: NodeId, label: Option<&str>) -> ShardId {
match label {
Some(l) => self.route_node(l),
None => self.config.default_shard,
}
}
pub fn route_traversal(&self, start_label: &str, target_labels: &[&str]) -> TraversalPlan {
let start_shard = self.route_node(start_label);
if target_labels.is_empty() {
return TraversalPlan::multi_shard(
start_shard,
self.config.shard_ids().into_iter().collect(),
);
}
let mut involved_shards: HashSet<ShardId> = HashSet::new();
involved_shards.insert(start_shard);
for label in target_labels {
involved_shards.insert(self.route_node(label));
}
if involved_shards.len() == 1 {
TraversalPlan::single_shard(start_shard)
} else {
let base_cost = 1.0;
let cost = base_cost + (involved_shards.len() as f64 - 1.0) * CROSS_SHARD_PENALTY;
TraversalPlan::multi_shard(start_shard, involved_shards).with_cost(cost)
}
}
pub fn plan_multi_hop(
&self,
start_label: &str,
edge_labels: &[&str],
expected_target_labels: Option<&[&str]>,
) -> TraversalPlan {
let start_shard = self.route_node(start_label);
let mut plan = TraversalPlan::single_shard(start_shard);
let target_labels = expected_target_labels.unwrap_or(&[]);
let mut current_shard = start_shard;
for (i, edge_label) in edge_labels.iter().enumerate() {
let target_shard = if i < target_labels.len() {
self.route_node(target_labels[i])
} else {
current_shard
};
let may_cross_shard = target_shard != current_shard || i >= target_labels.len();
plan.add_step(TraversalStep::new(
current_shard,
vec![(*edge_label).to_string()],
may_cross_shard,
));
if i < target_labels.len() {
current_shard = target_shard;
}
}
let mut cost = 1.0;
for step in &plan.steps {
cost += HOP_COST;
if step.may_cross_shard {
cost += CROSS_SHARD_HOP_COST;
}
}
plan.estimated_cost = cost;
plan
}
pub fn route_edge_write(
&self,
source_label: &str,
target_label: &str,
) -> (ShardId, ShardId, bool) {
let source_shard = self.route_node(source_label);
let target_shard = self.route_node(target_label);
let is_cross_shard = source_shard != target_shard;
(source_shard, target_shard, is_cross_shard)
}
pub fn shards_for_label(&self, label: &str) -> Vec<ShardId> {
match self.label_to_shard.get(label) {
Some(shard_id) => vec![*shard_id],
None => vec![self.config.default_shard],
}
}
pub fn default_shard(&self) -> ShardId {
self.config.default_shard
}
pub fn config(&self) -> &ShardConfig {
&self.config
}
pub fn has_explicit_assignment(&self, label: &str) -> bool {
self.label_to_shard.contains_key(label)
}
pub fn assigned_labels(&self) -> Vec<&String> {
self.label_to_shard.keys().collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::storage::sharding::config::ShardDefinition;
fn test_config() -> ShardConfig {
ShardConfig::new(vec![
ShardDefinition::new(0, "shard0:9000", vec!["Person", "User", "Account"]),
ShardDefinition::new(1, "shard1:9000", vec!["Place", "Location", "Address"]),
ShardDefinition::new(2, "shard2:9000", vec!["Event", "Transaction", "Activity"]),
])
}
#[test]
fn test_router_creation() {
let config = test_config();
let router = ShardRouter::new(config);
assert!(router.has_explicit_assignment("Person"));
assert!(router.has_explicit_assignment("Place"));
assert!(!router.has_explicit_assignment("Unknown"));
}
#[test]
fn test_route_node() {
let router = ShardRouter::new(test_config());
assert_eq!(router.route_node("Person").as_u16(), 0);
assert_eq!(router.route_node("User").as_u16(), 0);
assert_eq!(router.route_node("Place").as_u16(), 1);
assert_eq!(router.route_node("Event").as_u16(), 2);
assert_eq!(router.route_node("Unknown").as_u16(), 0);
}
#[test]
fn test_route_node_by_id() {
let router = ShardRouter::new(test_config());
let node_id = NodeId::new(100).unwrap();
assert_eq!(router.route_node_by_id(node_id, Some("Person")).as_u16(), 0);
assert_eq!(router.route_node_by_id(node_id, Some("Place")).as_u16(), 1);
assert_eq!(router.route_node_by_id(node_id, None).as_u16(), 0); }
#[test]
fn test_route_traversal_single_shard() {
let router = ShardRouter::new(test_config());
let plan = router.route_traversal("Person", &["User", "Account"]);
assert!(!plan.is_distributed);
assert_eq!(plan.start_shard.as_u16(), 0);
assert_eq!(plan.involved_shards.len(), 1);
}
#[test]
fn test_route_traversal_cross_shard() {
let router = ShardRouter::new(test_config());
let plan = router.route_traversal("Person", &["Place", "Event"]);
assert!(plan.is_distributed);
assert_eq!(plan.involved_shards.len(), 3);
assert!(plan.estimated_cost > 1.0); }
#[test]
fn test_route_traversal_empty_targets() {
let router = ShardRouter::new(test_config());
let plan = router.route_traversal("Person", &[]);
assert!(plan.is_distributed);
assert_eq!(plan.involved_shards.len(), 3); }
#[test]
fn test_plan_multi_hop() {
let router = ShardRouter::new(test_config());
let plan = router.plan_multi_hop("Person", &["KNOWS", "KNOWS"], Some(&["User", "Account"]));
assert_eq!(plan.steps.len(), 2);
assert!(!plan.is_distributed);
let plan = router.plan_multi_hop(
"Person",
&["VISITED", "OCCURRED_AT"],
Some(&["Place", "Event"]),
);
assert_eq!(plan.steps.len(), 2);
assert!(plan.is_distributed);
assert!(plan.estimated_cost > 4.0); }
#[test]
fn test_plan_multi_hop_unknown_targets() {
let router = ShardRouter::new(test_config());
let plan = router.plan_multi_hop("Person", &["KNOWS", "VISITED"], None);
assert_eq!(plan.steps.len(), 2);
for step in &plan.steps {
assert!(step.may_cross_shard);
}
}
#[test]
fn test_route_edge_write() {
let router = ShardRouter::new(test_config());
let (src, tgt, cross) = router.route_edge_write("Person", "User");
assert_eq!(src.as_u16(), 0);
assert_eq!(tgt.as_u16(), 0);
assert!(!cross);
let (src, tgt, cross) = router.route_edge_write("Person", "Place");
assert_eq!(src.as_u16(), 0);
assert_eq!(tgt.as_u16(), 1);
assert!(cross);
}
#[test]
fn test_shards_for_label() {
let router = ShardRouter::new(test_config());
let shards = router.shards_for_label("Person");
assert_eq!(shards.len(), 1);
assert_eq!(shards[0].as_u16(), 0);
let shards = router.shards_for_label("Unknown");
assert_eq!(shards.len(), 1);
assert_eq!(shards[0].as_u16(), 0); }
#[test]
fn test_traversal_step() {
let step = TraversalStep::new(ShardId::new(1).unwrap(), vec!["KNOWS".to_string()], false);
assert_eq!(step.shard_id.as_u16(), 1);
assert_eq!(step.edge_labels, vec!["KNOWS"]);
assert!(!step.may_cross_shard);
}
#[test]
fn test_traversal_plan_operations() {
let shard0 = ShardId::new(0).unwrap();
let shard1 = ShardId::new(1).unwrap();
let mut plan = TraversalPlan::single_shard(shard0);
assert!(!plan.is_distributed);
assert_eq!(plan.involved_shards.len(), 1);
plan.add_step(TraversalStep::new(shard0, vec!["KNOWS".to_string()], false));
assert!(!plan.is_distributed);
plan.add_step(TraversalStep::new(
shard1,
vec!["VISITED".to_string()],
true,
));
assert!(plan.is_distributed);
assert_eq!(plan.involved_shards.len(), 2);
}
#[test]
fn test_assigned_labels() {
let router = ShardRouter::new(test_config());
let labels = router.assigned_labels();
assert!(labels.contains(&&"Person".to_string()));
assert!(labels.contains(&&"Place".to_string()));
assert!(labels.contains(&&"Event".to_string()));
assert_eq!(labels.len(), 9); }
#[test]
fn test_router_default_shard() {
let router = ShardRouter::new(test_config());
assert_eq!(router.default_shard().as_u16(), 0);
}
#[test]
fn test_router_config_access() {
let config = test_config();
let router = ShardRouter::new(config.clone());
assert_eq!(router.config().num_shards(), 3);
}
}