use std::net::SocketAddr;
use std::sync::Arc;
use uuid::Uuid;
use crate::cluster::{ClusterState, Node, NodeRef};
use crate::routing::Shard;
use super::{LoadBalancingPolicy, RoutingInfo};
#[derive(Debug, Clone)] #[non_exhaustive]
pub enum NodeIdentifier {
Node(Arc<Node>),
HostId(Uuid),
NodeAddress(SocketAddr),
}
#[derive(Debug)]
pub struct SingleTargetLoadBalancingPolicy {
node_identifier: NodeIdentifier,
shard: Option<Shard>,
}
impl SingleTargetLoadBalancingPolicy {
#[expect(clippy::new_ret_no_self)]
pub fn new(
node_identifier: NodeIdentifier,
shard: Option<Shard>,
) -> Arc<dyn LoadBalancingPolicy> {
Arc::new(Self {
node_identifier,
shard,
})
}
}
impl LoadBalancingPolicy for SingleTargetLoadBalancingPolicy {
fn pick<'a>(
&'a self,
_request: &'a RoutingInfo,
cluster: &'a ClusterState,
) -> Option<(NodeRef<'a>, Option<Shard>)> {
let node = match &self.node_identifier {
NodeIdentifier::Node(node) => Some(node),
NodeIdentifier::HostId(host_id) => cluster.known_peers.get(host_id),
NodeIdentifier::NodeAddress(addr) => cluster
.all_nodes
.iter()
.find(|node| SocketAddr::new(node.address.ip(), node.address.port()) == *addr),
};
match node {
Some(node) => Some((node, self.shard)),
None => {
tracing::warn!(
"SingleTargetLoadBalancingPolicy failed to find requested node {:?} in cluster metadata.",
self.node_identifier
);
None
}
}
}
fn fallback<'a>(
&'a self,
_request: &'a RoutingInfo,
_cluster: &'a ClusterState,
) -> super::FallbackPlan<'a> {
Box::new(std::iter::empty())
}
fn name(&self) -> String {
"SingleTargetLoadBalancingPolicy".to_string()
}
}