use super::node_registry::{Node, NodeRegistry};
use crate::{
domain::value_objects::{EntityId, PartitionKey},
error::{AllSourceError, Result},
};
use std::sync::Arc;
pub struct RequestRouter {
registry: Arc<NodeRegistry>,
}
impl RequestRouter {
pub fn new(registry: Arc<NodeRegistry>) -> Self {
Self { registry }
}
pub fn route_for_entity(&self, entity_id: &EntityId) -> Result<Node> {
let partition_key = PartitionKey::from_entity_id(entity_id.as_str());
self.route_for_partition(&partition_key)
}
pub fn route_for_partition(&self, partition_key: &PartitionKey) -> Result<Node> {
let partition_id = partition_key.partition_id();
let node_id = self
.registry
.node_for_partition(partition_id)
.ok_or_else(|| {
AllSourceError::StorageError(format!(
"No healthy node available for partition {partition_id}"
))
})?;
self.registry.get_node(node_id).ok_or_else(|| {
AllSourceError::InternalError(format!("Node {node_id} not found in registry"))
})
}
pub fn nodes_for_read(&self) -> Vec<Node> {
self.registry.healthy_nodes()
}
pub fn can_node_handle_entity(&self, entity_id: &EntityId, node_id: u32) -> bool {
let partition_key = PartitionKey::from_entity_id(entity_id.as_str());
let partition_id = partition_key.partition_id();
if let Some(assigned_node_id) = self.registry.node_for_partition(partition_id) {
assigned_node_id == node_id
} else {
false
}
}
pub fn partition_distribution(&self) -> std::collections::HashMap<u32, Vec<u32>> {
self.registry.partition_distribution()
}
pub fn is_available(&self) -> bool {
self.registry.is_cluster_healthy()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::infrastructure::cluster::node_registry::Node;
fn setup_cluster() -> (Arc<NodeRegistry>, RequestRouter) {
let registry = Arc::new(NodeRegistry::new(32));
for i in 0..4 {
registry.register_node(Node {
id: i,
address: format!("node-{i}:8080"),
healthy: true,
assigned_partitions: vec![],
});
}
let router = RequestRouter::new(registry.clone());
(registry, router)
}
#[test]
fn test_create_router() {
let registry = Arc::new(NodeRegistry::new(32));
let _router = RequestRouter::new(registry);
}
#[test]
fn test_route_for_entity() {
let (_registry, router) = setup_cluster();
let entity_id = EntityId::new("user-123".to_string()).unwrap();
let node = router.route_for_entity(&entity_id).unwrap();
assert!(node.id < 4);
assert!(node.healthy);
}
#[test]
fn test_consistent_routing() {
let (_registry, router) = setup_cluster();
let entity_id = EntityId::new("user-123".to_string()).unwrap();
let node1 = router.route_for_entity(&entity_id).unwrap();
let node2 = router.route_for_entity(&entity_id).unwrap();
let node3 = router.route_for_entity(&entity_id).unwrap();
assert_eq!(node1.id, node2.id);
assert_eq!(node2.id, node3.id);
}
#[test]
fn test_different_entities_may_route_differently() {
let (_registry, router) = setup_cluster();
let entity1 = EntityId::new("user-1".to_string()).unwrap();
let entity2 = EntityId::new("user-2".to_string()).unwrap();
let entity3 = EntityId::new("user-3".to_string()).unwrap();
let node1 = router.route_for_entity(&entity1).unwrap();
let node2 = router.route_for_entity(&entity2).unwrap();
let node3 = router.route_for_entity(&entity3).unwrap();
let unique_nodes: std::collections::HashSet<_> =
vec![node1.id, node2.id, node3.id].into_iter().collect();
println!("Unique nodes: {unique_nodes:?}");
}
#[test]
fn test_route_for_partition() {
let (_registry, router) = setup_cluster();
let partition_key = PartitionKey::from_partition_id(15, 32).unwrap();
let node = router.route_for_partition(&partition_key).unwrap();
assert!(node.id < 4);
assert!(node.healthy);
}
#[test]
fn test_can_node_handle_entity() {
let (_registry, router) = setup_cluster();
let entity_id = EntityId::new("user-123".to_string()).unwrap();
let target_node = router.route_for_entity(&entity_id).unwrap();
assert!(router.can_node_handle_entity(&entity_id, target_node.id));
for i in 0..4 {
if i != target_node.id {
let _can_handle = router.can_node_handle_entity(&entity_id, i);
}
}
}
#[test]
fn test_nodes_for_read() {
let (_registry, router) = setup_cluster();
let nodes = router.nodes_for_read();
assert_eq!(nodes.len(), 4);
assert!(nodes.iter().all(|n| n.healthy));
}
#[test]
fn test_no_healthy_nodes() {
let registry = Arc::new(NodeRegistry::new(32));
registry.register_node(Node {
id: 0,
address: "node-0:8080".to_string(),
healthy: false,
assigned_partitions: vec![],
});
let router = RequestRouter::new(registry);
let entity_id = EntityId::new("user-123".to_string()).unwrap();
let result = router.route_for_entity(&entity_id);
assert!(result.is_err());
}
#[test]
fn test_partition_distribution() {
let (_registry, router) = setup_cluster();
let distribution = router.partition_distribution();
assert_eq!(distribution.len(), 4);
for (_node_id, partitions) in distribution {
assert_eq!(partitions.len(), 8);
}
}
#[test]
fn test_is_available() {
let (registry, router) = setup_cluster();
assert!(router.is_available());
for i in 0..4 {
registry.set_node_health(i, false);
}
assert!(!router.is_available());
}
}