use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use core_types::TransportDomain;
use transport_core::Endpoint;
use super::{RouteHint, RouteTrafficKind};
fn endpoint_weight(endpoint: &Endpoint) -> u32 {
for label in &endpoint.labels {
if let Some(raw) = label
.strip_prefix("weight=")
.or_else(|| label.strip_prefix("weight:"))
.or_else(|| label.strip_prefix("route-weight="))
.or_else(|| label.strip_prefix("route-weight:"))
.or_else(|| label.strip_prefix("w="))
.or_else(|| label.strip_prefix("w:"))
&& let Ok(parsed) = raw.parse::<u32>()
{
return parsed.max(1);
}
}
1
}
pub(super) fn weighted_pick<'a>(
candidates: &'a [&Endpoint],
hint: &RouteHint,
) -> Option<&'a Endpoint> {
if candidates.is_empty() {
return None;
}
if candidates.len() == 1 {
return Some(candidates[0]);
}
let total: u64 = candidates
.iter()
.map(|endpoint| u64::from(endpoint_weight(endpoint)))
.sum();
if total == 0 {
return Some(candidates[0]);
}
let mut slot = route_seed(hint) % total;
for endpoint in candidates {
let weight = u64::from(endpoint_weight(endpoint));
if slot < weight {
return Some(endpoint);
}
slot = slot.saturating_sub(weight);
}
candidates.last().copied()
}
fn route_seed(hint: &RouteHint) -> u64 {
let mut hasher = DefaultHasher::new();
match hint.preferred_domain {
TransportDomain::Network => 0u8,
TransportDomain::Local => 1u8,
}
.hash(&mut hasher);
for label in &hint.labels {
label.hash(&mut hasher);
}
if let Some(target_name) = &hint.target_name {
target_name.hash(&mut hasher);
}
if let Some(kind) = hint.traffic_kind {
match kind {
RouteTrafficKind::Topic => 0u8,
RouteTrafficKind::Service => 1u8,
RouteTrafficKind::Action => 2u8,
RouteTrafficKind::Mission => 3u8,
}
.hash(&mut hasher);
}
hasher.finish()
}