use bytes::{Buf, Bytes};
use rand::Rng;
use std::collections::HashMap;
use std::convert::TryFrom;
use std::net::SocketAddr;
use std::num::{NonZeroU16, Wrapping};
use thiserror::Error;
#[derive(PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy)]
pub struct Node {
pub addr: SocketAddr,
}
#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Debug)]
pub struct Token {
pub value: i64,
}
pub type Shard = u32;
pub type ShardCount = NonZeroU16;
#[derive(PartialEq, Eq, Clone, Debug)]
pub struct ShardInfo {
pub shard: u16,
pub nr_shards: ShardCount,
pub msb_ignore: u8,
}
#[derive(PartialEq, Eq, Clone, Debug)]
pub struct Sharder {
pub nr_shards: ShardCount,
pub msb_ignore: u8,
}
impl std::str::FromStr for Token {
type Err = std::num::ParseIntError;
fn from_str(s: &str) -> Result<Token, std::num::ParseIntError> {
Ok(Token { value: s.parse()? })
}
}
pub fn murmur3_token(pk: Bytes) -> Token {
Token {
value: hash3_x64_128(&pk) as i64,
}
}
impl ShardInfo {
pub fn new(shard: u16, nr_shards: ShardCount, msb_ignore: u8) -> Self {
ShardInfo {
shard,
nr_shards,
msb_ignore,
}
}
pub fn get_sharder(&self) -> Sharder {
Sharder::new(self.nr_shards, self.msb_ignore)
}
}
impl Sharder {
pub fn new(nr_shards: ShardCount, msb_ignore: u8) -> Self {
Sharder {
nr_shards,
msb_ignore,
}
}
pub fn shard_of(&self, token: Token) -> Shard {
let mut biased_token = (token.value as u64).wrapping_add(1u64 << 63);
biased_token <<= self.msb_ignore;
(((biased_token as u128) * (self.nr_shards.get() as u128)) >> 64) as Shard
}
pub fn shard_of_source_port(&self, source_port: u16) -> Shard {
(source_port % self.nr_shards.get()) as Shard
}
pub fn draw_source_port_for_shard(&self, shard: Shard) -> u16 {
assert!(shard < self.nr_shards.get() as u32);
rand::thread_rng()
.gen_range((49152 + self.nr_shards.get() - 1)..(65535 - self.nr_shards.get() + 1))
/ self.nr_shards.get()
* self.nr_shards.get()
+ shard as u16
}
pub fn iter_source_ports_for_shard(&self, shard: Shard) -> impl Iterator<Item = u16> {
assert!(shard < self.nr_shards.get() as u32);
let starting_port = self.draw_source_port_for_shard(shard);
let first_valid_port = (49152 + self.nr_shards.get() - 1) / self.nr_shards.get()
* self.nr_shards.get()
+ shard as u16;
let before_wrap = (starting_port..=65535).step_by(self.nr_shards.get().into());
let after_wrap = (first_valid_port..starting_port).step_by(self.nr_shards.get().into());
before_wrap.chain(after_wrap)
}
}
#[derive(Error, Debug)]
pub enum ShardingError {
#[error("ShardInfo parameters missing")]
MissingShardInfoParameter,
#[error("ShardInfo parameters missing after unwraping")]
MissingUnwrapedShardInfoParameter,
#[error("ShardInfo contains an invalid number of shards (0)")]
ZeroShards,
#[error("ParseIntError encountered while getting ShardInfo")]
ParseIntError(#[from] std::num::ParseIntError),
}
impl<'a> TryFrom<&'a HashMap<String, Vec<String>>> for ShardInfo {
type Error = ShardingError;
fn try_from(options: &'a HashMap<String, Vec<String>>) -> Result<Self, Self::Error> {
let shard_entry = options.get("SCYLLA_SHARD");
let nr_shards_entry = options.get("SCYLLA_NR_SHARDS");
let msb_ignore_entry = options.get("SCYLLA_SHARDING_IGNORE_MSB");
if shard_entry.is_none() || nr_shards_entry.is_none() || msb_ignore_entry.is_none() {
return Err(ShardingError::MissingShardInfoParameter);
}
if shard_entry.unwrap().is_empty()
|| nr_shards_entry.unwrap().is_empty()
|| msb_ignore_entry.unwrap().is_empty()
{
return Err(ShardingError::MissingUnwrapedShardInfoParameter);
}
let shard = shard_entry.unwrap().first().unwrap().parse::<u16>()?;
let nr_shards = nr_shards_entry.unwrap().first().unwrap().parse::<u16>()?;
let nr_shards = ShardCount::new(nr_shards).ok_or(ShardingError::ZeroShards)?;
let msb_ignore = msb_ignore_entry.unwrap().first().unwrap().parse::<u8>()?;
Ok(ShardInfo::new(shard, nr_shards, msb_ignore))
}
}
pub fn hash3_x64_128(mut data: &[u8]) -> i128 {
let length = data.len();
let c1: Wrapping<i64> = Wrapping(0x87c3_7b91_1142_53d5_u64 as i64);
let c2: Wrapping<i64> = Wrapping(0x4cf5_ad43_2745_937f_u64 as i64);
let mut h1 = Wrapping(0_i64);
let mut h2 = Wrapping(0_i64);
while data.len() >= 16 {
let mut k1 = Wrapping(data.get_i64_le());
let mut k2 = Wrapping(data.get_i64_le());
k1 *= c1;
k1 = rotl64(k1, 31);
k1 *= c2;
h1 ^= k1;
h1 = rotl64(h1, 27);
h1 += h2;
h1 = h1 * Wrapping(5) + Wrapping(0x52dce729);
k2 *= c2;
k2 = rotl64(k2, 33);
k2 *= c1;
h2 ^= k2;
h2 = rotl64(h2, 31);
h2 += h1;
h2 = h2 * Wrapping(5) + Wrapping(0x38495ab5);
}
let mut k1 = Wrapping(0_i64);
let mut k2 = Wrapping(0_i64);
debug_assert!(data.len() < 16);
if data.len() > 8 {
for i in (8..data.len()).rev() {
k2 ^= Wrapping(data[i] as i8 as i64) << ((i - 8) * 8);
}
k2 *= c2;
k2 = rotl64(k2, 33);
k2 *= c1;
h2 ^= k2;
}
if !data.is_empty() {
for i in (0..std::cmp::min(8, data.len())).rev() {
k1 ^= Wrapping(data[i] as i8 as i64) << (i * 8);
}
k1 *= c1;
k1 = rotl64(k1, 31);
k1 *= c2;
h1 ^= k1;
}
h1 ^= Wrapping(length as i64);
h2 ^= Wrapping(length as i64);
h1 += h2;
h2 += h1;
h1 = fmix(h1);
h2 = fmix(h2);
h1 += h2;
h2 += h1;
((h2.0 as i128) << 64) | h1.0 as i128
}
#[inline]
fn rotl64(v: Wrapping<i64>, n: u32) -> Wrapping<i64> {
Wrapping((v.0 << n) | (v.0 as u64 >> (64 - n)) as i64)
}
#[inline]
fn fmix(mut k: Wrapping<i64>) -> Wrapping<i64> {
k ^= Wrapping((k.0 as u64 >> 33) as i64);
k *= Wrapping(0xff51afd7ed558ccd_u64 as i64);
k ^= Wrapping((k.0 as u64 >> 33) as i64);
k *= Wrapping(0xc4ceb9fe1a85ec53_u64 as i64);
k ^= Wrapping((k.0 as u64 >> 33) as i64);
k
}
#[cfg(test)]
mod tests {
use super::Token;
use super::{ShardCount, Sharder};
use std::collections::HashSet;
#[test]
fn test_shard_of() {
let sharder = Sharder::new(ShardCount::new(4).unwrap(), 12);
assert_eq!(
sharder.shard_of(Token {
value: -9219783007514621794
}),
3
);
assert_eq!(
sharder.shard_of(Token {
value: 9222582454147032830
}),
3
);
}
#[test]
fn test_iter_source_ports_for_shard() {
let nr_shards = 4;
let max_port_num = 65535;
let min_port_num = (49152 + nr_shards - 1) / nr_shards * nr_shards;
let sharder = Sharder::new(ShardCount::new(nr_shards).unwrap(), 12);
for shard in 0..nr_shards {
let mut lowest_port = min_port_num;
while lowest_port % nr_shards != shard {
lowest_port += 1;
}
let possible_ports_number: usize =
((max_port_num - lowest_port) / nr_shards + 1).into();
let port_iter = sharder.iter_source_ports_for_shard(shard.into());
let mut returned_ports: HashSet<u16> = HashSet::new();
for port in port_iter {
assert!(!returned_ports.contains(&port)); assert!(port % nr_shards == shard);
returned_ports.insert(port);
}
assert_eq!(returned_ports.len(), possible_ports_number);
}
}
}