use std::io::Write;
use std::net::{SocketAddr, ToSocketAddrs};
use crate::error::{Result, ShardCacheClientError};
#[derive(Debug, Clone, Copy)]
pub struct ShardCacheDirectRouter {
base_addr: SocketAddr,
shard_count: usize,
shift: u32,
route_mode: ShardCacheRouteMode,
}
impl ShardCacheDirectRouter {
pub fn new(addr: impl ToSocketAddrs, shard_count: usize) -> Result<Self> {
if shard_count == 0 || !shard_count.is_power_of_two() {
return Err(ShardCacheClientError::Config(format!(
"SCNP direct shard count must be a non-zero power of two: {shard_count}"
)));
}
let base_addr = resolve_one(addr)?;
Ok(Self {
base_addr,
shard_count,
shift: shift_for(shard_count),
route_mode: ShardCacheRouteMode::FullKey,
})
}
pub fn with_route_mode(mut self, route_mode: ShardCacheRouteMode) -> Self {
self.route_mode = route_mode;
self
}
pub fn shard_count(&self) -> usize {
self.shard_count
}
pub fn route_key(&self, key: &[u8]) -> ShardCacheRoute {
let key_hash = hash_key(key);
let route_hash = self.route_mode.route_hash(key, key_hash);
ShardCacheRoute {
key_hash,
key_tag: hash_key_tag_from_hash(key_hash),
shard_id: stripe_index(route_hash, self.shift),
}
}
pub fn shard_addr(&self, shard_id: usize) -> Result<SocketAddr> {
if shard_id >= self.shard_count {
return Err(ShardCacheClientError::Config(format!(
"SCNP direct shard {shard_id} out of range for {} shards",
self.shard_count
)));
}
let mut addr = self.base_addr;
let offset = u16::try_from(shard_id).map_err(|_| {
ShardCacheClientError::Config(format!("SCNP direct shard id exceeds u16: {shard_id}"))
})?;
let port = self.base_addr.port().checked_add(offset).ok_or_else(|| {
ShardCacheClientError::Config(format!(
"SCNP direct shard port overflows for shard {shard_id}"
))
})?;
addr.set_port(port);
Ok(addr)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ShardCacheRouteMode {
FullKey,
SessionPrefix,
}
impl ShardCacheRouteMode {
pub fn parse(value: &str) -> Result<Self> {
match value {
"full_key" | "full-key" | "point" => Ok(Self::FullKey),
"session_prefix" | "session-prefix" | "session" => Ok(Self::SessionPrefix),
other => Err(ShardCacheClientError::Config(format!(
"unknown SCNP direct route mode `{other}`; use full_key or session_prefix"
))),
}
}
fn route_hash(self, key: &[u8], key_hash: u64) -> u64 {
match self {
Self::FullKey => key_hash,
Self::SessionPrefix => hash_key(session_route_prefix(key)),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ShardCacheRoute {
pub key_hash: u64,
pub key_tag: u64,
pub shard_id: usize,
}
impl ShardCacheRoute {
pub(crate) fn write_to<W: Write>(&self, w: &mut W) -> Result<()> {
w.write_all(&self.key_hash.to_le_bytes())?;
w.write_all(&(self.shard_id as u32).to_le_bytes())?;
w.write_all(&self.key_tag.to_le_bytes())?;
Ok(())
}
}
pub fn hash_key(key: &[u8]) -> u64 {
xxhash_rust::xxh3::xxh3_64(key)
}
pub fn hash_key_tag(key: &[u8]) -> u64 {
hash_key_tag_from_hash(hash_key(key))
}
pub fn hash_key_tag_from_hash(hash: u64) -> u64 {
hash >> 56
}
fn session_route_prefix(key: &[u8]) -> &[u8] {
if !key.starts_with(b"s:") {
return key;
}
if let Some(index) = session_chunk_separator(key) {
return &key[..index];
}
key
}
#[inline(always)]
fn session_chunk_separator(key: &[u8]) -> Option<usize> {
if key.len() < 3 {
return None;
}
let mut index = key.len() - 3;
loop {
if key[index] == b':' && key[index + 1] == b'c' && key[index + 2] == b':' {
return Some(index);
}
if index == 0 {
return None;
}
index -= 1;
}
}
pub fn shard_index(hash: u64, shard_count: usize) -> Result<usize> {
if shard_count == 0 || !shard_count.is_power_of_two() {
return Err(ShardCacheClientError::Config(format!(
"shard count must be a non-zero power of two: {shard_count}"
)));
}
Ok(stripe_index(hash, shift_for(shard_count)))
}
fn stripe_index(hash: u64, shift: u32) -> usize {
if shift == usize::BITS {
0
} else {
((hash as usize) << 7) >> shift
}
}
fn shift_for(shard_count: usize) -> u32 {
debug_assert!(shard_count > 0 && shard_count.is_power_of_two());
usize::BITS - shard_count.trailing_zeros()
}
fn resolve_one(addr: impl ToSocketAddrs) -> Result<SocketAddr> {
addr.to_socket_addrs()?.next().ok_or_else(|| {
ShardCacheClientError::Config("SCNP address resolved to no socket addresses".into())
})
}