use crate::Result;
use oxirs_core::model::Triple;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::info;
pub type ShardId = u32;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ShardingStrategy {
Hash { num_shards: u32 },
Subject { num_shards: u32 },
Predicate {
predicate_groups: HashMap<String, ShardId>,
},
Namespace {
namespace_mapping: HashMap<String, ShardId>,
},
Graph {
graph_mapping: HashMap<String, ShardId>,
},
Semantic {
concept_clusters: Vec<ConceptCluster>,
similarity_threshold: f64,
},
Hybrid {
primary: Box<ShardingStrategy>,
secondary: Box<ShardingStrategy>,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConceptCluster {
pub cluster_id: ShardId,
pub core_concepts: HashSet<String>,
pub predicates: HashSet<String>,
pub namespace_patterns: Vec<String>,
pub weight: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ShardMetadata {
pub shard_id: ShardId,
pub node_ids: Vec<u64>,
pub primary_node: u64,
pub triple_count: usize,
pub size_bytes: u64,
pub state: ShardState,
pub last_updated: u64,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum ShardState {
Active,
Migrating,
Splitting,
Merging,
Offline,
}
pub struct ShardRouter {
strategy: ShardingStrategy,
shards: Arc<RwLock<HashMap<ShardId, ShardMetadata>>>,
similarity_calc: Option<Arc<dyn ConceptSimilarity>>,
routing_cache: Arc<RwLock<HashMap<String, ShardId>>>,
}
pub trait ConceptSimilarity: Send + Sync {
fn similarity(&self, concept1: &str, concept2: &str) -> f64;
fn find_cluster(&self, concept: &str, clusters: &[ConceptCluster]) -> Option<ShardId>;
}
impl ShardRouter {
pub fn new(strategy: ShardingStrategy) -> Self {
Self {
strategy,
shards: Arc::new(RwLock::new(HashMap::new())),
similarity_calc: None,
routing_cache: Arc::new(RwLock::new(HashMap::new())),
}
}
pub fn with_similarity_calculator(mut self, calc: Arc<dyn ConceptSimilarity>) -> Self {
self.similarity_calc = Some(calc);
self
}
pub async fn init_shards(&self, num_shards: u32, _nodes_per_shard: usize) -> Result<()> {
let mut shards = self.shards.write().await;
for shard_id in 0..num_shards {
let metadata = ShardMetadata {
shard_id,
node_ids: Vec::new(), primary_node: 0,
triple_count: 0,
size_bytes: 0,
state: ShardState::Active,
last_updated: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("SystemTime should be after UNIX_EPOCH")
.as_secs(),
};
shards.insert(shard_id, metadata);
}
info!("Initialized {} shards", num_shards);
Ok(())
}
pub async fn route_triple(&self, triple: &Triple) -> Result<ShardId> {
let cache_key = format!("{triple:?}");
if let Some(&shard_id) = self.routing_cache.read().await.get(&cache_key) {
return Ok(shard_id);
}
let shard_id = match &self.strategy {
ShardingStrategy::Hash { num_shards } => {
self.hash_route(&triple.subject().to_string(), *num_shards)
}
ShardingStrategy::Subject { num_shards } => {
self.hash_route(&triple.subject().to_string(), *num_shards)
}
ShardingStrategy::Predicate { predicate_groups } => {
let predicate_str = triple.predicate().to_string();
predicate_groups
.get(&predicate_str)
.copied()
.unwrap_or_else(|| {
self.hash_route(&predicate_str, predicate_groups.len() as u32)
})
}
ShardingStrategy::Namespace { namespace_mapping } => {
self.route_by_namespace(&triple.subject().to_string(), namespace_mapping)?
}
ShardingStrategy::Graph { graph_mapping } => {
self.hash_route(&triple.subject().to_string(), graph_mapping.len() as u32)
}
ShardingStrategy::Semantic {
concept_clusters,
similarity_threshold,
} => self.semantic_route(triple, concept_clusters, *similarity_threshold)?,
ShardingStrategy::Hybrid { primary, secondary } => {
match self.route_with_strategy(triple, primary) {
Ok(shard) => shard,
Err(_) => self.route_with_strategy(triple, secondary)?,
}
}
};
self.routing_cache.write().await.insert(cache_key, shard_id);
Ok(shard_id)
}
fn route_with_strategy(&self, triple: &Triple, strategy: &ShardingStrategy) -> Result<ShardId> {
match strategy {
ShardingStrategy::Hash { num_shards } => {
Ok(self.hash_route(&triple.subject().to_string(), *num_shards))
}
ShardingStrategy::Subject { num_shards } => {
Ok(self.hash_route(&triple.subject().to_string(), *num_shards))
}
ShardingStrategy::Predicate { predicate_groups } => {
let predicate_str = triple.predicate().to_string();
Ok(predicate_groups
.get(&predicate_str)
.copied()
.unwrap_or_else(|| {
self.hash_route(&predicate_str, predicate_groups.len() as u32)
}))
}
ShardingStrategy::Namespace { namespace_mapping } => {
self.route_by_namespace(&triple.subject().to_string(), namespace_mapping)
}
ShardingStrategy::Graph { graph_mapping } => {
Ok(self.hash_route(&triple.subject().to_string(), graph_mapping.len() as u32))
}
ShardingStrategy::Semantic {
concept_clusters,
similarity_threshold,
} => self.semantic_route(triple, concept_clusters, *similarity_threshold),
ShardingStrategy::Hybrid { .. } => {
Ok(self.hash_route(&triple.subject().to_string(), 10))
}
}
}
fn hash_route(&self, key: &str, num_shards: u32) -> ShardId {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
key.hash(&mut hasher);
(hasher.finish() % num_shards as u64) as ShardId
}
fn route_by_namespace(
&self,
iri: &str,
namespace_mapping: &HashMap<String, ShardId>,
) -> Result<ShardId> {
let clean_iri = if iri.starts_with('<') && iri.ends_with('>') {
&iri[1..iri.len() - 1]
} else {
iri
};
let namespace = if let Some(pos) = clean_iri.rfind('#') {
&clean_iri[..=pos]
} else if let Some(pos) = clean_iri.rfind('/') {
&clean_iri[..=pos]
} else {
clean_iri
};
if let Some(&shard_id) = namespace_mapping.get(namespace) {
return Ok(shard_id);
}
let mut prefixes: Vec<_> = namespace_mapping.iter().collect();
prefixes.sort_by_key(|b| std::cmp::Reverse(b.0.len()));
for (prefix, &shard_id) in prefixes {
if clean_iri.starts_with(prefix) {
return Ok(shard_id);
}
}
Ok(self.hash_route(namespace, namespace_mapping.len() as u32))
}
fn semantic_route(
&self,
triple: &Triple,
clusters: &[ConceptCluster],
_threshold: f64,
) -> Result<ShardId> {
if let Some(similarity_calc) = &self.similarity_calc {
let concept = triple.subject().to_string();
if let Some(cluster_id) = similarity_calc.find_cluster(&concept, clusters) {
return Ok(cluster_id);
}
}
Ok(self.hash_route(&triple.subject().to_string(), clusters.len() as u32))
}
pub async fn route_query_pattern(
&self,
subject: Option<&str>,
predicate: Option<&str>,
_object: Option<&str>,
) -> Result<Vec<ShardId>> {
match &self.strategy {
ShardingStrategy::Subject { num_shards } => {
if let Some(subj) = subject {
Ok(vec![self.hash_route(subj, *num_shards)])
} else {
Ok((0..*num_shards).collect())
}
}
ShardingStrategy::Predicate { predicate_groups } => {
if let Some(pred) = predicate {
if let Some(&shard_id) = predicate_groups.get(pred) {
Ok(vec![shard_id])
} else {
Ok(vec![self.hash_route(pred, predicate_groups.len() as u32)])
}
} else {
let mut shard_ids: Vec<ShardId> = predicate_groups.values().copied().collect();
shard_ids.sort_unstable();
shard_ids.dedup();
Ok(shard_ids)
}
}
ShardingStrategy::Hash { num_shards } => {
if let Some(subj) = subject {
Ok(vec![self.hash_route(subj, *num_shards)])
} else {
Ok((0..*num_shards).collect())
}
}
_ => {
let shards = self.shards.read().await;
Ok(shards.keys().copied().collect())
}
}
}
pub async fn get_shard_metadata(&self, shard_id: ShardId) -> Option<ShardMetadata> {
self.shards.read().await.get(&shard_id).cloned()
}
pub async fn update_shard_metadata(&self, metadata: ShardMetadata) -> Result<()> {
self.shards
.write()
.await
.insert(metadata.shard_id, metadata);
Ok(())
}
pub async fn get_statistics(&self) -> ShardingStatistics {
let shards = self.shards.read().await;
let total_triples: usize = shards.values().map(|s| s.triple_count).sum();
let total_size: u64 = shards.values().map(|s| s.size_bytes).sum();
let active_shards = shards
.values()
.filter(|s| s.state == ShardState::Active)
.count();
let mut distribution = Vec::new();
for shard in shards.values() {
distribution.push(ShardDistribution {
shard_id: shard.shard_id,
triple_count: shard.triple_count,
size_bytes: shard.size_bytes,
load_factor: if total_triples > 0 {
shard.triple_count as f64 / total_triples as f64
} else {
0.0
},
primary_node: shard.primary_node,
});
}
distribution.sort_by_key(|d| d.shard_id);
ShardingStatistics {
total_shards: shards.len(),
active_shards,
total_triples,
total_size,
distribution,
}
}
}
#[derive(Debug, Clone)]
pub struct ShardingStatistics {
pub total_shards: usize,
pub active_shards: usize,
pub total_triples: usize,
pub total_size: u64,
pub distribution: Vec<ShardDistribution>,
}
#[derive(Debug, Clone)]
pub struct ShardDistribution {
pub shard_id: ShardId,
pub triple_count: usize,
pub size_bytes: u64,
pub load_factor: f64,
pub primary_node: u64,
}
pub struct DefaultConceptSimilarity;
impl ConceptSimilarity for DefaultConceptSimilarity {
fn similarity(&self, concept1: &str, concept2: &str) -> f64 {
let common_prefix_len = concept1
.chars()
.zip(concept2.chars())
.take_while(|(a, b)| a == b)
.count();
let max_len = concept1.len().max(concept2.len());
if max_len > 0 {
common_prefix_len as f64 / max_len as f64
} else {
0.0
}
}
fn find_cluster(&self, concept: &str, clusters: &[ConceptCluster]) -> Option<ShardId> {
let clean_concept = if concept.starts_with('<') && concept.ends_with('>') {
&concept[1..concept.len() - 1]
} else {
concept
};
let mut best_cluster = None;
let mut best_score = 0.0;
for cluster in clusters {
for core_concept in &cluster.core_concepts {
let similarity = self.similarity(clean_concept, core_concept);
let weighted_score = similarity * cluster.weight;
if weighted_score > best_score {
best_score = weighted_score;
best_cluster = Some(cluster.cluster_id);
}
}
for pattern in &cluster.namespace_patterns {
if clean_concept.starts_with(pattern) {
return Some(cluster.cluster_id);
}
}
}
best_cluster
}
}
#[cfg(test)]
mod tests {
use super::*;
use oxirs_core::model::{NamedNode, Triple};
#[tokio::test]
async fn test_hash_sharding() {
let strategy = ShardingStrategy::Hash { num_shards: 4 };
let router = ShardRouter::new(strategy);
router.init_shards(4, 3).await.unwrap();
let triple = Triple::new(
NamedNode::new("http://example.org/subject1").unwrap(),
NamedNode::new("http://example.org/predicate1").unwrap(),
NamedNode::new("http://example.org/object1").unwrap(),
);
let shard_id = router.route_triple(&triple).await.unwrap();
assert!(shard_id < 4);
let shard_id2 = router.route_triple(&triple).await.unwrap();
assert_eq!(shard_id, shard_id2);
}
#[tokio::test]
async fn test_namespace_sharding() {
let mut namespace_mapping = HashMap::new();
namespace_mapping.insert("http://example.org/".to_string(), 0);
namespace_mapping.insert("http://schema.org/".to_string(), 1);
let strategy = ShardingStrategy::Namespace { namespace_mapping };
let router = ShardRouter::new(strategy);
let triple1 = Triple::new(
NamedNode::new("http://example.org/subject1").unwrap(),
NamedNode::new("http://example.org/predicate1").unwrap(),
NamedNode::new("http://example.org/object1").unwrap(),
);
let triple2 = Triple::new(
NamedNode::new("http://schema.org/Person").unwrap(),
NamedNode::new("http://schema.org/name").unwrap(),
NamedNode::new("http://example.org/john").unwrap(),
);
assert_eq!(router.route_triple(&triple1).await.unwrap(), 0);
assert_eq!(router.route_triple(&triple2).await.unwrap(), 1);
}
#[tokio::test]
async fn test_semantic_sharding() {
let clusters = vec![
ConceptCluster {
cluster_id: 0,
core_concepts: vec!["http://schema.org/Person".to_string()]
.into_iter()
.collect(),
predicates: vec!["http://schema.org/name".to_string()]
.into_iter()
.collect(),
namespace_patterns: vec!["http://schema.org/".to_string()],
weight: 1.0,
},
ConceptCluster {
cluster_id: 1,
core_concepts: vec!["http://example.org/Document".to_string()]
.into_iter()
.collect(),
predicates: vec!["http://example.org/title".to_string()]
.into_iter()
.collect(),
namespace_patterns: vec!["http://example.org/".to_string()],
weight: 1.0,
},
];
let strategy = ShardingStrategy::Semantic {
concept_clusters: clusters,
similarity_threshold: 0.5,
};
let router = ShardRouter::new(strategy)
.with_similarity_calculator(Arc::new(DefaultConceptSimilarity));
let triple = Triple::new(
NamedNode::new("http://schema.org/Person/123").unwrap(),
NamedNode::new("http://schema.org/name").unwrap(),
oxirs_core::model::Literal::new_simple_literal("John Doe"),
);
let shard_id = router.route_triple(&triple).await.unwrap();
assert_eq!(shard_id, 0); }
#[test]
fn test_concept_similarity() {
let calc = DefaultConceptSimilarity;
assert_eq!(
calc.similarity("http://example.org/Person", "http://example.org/Person"),
1.0
);
assert!(calc.similarity("http://example.org/Person", "http://example.org/Place") > 0.5);
assert!(calc.similarity("http://example.org/Person", "http://schema.org/Person") < 0.5);
}
}