use std::collections::HashMap;
use std::num::NonZeroU16;
use std::ops::RangeInclusive;
use rand::Rng as _;
use thiserror::Error;
use super::Token;
#[derive(Debug, Clone)]
#[cfg_attr(test, derive(PartialEq, Eq))]
pub struct ShardAwarePortRange(RangeInclusive<u16>);
impl ShardAwarePortRange {
pub const EPHEMERAL_PORT_RANGE: Self = Self(49152..=65535);
#[inline]
pub fn new(range: impl Into<RangeInclusive<u16>>) -> Result<Self, InvalidShardAwarePortRange> {
let range = range.into();
if range.is_empty() || range.start() < &1024 {
return Err(InvalidShardAwarePortRange);
}
Ok(Self(range))
}
}
impl Default for ShardAwarePortRange {
fn default() -> Self {
Self::EPHEMERAL_PORT_RANGE
}
}
#[derive(Debug, Error)]
#[error("Invalid shard-aware local port range")]
pub struct InvalidShardAwarePortRange;
pub type Shard = u32;
pub type ShardCount = NonZeroU16;
#[derive(PartialEq, Eq, Clone, Debug)]
pub(crate) struct ShardInfo {
pub(crate) shard: u16,
pub(crate) nr_shards: ShardCount,
pub(crate) msb_ignore: u8,
}
#[derive(PartialEq, Eq, Clone, Debug)]
pub struct Sharder {
pub nr_shards: ShardCount,
pub msb_ignore: u8,
}
impl std::str::FromStr for Token {
type Err = std::num::ParseIntError;
fn from_str(s: &str) -> Result<Token, std::num::ParseIntError> {
Ok(Token { value: s.parse()? })
}
}
impl ShardInfo {
pub(crate) fn new(shard: u16, nr_shards: ShardCount, msb_ignore: u8) -> Self {
ShardInfo {
shard,
nr_shards,
msb_ignore,
}
}
pub(crate) fn get_sharder(&self) -> Sharder {
Sharder::new(self.nr_shards, self.msb_ignore)
}
}
impl Sharder {
pub fn new(nr_shards: ShardCount, msb_ignore: u8) -> Self {
Sharder {
nr_shards,
msb_ignore,
}
}
pub fn shard_of(&self, token: Token) -> Shard {
let mut biased_token = (token.value as u64).wrapping_add(1u64 << 63);
biased_token <<= self.msb_ignore;
(((biased_token as u128) * (self.nr_shards.get() as u128)) >> 64) as Shard
}
pub fn shard_of_source_port(&self, source_port: u16) -> Shard {
(source_port % self.nr_shards.get()) as Shard
}
pub fn draw_source_port_for_shard(&self, shard: Shard) -> u16 {
self.draw_source_port_for_shard_from_range(
shard,
&ShardAwarePortRange::EPHEMERAL_PORT_RANGE,
)
}
pub(crate) fn draw_source_port_for_shard_from_range(
&self,
shard: Shard,
port_range: &ShardAwarePortRange,
) -> u16 {
assert!(shard < self.nr_shards.get() as u32);
let (range_start, range_end) = (port_range.0.start(), port_range.0.end());
rand::rng().random_range(
(range_start + self.nr_shards.get() - 1)..(range_end - self.nr_shards.get() + 1),
) / self.nr_shards.get()
* self.nr_shards.get()
+ shard as u16
}
pub fn iter_source_ports_for_shard(&self, shard: Shard) -> impl Iterator<Item = u16> + use<> {
self.iter_source_ports_for_shard_from_range(
shard,
&ShardAwarePortRange::EPHEMERAL_PORT_RANGE,
)
}
pub(crate) fn iter_source_ports_for_shard_from_range(
&self,
shard: Shard,
port_range: &ShardAwarePortRange,
) -> impl Iterator<Item = u16> + use<> {
assert!(shard < self.nr_shards.get() as u32);
let (range_start, range_end) = (port_range.0.start(), port_range.0.end());
let starting_port = self.draw_source_port_for_shard_from_range(shard, port_range);
let first_valid_port =
range_start.div_ceil(self.nr_shards.get()) * self.nr_shards.get() + shard as u16;
let before_wrap = (starting_port..=*range_end).step_by(self.nr_shards.get().into());
let after_wrap = (first_valid_port..starting_port).step_by(self.nr_shards.get().into());
before_wrap.chain(after_wrap)
}
}
#[derive(Clone, Error, Debug)]
pub(crate) enum ShardingError {
#[error("Server did not provide any sharding information")]
NoShardInfo,
#[error("Missing some sharding info parameters")]
MissingSomeShardInfoParameters,
#[error("Missing some sharding info parameter values")]
MissingShardInfoParameterValues,
#[error("Sharding info contains an invalid number of shards (0)")]
ZeroShards,
#[error("Failed to parse a sharding info parameter's value: {0}")]
ParseIntError(#[from] std::num::ParseIntError),
}
const SHARD_ENTRY: &str = "SCYLLA_SHARD";
const NR_SHARDS_ENTRY: &str = "SCYLLA_NR_SHARDS";
const MSB_IGNORE_ENTRY: &str = "SCYLLA_SHARDING_IGNORE_MSB";
impl<'a> TryFrom<&'a HashMap<String, Vec<String>>> for ShardInfo {
type Error = ShardingError;
fn try_from(options: &'a HashMap<String, Vec<String>>) -> Result<Self, Self::Error> {
let shard_entry = options.get(SHARD_ENTRY);
let nr_shards_entry = options.get(NR_SHARDS_ENTRY);
let msb_ignore_entry = options.get(MSB_IGNORE_ENTRY);
let (shard_entry, nr_shards_entry, msb_ignore_entry) =
match (shard_entry, nr_shards_entry, msb_ignore_entry) {
(Some(shard_entry), Some(nr_shards_entry), Some(msb_ignore_entry)) => {
(shard_entry, nr_shards_entry, msb_ignore_entry)
}
(None, None, None) => return Err(ShardingError::NoShardInfo),
_ => return Err(ShardingError::MissingSomeShardInfoParameters),
};
let (Some(shard_entry), Some(nr_shards_entry), Some(msb_ignore_entry)) = (
shard_entry.first(),
nr_shards_entry.first(),
msb_ignore_entry.first(),
) else {
return Err(ShardingError::MissingShardInfoParameterValues);
};
let shard = shard_entry.parse::<u16>()?;
let nr_shards = nr_shards_entry.parse::<u16>()?;
let nr_shards = ShardCount::new(nr_shards).ok_or(ShardingError::ZeroShards)?;
let msb_ignore = msb_ignore_entry.parse::<u8>()?;
Ok(ShardInfo::new(shard, nr_shards, msb_ignore))
}
}
#[cfg(test)]
impl ShardInfo {
pub(crate) fn add_to_options(&self, options: &mut HashMap<String, Vec<String>>) {
for (k, v) in [
(SHARD_ENTRY, &self.shard as &dyn std::fmt::Display),
(NR_SHARDS_ENTRY, &self.nr_shards),
(MSB_IGNORE_ENTRY, &self.msb_ignore),
] {
options.insert(k.to_owned(), vec![v.to_string()]);
}
}
}
#[cfg(test)]
mod tests {
use crate::routing::{Shard, ShardAwarePortRange};
use crate::test_utils::setup_tracing;
use super::Token;
use super::{ShardCount, Sharder};
use std::collections::HashSet;
#[test]
fn test_shard_aware_port_range_constructor() {
setup_tracing();
let range = ShardAwarePortRange::new(49152..=65535).unwrap();
assert_eq!(range, ShardAwarePortRange::EPHEMERAL_PORT_RANGE);
#[expect(clippy::reversed_empty_ranges)]
{
assert!(ShardAwarePortRange::new(49152..=49151).is_err());
}
assert!(ShardAwarePortRange::new(0..=65535).is_err());
}
#[test]
fn test_shard_of() {
setup_tracing();
let sharder = Sharder::new(ShardCount::new(4).unwrap(), 12);
assert_eq!(
sharder.shard_of(Token {
value: -9219783007514621794
}),
3
);
assert_eq!(
sharder.shard_of(Token {
value: 9222582454147032830
}),
3
);
}
#[test]
fn test_iter_source_ports_for_shard() {
setup_tracing();
fn test_helper<F, I>(nr_shards: u16, port_range: ShardAwarePortRange, get_iter: F)
where
F: Fn(&Sharder, Shard) -> I,
I: Iterator<Item = u16>,
{
let max_port_num = port_range.0.end();
let min_port_num = port_range.0.start().div_ceil(nr_shards) * nr_shards;
let sharder = Sharder::new(ShardCount::new(nr_shards).unwrap(), 12);
for shard in 0..nr_shards {
let mut lowest_port = min_port_num;
while lowest_port % nr_shards != shard {
lowest_port += 1;
}
let possible_ports_number: usize =
((max_port_num - lowest_port) / nr_shards + 1).into();
let port_iter = get_iter(&sharder, shard.into());
let mut returned_ports: HashSet<u16> = HashSet::new();
for port in port_iter {
assert!(!returned_ports.contains(&port)); assert!(port % nr_shards == shard);
returned_ports.insert(port);
}
assert_eq!(returned_ports.len(), possible_ports_number);
}
}
{
test_helper(
4,
ShardAwarePortRange::EPHEMERAL_PORT_RANGE,
|sharder, shard| sharder.iter_source_ports_for_shard(shard),
);
}
{
let port_range = ShardAwarePortRange::new(21371..=42424).unwrap();
test_helper(4, port_range.clone(), |sharder, shard| {
sharder.iter_source_ports_for_shard_from_range(shard, &port_range)
});
}
}
}